In [1]:
import math
import torch.nn as nn

In [2]:
GPT_2_CONFIGURATION_124M={
    "vocab_size":50257,
    "context_length":1024,
    "emb_dim":768,
    "n_heads":12,
    "n_layers":12,
    "dropout":0.1,
    "qkv_bias":False
}

In [3]:
class LayerNorm(nn.Module):
    def __init__(self,emb_dim):
        super(LayerNorm, self).__init__()
        self.eps=1e-5
        self.scale=nn.Parameter(torch.ones(emb_dim))
        self.shift=nn.Parameter(torch.zeros(emb_dim))
        
    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        var=x.var(dim=-1,keepdim=True)
        norm_x=(x-mean)/torch.sqrt(var+self.eps)
        return self.scale*norm_x+self.shift
    
class GELU(nn.Module):
    def __init__(self):
        super().__init__()
        # Precompute constant √(2/π) as a buffer for efficiency
        self.sqrt_2_pi = math.sqrt(2.0 / math.pi)

    def forward(self, x):
        # Ensure dtype and device compatibility
        sqrt_2_pi = torch.tensor(self.sqrt_2_pi, dtype=x.dtype, device=x.device)
        return 0.5 * x * (1 + torch.tanh(sqrt_2_pi * (x + 0.044715 * torch.pow(x, 3))))

class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.layers=nn.Sequential(nn.Linear(cfg["emb_dim"],4*cfg["emb_dim"]),
                                 GELU(),
                                 nn.Linear(4*cfg["emb_dim"],cfg["emb_dim"])
                                 )
        
    def forward(self,x):
        return self.layers(x)
    
    

In [4]:
class Multihead_Attention_V2(nn.Module):
    def __init__(self,d_in,d_out,context_length,num_heads,dropout,qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        self.d_out=d_out
        self.num_heads=num_heads
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_keys=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_values=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.head_dim=d_out//num_heads
        self.out_proj=nn.Linear(d_out,d_out)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
        
    
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_keys(x)
        queries=self.W_query(x)
        values=self.W_values(x)
        #now change dimensions
        
        keys=keys.view(b,num_tokens,self.num_heads,self.head_dim)
        queries=queries.view(b,num_tokens,self.num_heads,self.head_dim)
        values=values.view(b,num_tokens,self.num_heads,self.head_dim)
        #group as num_heads
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)
        
        attn_scores=queries @ keys.transpose(2,3)
        
        mask_bool=self.mask.bool()[:num_tokens, :num_tokens]  #if num_tokens is less than specified context length
        
        attn_scores.masked_fill_(mask_bool,-torch.inf)
        
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        
        context_vectors=(attn_weights @ values).transpose(1,2)
        context_vectors=context_vectors.contiguous().view(b,num_tokens,self.d_out)
        context_vectors=self.out_proj(context_vectors)
        
        return context_vectors
        
        
        

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.att=Multihead_Attention_V2(d_in=cfg["emb_dim"],
                                       d_out=cfg["emb_dim"],
                                       context_length=cfg["context_length"],
                                       num_heads=cfg["n_heads"],
                                       dropout=cfg["dropout"],
                                       qkv_bias=cfg["qkv_bias"])
        self.ff=FeedForward(cfg)
        self.norm1=LayerNorm(cfg["emb_dim"])
        self.norm2=LayerNorm(cfg["emb_dim"])
        self.drop_shortcut=nn.Dropout(cfg["dropout"])
    
    def forward(self,x):
        shortcut=x;
        x=self.norm1(x)
        x=self.att(x)
        x=self.drop_shortcut(x)
        x=x+shortcut
        
        shortcut=x
        x=self.norm2(x)
        x=self.ff(x)
        x=self.drop_shortcut(x)
        x=x+shortcut
        return x

In [6]:
torch.manual_seed(123)
x=torch.rand(2,4,768)
block=TransformerBlock(GPT_2_CONFIGURATION_124M)
output=block(x)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
output

