In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import GptConfig
config = GptConfig()

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [2]:
import math 


class GPTAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.Q_w = nn.Linear(config.d_model, config.d_model)
        self.K_w = nn.Linear(config.d_model, config.d_model)
        self.V_w = nn.Linear(config.d_model, config.d_model)
        self.O_w = nn.Linear(config.d_model, config.d_model)
        
    def _mask_atten_scores(self, atten_score, seq_len):
        # Masking (Casual)
        mask = torch.tril(torch.ones(1, 1, seq_len, seq_len)).to(self.config.device)  # lower triangle
        atten_score = atten_score.masked_fill(mask==0, float('-inf'))
        return atten_score
    
    def _calc_atten_score(self, query, key):
        d_k = key.size(-1)
        atten_score = (query @ key.transpose(-1, -2)) / math.sqrt(d_k)
        return atten_score
    
    def _create_contextualized_embeds(self, atten_score, value):
        contextulized_embeds = atten_score @ value # (b, n_head, seq, seq) * (b, n_head, seq, d_model) -> (b, n_head, seq, d_model)
        contextulized_embeds = contextulized_embeds.transpose(1, 2).contiguous().view(self.batch_size, self.seq_len, self.config.n_heads * self.config.head_dim)
        return contextulized_embeds
        
    def forward(self, x: torch.Tensor):
        self.batch_size, self.seq_len = x.shape[0], x.shape[1]
        
        query = self.Q_w(x)
        key = self.K_w(x)
        value = self.V_w(x)
        
        # (b, s, d_model) -> view -> (batch_size, seq_len, num_head, head_dim) -> permute ->(b, num_head, s, head_dim)
        query = query.view(self.batch_size, self.seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        key = key.view(self.batch_size, self.seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        value = value.view(self.batch_size, self.seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        
        atten_score = self._calc_atten_score(query, key)
        atten_score = self._mask_atten_scores(atten_score, self.seq_len)
        atten_score = F.softmax(atten_score, dim=-1)
        contextualized_embed = self._create_contextualized_embeds(atten_score, value)
        
        return self.O_w(contextualized_embed)
        
        
        
        
        

In [3]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_model, config.intermidiate_size),
            nn.GELU(),
            nn.Linear(config.intermidiate_size, config.d_model)    
        )
        
    def forward(self, x):
        return self.feed_forward(x)

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, config, eps=1e-8):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(config.d_model))
        self.eps = eps
    def forward(self, x: torch.Tensor):
        rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
        x_norm = x / (rms + self.eps)
        return x_norm * self.scale

In [5]:
attention = GPTAttention(config).to(device=config.device)
feedforward = FeedForward(config).to(device=config.device)
rms_norm = RMSNorm(config).to(config.device)

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.attention = GPTAttention(config)
        self.feed_forward = FeedForward(config)
        self.norm = RMSNorm(config)
        
    def forward(self, x):
        x = self.norm(x)
        x = self.attention(x)
        x = feedforward(x)
        
        return x

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pos_embed = nn.Embedding(config.max_len, config.d_model)
    
    def forward(self, x):
        seq_len = x.shape[1] 
        positions = torch.range(0, seq_len - 1, dtype=torch.long, device=config.device).unsqueeze(0)
        print(positions)
        return x + self.pos_embed(positions)

In [17]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.tok_embed = nn.Embedding(config.vocab_size, config.n_embed)
        self.pos_embed = PositionalEncoding(config)
        self.decoder_blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layers)])
        self.final_rms_norm = RMSNorm(config)
        self.final_layer = nn.Linear(config.n_embed, config.vocab_size)
        self.criteria = nn.CrossEntropyLoss()
    
    def forward(self, x, target):
        """x: (b, seq)
           target: (b, seq)"""
        batch_size, seq_len = x.size()
        x = self.tok_embed(x)
        x = self.pos_embed(x)
        
        for block in self.decoder_blocks:
            x = block(x)
            
        x = self.final_rms_norm(x)
        logits = self.final_layer(x) # (b, seq, vocab_size)
        print(logits.shape, target.shape)
        loss = self.criteria(logits.view(batch_size*seq_len, -1), target.view(-1))
        
        return {'logits' : logits, "loss": loss}
    

In [None]:
sample = torch.randint(1, 10, size=(1, 5)).to('cuda')
target = torch.randint(1, 10, size=(1, 5)).to('cuda')

gpt = GPT(config).to('cuda')

gpt(sample, target)

tensor([[0, 1, 2, 3, 4]], device='cuda:0')
torch.Size([1, 5, 55]) torch.Size([1, 5])


  positions = torch.range(0, seq_len - 1, dtype=torch.long, device=config.device).unsqueeze(0)


{'logits': tensor([[[ 0.1365, -0.3977, -0.1294,  0.3840, -0.4148,  0.6821, -0.3824,
            0.9976,  0.1885, -0.3748,  0.3146, -0.1096,  0.4535,  0.2106,
           -0.1533, -1.0743,  0.2641,  0.7782, -0.0524, -0.3638,  0.6550,
            0.1163,  0.4044, -0.1925,  0.8356,  0.6839, -1.0257, -0.5049,
           -0.6323, -0.3597,  0.0450,  0.3520, -0.1769,  0.5743, -0.1087,
           -0.0463,  0.8905, -0.0746, -0.4379,  0.0583,  0.0269, -0.9097,
           -0.1149, -0.1206,  0.7310,  0.2275, -0.0553, -0.2520,  0.9179,
           -0.5907, -0.0777, -0.1224, -0.1404, -0.6363, -0.7409],
          [ 0.1368, -0.3978, -0.1298,  0.3841, -0.4146,  0.6822, -0.3822,
            0.9977,  0.1886, -0.3747,  0.3147, -0.1096,  0.4536,  0.2105,
           -0.1535, -1.0744,  0.2638,  0.7785, -0.0522, -0.3638,  0.6551,
            0.1161,  0.4045, -0.1926,  0.8355,  0.6839, -1.0259, -0.5051,
           -0.6323, -0.3599,  0.0451,  0.3522, -0.1767,  0.5743, -0.1089,
           -0.0462,  0.8908, -0.0746

In [None]:
attention

GPTAttention(
  (Q_w): Linear(in_features=768, out_features=768, bias=True)
  (K_w): Linear(in_features=768, out_features=768, bias=True)
  (V_w): Linear(in_features=768, out_features=768, bias=True)
  (O_w): Linear(in_features=768, out_features=768, bias=True)
)