# –≠—Ç–∞–ø 3: INT8 Quantization ‚Äî –ü–æ–ª–Ω—ã–π —Å—Ä–∞–≤–Ω–∏—Ç–µ–ª—å–Ω—ã–π –∞–Ω–∞–ª–∏–∑

–ú—ã —Å—Ä–∞–≤–Ω–∏–≤–∞–µ–º –¥–≤–µ –º–æ–¥–µ–ª–∏ –ø–æ —Ç—Ä–µ–º –∫–ª—é—á–µ–≤—ã–º –±–∏–∑–Ω–µ—Å-–º–µ—Ç—Ä–∏–∫–∞–º:
1. **Memory (–ü–∞–º—è—Ç—å)**: –§–∏–∑–∏—á–µ—Å–∫–∏–π –æ–±—ä–µ–º –∑–∞–Ω–∏–º–∞–µ–º–æ–≥–æ –º–µ—Å—Ç–∞.
2. **Throughput (–ü—Ä–æ–ø—É—Å–∫–Ω–∞—è —Å–ø–æ—Å–æ–±–Ω–æ—Å—Ç—å)**: –°–∫–æ–ª—å–∫–æ —Å–∏–º–≤–æ–ª–æ–≤ –≥–µ–Ω–µ—Ä–∏—Ä—É–µ—Ç—Å—è –≤ —Å–µ–∫—É–Ω–¥—É.
3. **Latency (–ó–∞–¥–µ—Ä–∂–∫–∞)**: –í—Ä–µ–º—è –¥–æ –ø–æ—è–≤–ª–µ–Ω–∏—è —Å–∞–º–æ–≥–æ –ø–µ—Ä–≤–æ–≥–æ —Å–∏–º–≤–æ–ª–∞ (Time to First Token).

In [None]:
import torch
import copy
import time
import pandas as pd
import numpy as np
from src.model import GPTLanguageModel, device, get_batch, estimate_loss, decode, encode

def get_model_size_mb(mdl, real_int8=False):
    if not real_int8:
        param_size = sum(p.nelement() * p.element_size() for p in mdl.parameters())
        buffer_size = sum(b.nelement() * b.element_size() for b in mdl.buffers())
        return (param_size + buffer_size) / 1024**2
    else:
        total_bits = 0
        for name, param in mdl.named_parameters():
            bits = 8 if ('weight' in name and param.dim() > 1) else 32
            total_bits += param.nelement() * bits
        for b in mdl.buffers():
            total_bits += b.nelement() * 32
        return total_bits / (8 * 1024**2)

@torch.no_grad()
def measure_performance(mdl, num_tokens=50):
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    
    # Warmup
    _ = mdl.generate(context, max_new_tokens=5)
    
    # 1. Latency (TTFT)
    start_latency = time.time()
    _ = mdl.generate(context, max_new_tokens=1)
    latency = (time.time() - start_latency) * 1000
    
    # 2. Throughput
    start_throughput = time.time()
    _ = mdl.generate(context, max_new_tokens=num_tokens)
    duration = time.time() - start_throughput
    throughput = num_tokens / duration
    
    return latency, throughput

def quantize_tensor_int8(x):
    x_max = x.abs().max().item()
    if x_max == 0: return x
    scale = x_max / 127.0
    return torch.round(x / scale).clamp(-128, 127) * scale

### 1. –ó–∞–º–µ—Ä –∏—Å—Ö–æ–¥–Ω–æ–π –º–æ–¥–µ–ª–∏ (Baseline FP32)
–ó–¥–µ—Å—å –º—ã —Ñ–∏–∫—Å–∏—Ä—É–µ–º ¬´–∑–æ–ª–æ—Ç–æ–π —Å—Ç–∞–Ω–¥–∞—Ä—Ç¬ª –∫–∞—á–µ—Å—Ç–≤–∞ –∏ —Å–∫–æ—Ä–æ—Å—Ç–∏.

In [None]:
model_fp32 = GPTLanguageModel().to(device)
try:
    model_fp32.load_state_dict(torch.load('model_ckpt.pt', map_location=device))
    print("‚úÖ –£—Å–ø–µ—à–Ω–æ –∑–∞–≥—Ä—É–∂–µ–Ω—ã –≤–µ—Å–∞ –º–æ–¥–µ–ª–∏.")
except:
    print("‚ö†Ô∏è –ß–µ–∫–ø–æ–∏–Ω—Ç –Ω–µ –Ω–∞–π–¥–µ–Ω, –∑–∞–º–µ—Ä—ã –±—É–¥—É—Ç –Ω–∞ —Å–ª—É—á–∞–π–Ω—ã—Ö –≤–µ—Å–∞—Ö.")
model_fp32.eval()

print("\n--- [STEP 1] Baseline FP32 Performance ---")
loss_fp32 = estimate_loss(model_fp32)['val'].item()
size_fp32 = get_model_size_mb(model_fp32)
lat_fp32, thr_fp32 = measure_performance(model_fp32)

