In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Tokenizer

In [2]:
GPT2Config()

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.30.2",
  "use_cache": true,
  "vocab_size": 50257
}

In [None]:
config = GPT2Config()

In [4]:
class GPT2Attention(nn.Module):
    def __init__(self,config):
        super(GPT2Attention, self).__init__()
        
        max_positions = config.n_positions
        self.mask = torch.tril(torch.ones(max_positions, max_positions), dtype = torch.uint8).unsqueeze(0).unsqueeze(0) #batch_size and heads
        self.emb_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.emb_dim//self.num_heads
        self.split_size = config.n_embd
        self.c_attention = nn.Linear(self.emb_dim,3*self.emb_dim) #doesn't need to declare three times for q,v,k
        self.c_proj = nn.Linear(self.emb_dim,self.emb_dim)
        self.dropout = nn.Dropout(0.1)
    def _attn(self,query, key, value):
        # query, key, value : Batch_size, nheads,seq_len,d_heads
        attn_weights = torch.matmul(query, key.transpose(-1,-2))
        attn_weights = attn_weights/float(query.size(-1)**0.5)
        
        T = query.size(-2)
        casual_mask = self.mask[:,:,:T,:T].bool()
        attn_weights = torch.where(casual_mask, attn_weights, torch.tensor(1e-4))
        
        #attn_weights = F.softmax(attn_weights, dim = -1)
        attn_weights = nn.Softmax(dim = -1)(attn_weights)
        attn_weights = self.dropout(attn_weights)
        attn_weights = torch.matmul(attn_weights, value)
        return attn_weights
    def forward(self, x):
        #x: batch_size, seq_len, dimensionality     (B,T,c)
        
        B,T,C = x.size()
        quey, key, value = self.c_attention(x).split(self.split_size, dim = -1)# B,T,3*C
        query = query.view(B,T,self.num_heads, self.head_dim).transpose(1,2)
        key = key.view(B,T,self.num_heads, self.head_dim).transpose(1,2)
        value = value.view(B,T,self.num_heads, self.head_dim).transpose(1,2)
        
        att_output = self._attn(query, key, value) #batch_size,n_heads, seq_len, d_heads
        att_output = att_output.transpose(1,2).view(B,T,C)
        att_output = self.c_proj(att_output)
        att_output = self.dropout(att_output)
        return att_output

In [4]:
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super(GPT2MLP, self).__init__()
        self.emb_dim = config.n_embd
        self.mlp = nn.Sequential(nn.Linear(self.emb_dim, 4*self.emb_dim), nn.GELU(), nn.Linear(4*self.emb_dim, self.emb_dim),nn.dropout(0.1))
    def forward(self,x):
        return self.mlp(x)

In [6]:
class GPT2Block(nn.Module):
    def __init__(self,config):
        super(GPT2Block, self).__init__()
        emb_dim = config.n_embd
        self.layernorm1 = nn.LayerNorm(emb_dim)
        self.layernorm2 = nn.LayerNorm(emb_dim)
        self.attn = GPT2Attention(config)
        self.mlp = GPT2MLP(config)
    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.layernorm1(hidden_states)
        attn_outputs = self.attn(hidden_states)
        hidden_states = residual + attn_outputs
        residual = hidden_states 
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

In [7]:
class GPT2Model(nn.Module):
    def __init__(self,config):
        super(GPT2Model, self).__init__()
        self.emb_dim = config.n_embd
        self.vocab_size = config.vocab_size
        self.wte = nn.Embedding(self.vocab_size, self.emb_dim)
        self.wpe = nn.Embedding(config.n_positions, self.emb_dim)
        self.dropout = nn.Dropout(0.1)
        self.blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
        self.layernormf = nn.LayerNorm(emb_dim)
    def forward(self, input_ids=None, position_ids = None):
        input_shapes = input_ids.size() #batch_size, max_seq_len
        batch_size = input_ids.size(0)
        device = input_ids.device
        
        position_ids = torch.arange(0,input_ids.size(1), dtype = torch.long, device = device)
        position_ids = position_ids.unsqueeze(0)
        input_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = input_embeds + position_embeds
        hidden_states = self.dropout(hidden_states)
        
        for block in self.blocks:
            hidden_states = block(hidden_states)
        hidden_states = self.layernormf(hidden_states)
        
        return hidden_states

In [None]:
class GPT2LMHead(nn.Module):
    def __init__(self,config):
        super(GPT2LMHead, self).__init__()
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size,bias = False)
        self.xe = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token)
    def forward(self,input_ids = None, position_ids = None, labels = None):
        
        hidden_states = self.transformer(input_ids)
        lm_logits = self.lm_head(hidden_states) #batch_size, max_seq_len, vocab_size
        
        loss = None
        if labels is not None:
            shift_logits = lm_logits[:,-1, :]
            shift_labels = labels[:,1:]
            loss = self.xe(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1))
        return lm_logits, loss
            