# Description

In this notebook, I will test the inference of GPT-2 model

In [None]:
import os
import time
from pathlib import Path
import zipfile
import math
from datasets import load_dataset

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import tiktoken
import gemm_int8

In [None]:
def quantized_column_matrix_int_symmetric(mat:torch.Tensor):
    """
    Symmetric quantization to int8 on a per-column basis.
    mat: input float tensor (e.g., torch.float32 or torch.bfloat16)
    """
    qmin, qmax = -128, 127
    
    max_vals, _ = torch.max(torch.abs(mat), dim=0, keepdim=True)  # shape (1, M)
    scales = (max_vals / qmax).squeeze(0)  # shape (M,)
    
    q_mat = torch.clamp(torch.round(mat / scales.unsqueeze(0)), qmin, qmax).to(torch.int8)  # shape (N, M)
    
    scales = scales.clone().detach().to(torch.float32)
    return q_mat, scales

def quantize_row_int8_symmetric(mat: torch.Tensor):
    """
    Symmetric int8 quantization per row.
    mat: (N, M) float tensor
    Returns:
      q_mat: (N, M) int8
      scales: (N,) float32
    """
    qmin, qmax = -128, 127
    
    max_vals = mat.abs().amax(dim=1, keepdim=True)  # (N, 1)
    max_vals = max_vals.clamp(min=1e-8)

    scales = (max_vals / qmax).squeeze(1)          # (N,)
    q_mat = torch.clamp(torch.round(mat / scales.unsqueeze(1)), qmin, qmax).to(torch.int8)

    return q_mat, scales.to(torch.float32)


def quantize_row_matrix_int8_symmetric_batched(mat: torch.Tensor):
    """
    Symmetric per-row quantization for batched 3D tensor.
    mat: [B, N, D]  (float tensor)
    
    Returns:
        q_mat:   [B, N, D] int8
        scales:  [B, N]    float32  (scale per row within each batch)
    """
    qmin, qmax = -128, 127

    # Compute max abs per row (per batch) - Result shape: [B, N, 1]
    max_vals, _ = torch.max(torch.abs(mat), dim=2, keepdim=True)

    # Compute scales per row
    scales = (max_vals / qmax).clamp(min=1e-12)  # avoid div-by-zero, shape [B, N, 1]

    # Quantize
    q_mat = torch.clamp(torch.round(mat / scales), qmin, qmax).to(torch.int8)

    # Return float scales of shape [B, N]
    scales = scales.squeeze(2).to(torch.float32)
    return q_mat, scales

In [None]:
class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.bias = None
            
        # Quantization parameters
        self.register_buffer("weight_q", torch.empty_like(self.weight, dtype=torch.int8), persistent=False)
        self.register_buffer("weight_scale", torch.empty(out_features, dtype=torch.float32), persistent=False)
        self.is_quantized = False
        
    @torch.no_grad()
    def quantize_weights(self):
        weight_q, weight_scale = quantize_row_int8_symmetric(self.weight)
        self.weight_q.copy_(weight_q)
        self.weight_scale.copy_(weight_scale)
        print(f"[INFO] Done quantize linear layer weights - shape {self.weight_q.shape}")
        self.is_quantized = True

    def forward(self, x):
        if self.is_quantized == False:
            y = torch.matmul(x, self.weight.t())
            if self.bias is not None:
                y = y + self.bias
        else:
            x_q, x_scale = quantize_row_matrix_int8_symmetric_batched(x)
            y_int = gemm_int8.bmm_int8_matmul(x_q, self.weight_q)
            y = x_scale.unsqueeze(-1) * y_int * self.weight_scale[None, :]
            
            y = y.to(x.dtype)
            if self.bias is not None:
                y = y + self.bias
        return y


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        # self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        # self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        # self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = CustomLinear(d_in, d_out, bias=qkv_bias)
        self.W_key = CustomLinear(d_in, d_out, bias=qkv_bias)
        self.W_value = CustomLinear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec


class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift


class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            CustomLinear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            CustomLinear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_resid = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_resid(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut  # Add the original input back

        return x

class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits


def generate_text_simple(model, idx, max_new_tokens, context_size):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):

        # Crop current context if it exceeds the supported context size
        # E.g., if LLM supports only 5 tokens, and the context size is 10
        # then only the last 5 tokens are used as context
        idx_cond = idx[:, -context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(idx_cond)

        # Focus only on the last time step
        # (batch, n_token, vocab_size) becomes (batch, vocab_size)
        logits = logits[:, -1, :]

        # Get the idx of the vocab entry with the highest logits value
        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch, 1)

        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)

    return idx


In [None]:
GPT_CONFIG_BASE = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Original context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.0,       # Dropout rate
    "qkv_bias": True        # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

PATH_MODEL = "/scratch/tnguyen10/gpt2-xl-1558M.pth"
# PATH_MODEL = "/scratch/tnguyen10/gpt2-medium-355M.pth"
model_name = "gpt2-xl (1558M)"  # FIX When changing model, update PATH_MODEL accordingly
# model_name = "gpt2-medium (355M)"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

NEW_CONFIG = GPT_CONFIG_BASE.copy()
NEW_CONFIG.update(model_configs[model_name])

model = GPTModel(NEW_CONFIG).to(DEVICE).eval()
# model = torch.compile(model, dynamic=True)

model.load_state_dict(torch.load(PATH_MODEL, weights_only=True))

In [None]:
tokenizer = tiktoken.get_encoding("gpt2")
print(f"Tokenizer vocab size: {tokenizer.n_vocab}")

In [None]:
prompt = "What is the capital of France?"
enc_prompt = tokenizer.encode(prompt)
enc_prompt = torch.tensor([enc_prompt])
enc_prompt = enc_prompt.to("cuda")

token_ids = generate_text_simple(
    model=model,
    idx=enc_prompt, 
    max_new_tokens=256, 
    context_size=NEW_CONFIG["context_length"]
)

output = tokenizer.decode(token_ids.squeeze().tolist())
print(output)

In [None]:
prompt = "What is the King of England?"
enc_prompt = tokenizer.encode(prompt)
enc_prompt = torch.tensor([enc_prompt])
enc_prompt = enc_prompt.to("cuda")

# Measure inference time
# Warm-up
for _ in range(2):
    token_ids = generate_text_simple(
        model=model,
        idx=enc_prompt, 
        max_new_tokens=256, 
        context_size=NEW_CONFIG["context_length"]
    )
    
n_iter = 5
start_time = time.time()
for _ in range(n_iter):
    token_ids = generate_text_simple(
        model=model,
        idx=enc_prompt, 
        max_new_tokens=256, 
        context_size=NEW_CONFIG["context_length"]
    )
end_time = time.time()
avg_time = (end_time - start_time) / n_iter
print(f"Average inference time over {n_iter} runs: {avg_time:.4f} seconds")
print('-'*10)

output = tokenizer.decode(token_ids.squeeze().tolist())
print(output)

# 2. Quantize model

In [None]:
for name, module in model.named_modules():
    if isinstance(module, CustomLinear):
        print(f"\nQuantize weights module: {name} ...")
        module.quantize_weights()
        print(f"Quantizate done : {name}.")

In [None]:
prompt = "What is the capital of France?"
enc_prompt = tokenizer.encode(prompt)
enc_prompt = torch.tensor([enc_prompt])
enc_prompt = enc_prompt.to("cuda")

# Measure inference time after quantization
# Warm-up
for _ in range(2):
    token_ids = generate_text_simple(
        model=model,
        idx=enc_prompt, 
        max_new_tokens=256, 
        context_size=NEW_CONFIG["context_length"]
    )
    
n_iter = 5
start_time = time.time()
for _ in range(n_iter):
    token_ids = generate_text_simple(
        model=model,
        idx=enc_prompt, 
        max_new_tokens=100, 
        context_size=NEW_CONFIG["context_length"]
    )
end_time = time.time()
avg_time = (end_time - start_time) / n_iter
print(f"Average inference time after quantization over {n_iter} runs: {avg_time:.4f} seconds")
print('-'*10)   

output = tokenizer.decode(token_ids.squeeze().tolist())
print(output)