print(f"üîπ Memory Usage:     {size_fp32:.2f} MB")
print(f"üîπ Inference Latency: {lat_fp32:.2f} ms (Time to First Token)")
print(f"üîπ Throughput Rate:   {thr_fp32:.2f} tokens/sec")
print(f"üîπ Model Quality:     {loss_fp32:.4f} (Val Loss)")

### 2. –ö–≤–∞–Ω—Ç–æ–≤–∞–Ω–∏–µ –∏ –∑–∞–º–µ—Ä INT8
–¢–µ–ø–µ—Ä—å –º—ã ¬´–ø–æ—Ä—Ç–∏–º¬ª –≤–µ—Å–∞ –æ–∫—Ä—É–≥–ª–µ–Ω–∏–µ–º –¥–æ 8 –±–∏—Ç –∏ —Å–º–æ—Ç—Ä–∏–º, –∫–∞–∫ –∏–∑–º–µ–Ω—è—Ç—Å—è —Ç–µ –∂–µ –º–µ—Ç—Ä–∏–∫–∏.

In [None]:
model_int8 = copy.deepcopy(model_fp32)
with torch.no_grad():
    for name, param in model_int8.named_parameters():
        if 'weight' in name and param.dim() > 1:
            param.copy_(quantize_tensor_int8(param.data))

print("--- [STEP 2] Quantized INT8 Performance ---")
loss_int8 = estimate_loss(model_int8)['val'].item()
size_int8 = get_model_size_mb(model_int8, real_int8=True)
lat_int8, thr_int8 = measure_performance(model_int8)

print(f"üî∏ Memory Usage:     {size_int8:.2f} MB (Estimated storage size)")
print(f"üî∏ Inference Latency: {lat_int8:.2f} ms")
print(f"üî∏ Throughput Rate:   {thr_int8:.2f} tokens/sec")
print(f"üî∏ Model Quality:     {loss_int8:.4f} (Val Loss)")

### 3. –ê–Ω–∞–ª–∏–∑ —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ—Å—Ç–∏
–°—Ä–∞–≤–Ω–∏–º –≤—ã–∏–≥—Ä—ã—à –≤ —Ä–µ—Å—É—Ä—Å–∞—Ö –ø—Ä–æ—Ç–∏–≤ –ø–æ—Ç–µ—Ä–∏ –∫–∞—á–µ—Å—Ç–≤–∞.

In [None]:
results = {
    "Metric": ["Memory (MB)", "Throughput (tokens/s)", "Latency (ms)", "Validation Loss"],
    "FP32": [size_fp32, thr_fp32, lat_fp32, loss_fp32],
    "INT8": [size_int8, thr_int8, lat_int8, loss_int8],
    "Delta": [
        f"{size_fp32/size_int8:.1f}x smaller", 
        f"{(thr_int8/thr_fp32 - 1)*100:+.1f}% check", 
        f"{lat_int8 - lat_fp32:+.2f} ms", 
        f"{(loss_int8/loss_fp32 - 1)*100:+.2f}% quality loss"
    ]
}

df = pd.DataFrame(results)
display(df)

print(f"\nüöÄ –ò–¢–û–ì: –í—ã –æ—Å–≤–æ–±–æ–¥–∏–ª–∏ {(size_fp32 - size_int8):.2f} MB –ø–∞–º—è—Ç–∏!")
print(f"üìâ –ü–æ—Ç–µ—Ä—è –∫–∞—á–µ—Å—Ç–≤–∞ —Å–æ—Å—Ç–∞–≤–∏–ª–∞ –≤—Å–µ–≥–æ {((loss_int8/loss_fp32 - 1)*100):.2f}%")

### 4. –ì–µ–Ω–µ—Ä–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞ (Blind Test)
–ù–∞–ø–∏—à–µ–º –ø—Ä–æ–º–ø—Ç –∏ –ø–æ—Å–º–æ—Ç—Ä–∏–º —Ä–∞–∑–Ω–∏—Ü—É –≤ —Å—Ç–∏–ª–µ. 
*–ï—Å–ª–∏ –≤—ã –≤–∏–¥–∏—Ç–µ NameError, —É–±–µ–¥–∏—Ç–µ—Å—å, —á—Ç–æ –≤—ã–ø–æ–ª–Ω–∏–ª–∏ —Å–∞–º—É—é –ø–µ—Ä–≤—É—é —è—á–µ–π–∫—É –∫–æ–¥–∞.*

In [None]:
from src.model import decode, encode

prompt = "ROMEO: "
context = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)

print("--- FP32 OUTPUT ---")
print(decode(model_fp32.generate(context, max_new_tokens=100)[0].tolist()))

print("\n--- INT8 OUTPUT ---")
print(decode(model_int8.generate(context, max_new_tokens=100)[0].tolist()))