In [1]:
import time
import tiktoken
import torch
import torch.nn as nn

### Multi-Head Attention with simple KV Cache
- Approach 2)
    1. simple) Use torch.cat
    2. typical) Preallocate [Batch, num_head, max_len, head_dim] fixed buffer allocates, then inplace
- Optimization: PagedAttention, Quantization (KV: FP16, Q: FP32), FlashAttention

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, drop_rate, num_heads, qkv_bias = False):
        super().__init__()
        assert d_out%num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out//num_heads # 768//12 == 64

        self.Q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.K = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.V = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(drop_rate)
        self.proj_out = nn.Linear(d_out, d_out) # to combine head outputs
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1),
            persistent=False
        ) # non trainable parameter, but sort of model's state

        ####################################################
        # KV Cache: define in the buffer
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)
        ####################################################

    def forward(self, x, use_cache=False):
        batch, num_tokens, d_in = x.shape

        # 1) get q, k ,v
        # ==> d_out = num_heads * head_dim 
        q, k, v = self.Q(x), self.K(x), self.V(x) # [batch, num_tokens, d_out]

        # 2) split qkv into the head
        # ==> [batch, num_tokens, num_heads, head_dim]
        q = q.view(batch, num_tokens, self.num_heads, self.head_dim)
        k_new = k.view(batch, num_tokens, self.num_heads, self.head_dim)
        v_new = v.view(batch, num_tokens, self.num_heads, self.head_dim)

        ####################################################
        # KV Cache: based on num_tokens (batch, num_heads, head_dim are fixed)
        ## but) torch.cat operation is inefficent due to inability of resize --> O(L^2)
        if use_cache:
            if not self.cache_k:
                self.cache_k, self.cache_v = k_new, v_new
            else:
                self.cache_k = torch.cat([self.cache_k, k_new], dim=1) # along num_heads
                self.cache_v = torch.cat([self.cache_v, v_new], dim=1)
            k, v = self.cache_k, self.cache_v
        else:
            k, v = k_new, v_new
        ####################################################
        
        # 3) [batch, num_heads, num_tokens, head_dim]
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        # 4) get attention score [batch, num_heads, num_tokens, num_tokens]
        attn_score = q @ k.transpose(2,3)

        # 5) masking 
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_score.maksed_fill_(mask_bool, -torch.inf)

        # 6) get attention weight (normalization)
        attn_weight = torch.softmax(attn_score/k.shape[-1]**0.5, dim=-1)

        # 7) dropout 
        attn_weight = self.dropout(attn_weight)

        # 8) get context vector 
        # ==> [batch, num_heads, num_tokens, head_dim] -> [batch, num_tokens, num_heads, head_dim]
        context_vec = (attn_weight @ v).transpose(1,2)

        # 9) combine head to make original shape
        context_vec = context_vec.reshape(batch, num_tokens, self.d_out)
        context_vec = self.proj_out(context_vec)
        return context_vec

    ####################################################
    # KV Cache: reset
    def reset_cache(self):
        self.cache_k, self.cache_v = None, None
    ####################################################

### Transformer Architecture
1. Layer Normalization
2. GELU + Feed forward
3. Skip Connection in Transformer block

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self,x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x-meam)/torch.sqrt(var + self.eps)
        return self.scale*norm_x + self.shift

In [5]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0/torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))))

In [6]:
class FeedForward(nn.Module):
    def __init__(self, emd_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, 4*emb_dim),
            GELU(),
            nn.Linear(4*emb_dim, emb_dim)
        )
    def forward(self, x):
        return self.layers(x)

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.attn = MultiHeadAttention(
            d_in = cfg["emb_dim"],
            d_out = cfg["emb_dim"],
            context_length = cfg["context_length"],
            drop_rate = cfg["drop_rate"],
            num_heads = cfg["n_heads"],
            qkv_bias = cfg["qkv_bias"])
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.ff = FeedForward(cfg["emb_dim"])
        self.dropAdd = nn.Dropout(cfg["drop_rate"])
        
    def forward(self, x, use_cache=False):
        skip_ = x
        x = self.norm1(x)
        x = self.att(x, use_cache=use_cache)
        x = dropAdd(x)
        x += skip_

        skip_ = x
        x = self.norm2(x)
        x = self.ff(x)
        x = dropAdd(x)
        x += skip_

        return x

### GPT
1. token embedding
2. positional embedding
3. dropout
4. transformer block x 12
5. final layernorm
6. linear output projection

In [11]:
class GPT2Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        ####################################################
        # KV cache
        self.trf_blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.current_pos = 0
        ####################################################

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx, use_cache=False):
        batch, seq_len = in_idx.shape

        # 1st) Token embedding
        tok_embeds = self.tok_emb(in_idx) 
        
        # 2nd) positional encoding
        # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        ####################################################
        # KV cache:
        if use_cache:
            pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
            self.current_pos += seq_len
        else:
            pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
        pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
        ####################################################
        x = tok_embeds + pos_embeds

        # 3rd) Dropout
        x = self.drop_emb(x)

        # 4th) trf block
        # x = self.trf_block(x)
        ####################################################
        # KV cache: change use_cache in the transformer block as well
        for blk in self.trf_blocks:
            x = blk(x, use_cache=use_cache)
        ####################################################

        # 5th) final layer norm
        x = self.final_norm(x)
        
        # 6th) linear projection
        logits = self.out_head(x)
        
        return logits

    def reset_kv_cache(self):
        for blk in self.trf_blocks:
            blk.att.reset_cache()
        self.current_pos = 0

In [10]:
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
    model.eval()

    ctx_len = model.pos_emb.num_embeddings
    if use_cache:
        model.reset_kv_cache() # initial cache
        with torch.no_grad():
            logits = model(idx[:, :-ctx_len:], use_cache=True)

        for _ in range(max_new_tokens):
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_idx], dim=1)
            with torch.no_grad():
                logits = model(next_idx, use_cache=True)
    else:
        for _ in range(max_new_tokens):
            with torch.no_grad():
                logits = model(idx[:, -cnt_len:], use_cache=False)
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_idx], dim=1)
    return idx