- QKV chasing is a method that use to improve computational efficiency of the GPT model.
- This chaching is used at the inference stage
- What it does is saving the query, key, value vectors of the previous tokens and calculate them only for newly generated token

In [35]:
#need a method to identify first iteration
#at first iteration q,k,v for all inputs and k,v should chased

In [36]:
import torch
import torch.nn as nn
import tiktoken 

In [37]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "emb_dim": 768,          # Embedding dimension
    "n_heads": 12,           # Number of attention heads
    "n_layers": 12,          # Number of layers
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": False,        # Query-Key-Value bias
    "use_cache": True
 }
 

In [38]:
class MultiheadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False,use_cache=False):
        super().__init__()
        assert (d_out % num_heads == 0),"d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.w_query = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.w_value = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.out_proj = nn.Linear(d_out,d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length,context_length),diagonal=1)
        )
        self.use_cache = use_cache
        self.k_cache = None
        self.v_cache = None
        
    def forward(self,x):
        b,num_tokens,d_in = x.shape
        
        if self.use_cache: 
            if self.k_cache is None or self.v_cache is None:
                self.k_cache = torch.zeros(b, self.num_heads, 0, self.head_dim, device=x.device)
                self.v_cache = torch.zeros(b, self.num_heads, 0, self.head_dim, device=x.device)
                
                queries = self.w_query(x)
                keys = self.w_key(x)
                values = self.w_value(x)

                queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)
                keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
                values = values.view(b,num_tokens,self.num_heads,self.head_dim)
                
            else:
                queries = self.w_query(x[:,-1,:])
                keys = self.w_key(x[:,-1,:])
                values = self.w_value(x[:,-1,:])

                queries = queries.view(b,1,self.num_heads,self.head_dim)
                keys = keys.view(b,1,self.num_heads,self.head_dim)
                values = values.view(b,1,self.num_heads,self.head_dim)

            queries = queries.transpose(1,2)
            keys = keys.transpose(1,2)
            values = values.transpose(1,2)
            
            self.k_cache = torch.cat([self.k_cache,keys],dim=2)
            self.v_cache = torch.cat([self.v_cache,values],dim=2)
            
            keys, values = self.k_cache, self.v_cache
            queries = queries
            
        else:
            queries = self.w_query(x)
            keys = self.w_key(x)
            values = self.w_value(x)

            queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)
            keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
            values = values.view(b,num_tokens,self.num_heads,self.head_dim)

            queries = queries.transpose(1,2)
            keys = keys.transpose(1,2)
            values = values.transpose(1,2)
            
        attn_scores = torch.matmul(queries,keys.transpose(2,3))
        
        q_len, k_len = attn_scores.size(-2), attn_scores.size(-1)
        mask = torch.triu(torch.ones(q_len, k_len, device=x.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, float("-inf"))
            
        scaled_attn_scores = attn_scores/keys.shape[-1]**0.5
        attn_weights = torch.softmax(scaled_attn_scores,dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = (torch.matmul(attn_weights,values)).transpose(1,2)
        num_tokens_out = context_vec.size(1)
        context_vec = context_vec.contiguous().view(b, num_tokens_out, self.d_out)
        context_vec = self.out_proj(context_vec)
                
        return context_vec

In [39]:
#implement layer normalization class
class LayerNorm(nn.Module):
    def __init__(self,emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self,x):
        mean = x.mean(dim=-1,keepdim=True)
        var = x.var(dim=-1,keepdim=True)
        norm_x = (x-mean)/torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift
    
    
#Gelu activation function
class GELU(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self,x):
        return 0.5 * x * (1+torch.tanh(
            torch.sqrt(torch.tensor(2.0/torch.pi)) * 
            (x+0.044715 * torch.pow(x, 3))
        ))
    
#Feed forward network
class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear( cfg['emb_dim'],4*cfg['emb_dim']), #linear layer
            GELU(),                                      #apply non-linear gelu activation
            nn.Linear(4*cfg['emb_dim'], cfg['emb_dim'])
        )
        
    def forward(self,x):
        return self.layers(x) 

In [40]:
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.att = MultiheadAttention(  #initializing attention layer
            d_in = cfg['emb_dim'],
            d_out = cfg['emb_dim'],
            context_length = cfg['context_length'],
            num_heads = cfg['n_heads'],
            dropout = cfg['drop_rate'],
            qkv_bias = cfg['qkv_bias'],
            use_cache = cfg['use_cache']
        )
        self.ff = FeedForward(cfg) #initializing feedforward layer
        self.norm1 = LayerNorm(cfg['emb_dim']) #initializing layer normalization
        self.norm2 = LayerNorm(cfg['emb_dim']) 
        self.drop_shortcut = nn.Dropout(cfg['drop_rate']) #add dropout
        
    def forward(self,x):
        shortcut = x              #shortcut connection for attention block
        x = self.norm1(x)         #normalize input    
        x = self.att(x)           #forward through attention layer
        x = self.drop_shortcut(x) #dropout certain nuerones
        x = x + shortcut          #add original input to attention block back
            
        shortcut = x              #shortcut connection for feed forward block
        x = self.norm2(x)         #normalize input for feed forward block
        x = self.ff(x)            #forward through feed forward block
        x = self.drop_shortcut(x) #dropout
        x = x + shortcut          #add original input for ff block back
            
        return x

In [41]:
class GPTModel(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        #initializing token embeddings
        self.tok_emb = nn.Embedding(cfg['vocab_size'],cfg['emb_dim'])
        #initializing positional embedingd
        self.pos_emb = nn.Embedding(cfg['context_length'],cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])
        
        #Initializing layer of transformer block
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg['n_layers'])]
        )
        
        #initializing final normalization layer
        self.final_norm = LayerNorm(cfg['emb_dim'])
        #final output layer project vectors to space of dimention of vocabulary size
        self.out_head = nn.Linear(cfg['emb_dim'],cfg['vocab_size'],bias=False)
        
    def forward(self,in_idx):
        batch_size,seq_len = in_idx.shape
        #create token embeddings
        tok_embeds = self.tok_emb(in_idx) 
        
        #create positionl embeddingd
        pos_embeds = self.pos_emb(torch.arange(seq_len,device=in_idx.device))
        
        #apply token embeddings to positional embeddings
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        
        return logits

