In [1]:
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer,GPT2Config
import torch.nn.functional as F

In [2]:
config=GPT2Config()
config

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.32.1",
  "use_cache": true,
  "vocab_size": 50257
}

In [3]:
torch.tril(torch.ones(10,10)) # for masked attn

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [4]:
class gpt2attn(nn.Module):
    def __init__(self,config):
        super(gpt2attn, self).__init__()
        
        max_postion=config.n_positions
        self.mask=torch.tril(torch.ones(max_postion,max_postion),dtype=torch.uint8).unsqueeze(0).unsqueeze(0)
        # or below
        # self.mask=torch.tril(torch.ones(max_postion,max_postion),dtype=torch.uint8).view(1,1,max_postion,max_postion)
        self.embed_dim=config.n_embed
        self.head=config.n_head
        self.head_no=self.embed_dim//self.head
        self.split_size=self.embed_dim
        self.c_attn=nn.Linear(self.embed_dim,3*self.embed_dim) # for q,k,v
        self.c_proj=nn.Linear(self.embed_dim,self.embed_dim)
        self.dropout=nn.Dropout(0.1)
    
    def _attn(self,q,k,v):
        # shape of q,k,v -> (b,n_head,seq_len,d_head)
        attn_weight=torch.matmul(q,k.transpose(-1,-2))
        attn_weight/=float(q.size(-1))**0.5
        seq_len=q.size(-2)
        causal_mask=self.mask[:,:,:seq_len,:seq_len].bool()
        attn_weight=torch.where(causal_mask,attn_weight,torch.tensor(-1e4))
        attn_weight=F.softmax(attn_weight,dim=-1)
        attn_weight=self.dropout(attn_weight)
        #attn_weight : (b,n_head,seq_len,sqe_len)
        # we dont need to transpose as it can be matrix mulltipilcablke
        attn_out=torch.matmul(attn_weight,v)
        
        return attn_out
    def forward(self,x):
        # shape of x -> batch,seq_len,dim_model
        b,seq_len,dim_model=x
        q,k,v=self.c_attn(x).split(self.split_size,dim=-1) # (Batch,seq_len,c*3) #split size =c
        q=q.view(b,seq_len,self.head,self.head_no).transpose(1,2)
        k=k.view(b,seq_len,self.head,self.head_no).transpose(1,2)
        v=v.view(b,seq_len,self.head,self.head_no).transpose(1,2)
        # shape (b,n_head,seq_len,d_head)
        
        attn_out=self._attn(q,k,v)
        attn_out=attn_out.transpose(1,2).view(b,seq_len,dim_model)
        attn_out=self.dropout(attn_out)
        
        return attn_out  
        

In [5]:
class gpt2mlp(nn.Module):
    def __init__(self, config) -> None:
        super(gpt2mlp,self).__init__()
        embed_dim=config.n_embed
        self.mlp=nn.Sequential(nn.Linear(embed_dim,4*embed_dim),
                               nn.GELU(),
                               nn.Linear(4*embed_dim,embed_dim),
                               nn.Dropout(0.1))
    def forward(Self,x):
        return self.mlp(x)
    

In [6]:
class gpt2block(nn.Module):
    def __init__(self,config) -> None:
        super(gpt2block,self).__init__()
        embed_dim=config.n_embed
        self.layer_norm_1=nn.LayerNorm(embed_dim)
        self.layer_norm_2=nn.LayerNorm(embed_dim)
        self.attn=gpt2attn(config)
        self.mlp=gpt2mlp(config)
        
    def forward(self,hidden_states):
        residual=hidden_states
        hidden_states=self.layer_norm_1(hidden_states)
        attn_out=self.attn(hidden_states)
        hidden_states=attn_out+residual
        residual=hidden_states
        ffw_hidden_states=self.mlp(hidden_states)
        hidden_states=residual+ffw_hidden_states
        
        return hidden_states

In [7]:
class gpt2model(nn.Module):
    def __init__(self,config) -> None:
        super(gpt2model,self).__init__()
        self.embed_dim=config.n_embed
        self.vocab_size=config.vocab_size
        self.token_embed=nn.Embedding(self.vocab_size,self.embed_dim)
        self.postional_embed=nn.Embedding(config.n_positions,self.embed_dim)
        
        self.drop=nn.Dropout(0.1)
        self.blocks=nn.ModuleList([gpt2block(config) for _ in range(config.n_layer)])
        
        self.layer_norm_final=nn.LayerNorm(self.embed_dim)
        
    def forward(self,input_ids=None,position_ids=None,):
        # input ids (batch size,max seq len)
        input_shape=input_ids.size()
        batch_size=input_ids.size(0)
        device=input_ids.device
        
        position_ids=torch.arange(0,input_ids.size(-1),dtype=torch.long,device=device).unsqueeze(0)
        
        input_embeds=self.token_embed(input_ids)
        position_embed=self.postional_embed(position_ids)
        
        hidden_states=input_embeds+position_embed
        hidden_states=self.drop(hidden_states)
        
        for block in self.blocks:
            hidden_states=block(hidden_states)
            
        hidden_states=self.layer_norm_final(hidden_states)
        
        return hidden_states

In [8]:
torch.arange(0,10) #-> tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [None]:
class gpt2languagehead(nn.Module):
    def __init__(self,config) -> None:
        super(gpt2languagehead,self).__init__()
        self.embed_dim=config.n_embed
        self.transfomer=gpt2model(config)
        self.lm_head=nn.Linear(config.n_embed,config.vocab_size,bias=False)
        self.xe=nn.CrossEntropyLoss(ignore_index=GPT2Tokenizer.pad_token)
    
    def forward(self,input_ids=None,position_ids=None,labels=None):
        
        hidden_states=self.transfomer(input_ids)
        lm_head=self.lm_head(hidden_states)# (bs,max seq len ,vocab size)
        
        loss= None
        if labels is not None:
            # <bos> hey dude !
            # labels -> hey dude ! <eos>
            shift_logits=lm_head[:,:-1,:]
            shift_label=labels[:,1:]
            loss=self.xe(shift_logits.view(-1,shift_logits.size(-1)),shift_label.view(-1))
            
        return lm_head,loss
            