tensor([[[ 0.1649,  0.4003, -0.0746,  ...,  1.2644,  0.3327,  0.7242],
         [ 0.0295,  0.0499,  0.2529,  ...,  0.4699,  0.1284,  0.9746],
         [ 0.5534,  0.5785, -0.0309,  ...,  1.1541,  0.3949,  0.7598],
         [ 0.1631,  0.7129,  0.7272,  ...,  0.3312,  0.5731,  0.9255]],

        [[ 0.1788,  1.1680,  0.5809,  ...,  0.1828,  0.0076, -0.5598],
         [-0.2919,  0.6317,  0.2002,  ...,  0.3218,  0.4671, -0.0381],
         [ 0.9273,  0.4202,  0.3183,  ...,  0.3771,  0.7189, -0.1203],
         [ 0.6033,  0.5767,  0.3411,  ...,  1.3796,  1.2681,  0.3915]]],
       grad_fn=<AddBackward0>)

In [8]:
class GPT_MODEL(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.tok_emb=nn.Embedding(cfg["vocab_size"],cfg["emb_dim"])
        self.pos_emb=nn.Embedding(cfg["context_length"],cfg["emb_dim"])
        self.drop_emb=nn.Dropout(cfg["dropout"])
        self.trf_block=nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.final_norm=LayerNorm(cfg["emb_dim"])
        self.out_head=nn.Linear(cfg["emb_dim"],cfg["vocab_size"],bias=False)
    
    
    def forward(self,x):
        batch_size,seq_len=x.shape
        tok_embs=self.tok_emb(x)
        pos_embs=self.pos_emb(torch.arange(seq_len))
        x=tok_embs+pos_embs
        x=self.drop_emb(x)
        x=self.trf_block(x)
        x=self.final_norm(x)
        logits=self.out_head(x)
        return logits
    
        

In [9]:
torch.manual_seed(123)
batch=torch.tensor([[6109,3626,6100,345],
                   [6109,1100,6622,257]])
model=GPT_MODEL(GPT_2_CONFIGURATION_124M)
output=model(batch)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [10]:
output

tensor([[[ 0.1381,  0.0079, -0.1957,  ..., -0.0222, -0.1062,  0.1717],
         [ 0.3867, -0.8400, -0.6558,  ..., -0.5162,  0.2362, -0.3349],
         [ 0.6985, -0.1826, -0.1634,  ...,  0.1472, -0.6503, -0.0054],
         [-0.4288,  0.1670, -0.1262,  ...,  1.1571,  0.5297, -0.5542]],

        [[ 0.1095, -0.2890, -0.1463,  ..., -0.0557,  0.2907, -0.2818],
         [-0.0709, -0.2588, -1.4234,  ...,  0.9827,  0.8914,  0.1527],
         [ 0.6661,  0.5157, -0.3347,  ...,  0.6909,  0.4849, -0.3056],
         [-0.1123,  0.0137,  0.4846,  ...,  1.1734, -0.4077, -0.0847]]],
       grad_fn=<UnsafeViewBackward0>)

In [53]:
def generate_text(model,max_num_tokens,context_size,idx):
    for _ in range(max_num_tokens):
        idx_cond=idx[:,-context_size:]
        with torch.no_grad():
            logits=model(idx_cond)
            logits=logits[:,-1,:]
            probabs=torch.softmax(logits,dim=-1)
            idx_next=torch.argmax(probabs,dim=-1,keepdim=True)
            idx=torch.cat((idx,idx_next),dim=1)
    return idx

In [54]:
import tiktoken

In [55]:
tokenizer=tiktoken.get_encoding("gpt2")

In [56]:
context_text="Hello,I am"
encoded=tokenizer.encode(context_text)


as this encoded is a list convert to tensor

In [57]:
encoded_tensor=torch.tensor(encoded).unsqueeze(0)

In [58]:
encoded_tensor.shape

torch.Size([1, 4])

In [59]:
model.eval()
out=generate_text(model=model,max_num_tokens=6,context_size=GPT_2_CONFIGURATION_124M["context_length"],idx=encoded_tensor)

In [60]:
out

tensor([[15496,    11,    40,   716, 27018,  7283, 46275, 11472, 21692, 43530]])

In [61]:
decoded_text=tokenizer.decode(out.squeeze(0).tolist())

In [62]:
decoded_text

'Hello,I am Feature IT snowball shocked merits neocons'