In [None]:
prompt = "What is the King of England?"
enc_prompt = tokenizer.encode(prompt)
enc_prompt = torch.tensor([enc_prompt])
enc_prompt = enc_prompt.to("cuda")

# Measure inference time
# Warm-up
for _ in range(2):
    token_ids = generate_text_simple(
        model=model,
        idx=enc_prompt, 
        max_new_tokens=256, 
        context_size=NEW_CONFIG["context_length"]
    )
    
n_iter = 5
start_time = time.time()
for _ in range(n_iter):
    token_ids = generate_text_simple(
        model=model,
        idx=enc_prompt, 
        max_new_tokens=256, 
        context_size=NEW_CONFIG["context_length"]
    )
end_time = time.time()
avg_time = (end_time - start_time) / n_iter
print(f"Average inference time over {n_iter} runs: {avg_time:.4f} seconds")
print('-'*10)

output = tokenizer.decode(token_ids.squeeze().tolist())
print(output)

# 2. Measure PPL

In [None]:
EOS_ID = 50256  # gpt2's end token (not strictly needed here)

@torch.no_grad()
def compute_ppl(model, tokenizer, texts, context_size, device=DEVICE):
    """
    Faster PPL: slide windows of length <= context_size and
    score only the last token of each window (which has full left context).
    """
    model_was_training = model.training
    model.eval()

    results = []
    for txt in texts:
        ids = tokenizer.encode(txt)
        if len(ids) < 2:
            results.append({"num_tokens": 0, "nll_sum": 0.0, "ppl": float("nan")})
            continue

        ids_t = torch.tensor(ids, dtype=torch.long, device=device)
        nll_sum = 0.0
        tok_cnt = 0

        # We will take windows ending at positions end=1..L-1
        L = ids_t.size(0)
        end = 1
        while end < L:
            start = max(0, end - context_size)        # include up to token end-1
            inp = ids_t[start:end].unsqueeze(0)       # [1, w] (predict token at 'end')
            logits = model(inp)                        # [1, w, V]
            last_logits = logits[:, -1, :]            # prediction for token at 'end'
            target = ids_t[end].view(1)               # [1]
            loss = F.cross_entropy(last_logits, target, reduction="sum")
            nll_sum += float(loss.item())
            tok_cnt += 1

            # Jump ahead by a stride: score roughly one token per window
            # (Tune stride for speed/accuracy trade-off; 1 is exact; larger is faster.)
            stride = max(1, context_size - 1)
            end += stride

        # If we skipped some tail tokens due to large stride, optionally finish them:
        if end - (context_size - 1) < L - 1:
            # exact tail sweep to ensure full coverage
            for t in range(max(1, L - context_size + 1), L):
                start = max(0, t - context_size)
                inp = ids_t[start:t].unsqueeze(0)
                logits = model(inp)
                last_logits = logits[:, -1, :]
                target = ids_t[t].view(1)
                loss = F.cross_entropy(last_logits, target, reduction="sum")
                nll_sum += float(loss.item())
                tok_cnt += 1

        ppl = math.exp(nll_sum / max(tok_cnt, 1))
        results.append({"num_tokens": tok_cnt, "nll_sum": nll_sum, "ppl": ppl})

    total_nll = sum(r["nll_sum"] for r in results)
    total_tok = sum(r["num_tokens"] for r in results) or 1
    corpus_ppl = math.exp(total_nll / total_tok)

    if model_was_training: model.train()
    return results, corpus_ppl

In [None]:
# ----- Load 1,000 samples from WikiText2 -----
def load_wikitext2_samples(n=1000, min_length=10):
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
    # Filter out empty or too-short lines
    samples = [x["text"] for x in dataset if len(x["text"].strip()) > min_length]
    return samples[:n]

num_samples = 10

samples = load_wikitext2_samples(num_samples)
print(f"Loaded {len(samples)} samples. Computing perplexity...")

In [None]:
per_text, corpus_ppl = compute_ppl(
    model=model,
    tokenizer=tokenizer,           
    texts=samples,
    context_size=NEW_CONFIG["context_length"],
    device=DEVICE
)

# print(per_text)
print("Corpus PPL:", corpus_ppl)