In [10]:
# LLM inference and optimisations
# KV cache , optimises the token generation pipeline ofro N tokens from O(N3) summation O(n2) for each token = n3 to 
# O(n) for each n tokens = O(n2) complexity 
import torch 
import math 
import torch.nn as nn 
def scaled_dot_product_attention(Q, K, V, mask=  None):
    # Q,K,V: (B, H, T, D)
    scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(Q.size(-1))  # (B,H,T,Tk)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    probs = torch.softmax(scores, dim=-1)
    out = torch.matmul(probs, V)  # (B,H,T,D)
    return out, probs  # return both; caller can ignore probs if not needed
'''
The Complexity DifferenceWithout Cache:
 To generate token $N$, we process $N$ tokens. Total for a sequence is $\sum_{i=1}^{N} i \approx O(N^2)$.
 With Cache: To generate token $N$, we process $1$ token (attending to $N-1$ cached keys).
   Total for a sequence is $\sum_{i=1}^{N} 1 = O(N)$.
'''

'\nThe Complexity DifferenceWithout Cache:\n To generate token $N$, we process $N$ tokens. Total for a sequence is $\\sum_{i=1}^{N} i \x07pprox O(N^2)$.\n With Cache: To generate token $N$, we process $1$ token (attending to $N-1$ cached keys).\n   Total for a sequence is $\\sum_{i=1}^{N} 1 = O(N)$.\n'

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, num_heads, use_kv_cache= False):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dk = d_model // num_heads
        self.use_kv_cache = use_kv_cache
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        # q groupd 

    def forward(self, x, kv_cache = None):
        # B, 10 , d 
        # B 11, d 
        # B ,1 , d 
        q = self.w_q(x) # (B, T, d_model)
        k = self.w_k(x) # (B, T, d_model)
        v = self.w_v(x) # (B, T, d_model)
        # now split into heads 
        B, T, _ = q.shape

        q = q.view(B, T, self.num_heads, self.dk).transpose(1, 2) # (B, H, T, dk)
        k_new = k.view(B, T, self.num_heads, self.dk).transpose(1, 2) # (B, H, T, dk)
        v_new  = v.view(B, T, self.num_heads, self.dk).transpose(1, 2) # (B, H, T, dk)
        # B, 1, d  , t+1,d 
        # B, 1, d , t+1, d
        # now we can do scaled dot product attention
        # lets say we are not using cache here 
        # if you are using the kv cache q, kv, becomes B, h, 1, dk because x is a single token
        # so instaed ofmultiplying b, n, d with dx d we now multiply b, 1, d 
        if self.use_kv_cache:
            if kv_cache is not None:
                k, v = kv_cache                              # (B, H, T_prev, Dh)
                k = torch.cat([k, k_new], dim=2)              # (B, H, T_prev+1, Dh)
                v = torch.cat([v, v_new], dim=2)              # (B, H, T_prev+1, Dh)
            else:
                k, v = k_new, v_new
            new_cache = (k, v)
            mask = None
        else:
            k, v = k_new, v_new
            new_cache = None
            # causal mask (lower-triangular)
            mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)  # (1,1,T,T)

        out, probs  = scaled_dot_product_attention(q, k, v) # (B, H, T, dk)
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)  # (B, T, d_model)

        out = self.w_o(out) 
        return out , kv_cache
    
    

In [42]:
torch.manual_seed(0)
X = torch.rand(1, 1, 64)  # (B=1, seq=1, d_model=64)
num_tokens = 5
print("\n--- WITHOUT KV CACHE ---")
attn_no = MultiheadAttention(64, 8, use_kv_cache=False)
for t in range(1, num_tokens + 1):
    x_in = X.repeat(1, t, 1)  # simulate full sequence each step
    print(f"\nStep {t}: input {x_in.shape}")
    y, _ = attn_no(x_in)
    


--- WITHOUT KV CACHE ---

Step 1: input torch.Size([1, 1, 64])

Step 2: input torch.Size([1, 2, 64])

Step 3: input torch.Size([1, 3, 64])

Step 4: input torch.Size([1, 4, 64])

Step 5: input torch.Size([1, 5, 64])


In [43]:
print("\n--- WITH KV CACHE ---")
attn_cache = MultiheadAttention(64, 8, use_kv_cache=True)
kv_cache = None
for t in range(num_tokens):
    x_in = X  # one token at a time
    print(f"\nStep {t+1}: input {x_in.shape}")
    y, kv_cache = attn_cache(x_in, kv_cache)



--- WITH KV CACHE ---

Step 1: input torch.Size([1, 1, 64])

Step 2: input torch.Size([1, 1, 64])

Step 3: input torch.Size([1, 1, 64])

Step 4: input torch.Size([1, 1, 64])

Step 5: input torch.Size([1, 1, 64])


In [34]:
import torch

# random weights
W = torch.randn(2, 4)
x = torch.randn(1, 4)

# --- quantization ---
W_min, W_max = W.min(), W.max()
qmin, qmax = 0, 255
scale = (W_max - W_min) / (qmax - qmin)
zero_point = int(qmin - W_min / scale)

# convert to uint8
W_q = torch.round(W / scale + zero_point).clamp(qmin, qmax).to(torch.uint8)

# --- dequantization ---
W_deq = scale * (W_q.float() - zero_point)

# --- run inference ---
y_fp32 = x @ W.T
y_int8 = x @ W_deq.T

print("Original weights:\n", W)
print("\nQuantized (uint8):\n", W_q)
print("\nDequantized weights:\n", W_deq)
print("\nFP32 output:", y_fp32)
print("INT8(dequantized) output:", y_int8)
print("Max diff:", (y_fp32 - y_int8).abs().max().item())


Original weights:
 tensor([[-0.5614,  0.7887,  0.4191,  1.0952],
        [-0.6537,  1.5501, -1.4322,  0.3474]])

Quantized (uint8):
 tensor([[ 74, 189, 158, 216],
        [ 66, 255,   0, 152]], dtype=torch.uint8)

Dequantized weights:
 tensor([[-0.5614,  0.7836,  0.4210,  1.0993],
        [-0.6549,  1.5554, -1.4268,  0.3508]])

FP32 output: tensor([[-0.5213,  0.6229]])
INT8(dequantized) output: tensor([[-0.5218,  0.6176]])
Max diff: 0.005379140377044678