In [42]:
model = GPTModel(GPT_CONFIG_124M)

In [43]:
#A function for GPT model to generate text
def generate_text_simple(model,idx,max_new_tokens,context_size): #idx is a (batch,n_tokens) array of indices in the current context
    for _ in range(max_new_tokens):
        idx_cond = idx[:,-context_size:] #selecting last idxs of context size from each batch
        with torch.no_grad():
            logits = model(idx_cond) #genetate logits for predict next token
    
        logits = logits[:,-1,:] #get the last row of logits for each batch
        probas = torch.softmax(logits,dim=-1)
        idx_next = torch.argmax(probas,dim=-1,keepdim=True)
        idx = torch.cat((idx,idx_next),dim=1)
        
    return idx

In [44]:
tokenizer = tiktoken.get_encoding('gpt2')
start_context = "Hello, I am"
encoded = tokenizer.encode(start_context)
print("encoded:",encoded)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print("encoded_tensor.shape:", encoded_tensor.shape)

encoded: [15496, 11, 314, 716]
encoded_tensor.shape: torch.Size([1, 4])


In [45]:
model.eval()
out = generate_text_simple(
    model=model,
    idx=encoded_tensor,
    max_new_tokens=20,
    context_size=GPT_CONFIG_124M['context_length']
)
print("Output:", out)
print("Output length:", len(out[0]))

Output: tensor([[15496,    11,   314,   716, 18159,  1194,  3576, 10972,  3193, 24905,
         29833,  3999, 16450, 10909, 22031, 42839, 40805,  8086, 43110, 46423,
         45111,  3876, 40331, 46761]])
Output length: 24


In [46]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(decoded_text)

Hello, I am masters another LondonptoniblyiftyAIN Chinese 1969 acceptableexpensiveKER crappyAtt sparing Meter flungmarSUP771
