In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import GptConfig
config = GptConfig()

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [8]:
import math 


class GPTAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.Q_w = nn.Linear(config.d_model, config.d_model)
        self.K_w = nn.Linear(config.d_model, config.d_model)
        self.V_w = nn.Linear(config.d_model, config.d_model)
        self.O_w = nn.Linear(config.d_model, config.d_model)
        
    def _mask_atten_scores(self, atten_score, seq_len):
        # Masking (Casual)
        mask = torch.tril(torch.ones(1, 1, seq_len, seq_len)).to(self.config.device)  # lower triangle
        atten_score = atten_score.masked_fill(mask==0, float('-inf'))
        return atten_score
    
    def _calc_atten_score(self, query, key):
        d_k = key.size(-1)
        atten_score = (query @ key.transpose(-1, -2)) / math.sqrt(d_k)
        return atten_score
    
    def _create_contextualized_embeds(self, atten_score, value):
        contextulized_embeds = atten_score @ value # (b, n_head, seq, seq) * (b, n_head, seq, d_model) -> (b, n_head, seq, d_model)
        contextulized_embeds = contextulized_embeds.transpose(1, 2).contiguous().view(self.batch_size, self.seq_len, self.config.n_heads * self.config.head_dim)
        return contextulized_embeds
        
    def forward(self, x: torch.Tensor):
        self.batch_size, self.seq_len = x.shape[0], x.shape[1]
        
        query = self.Q_w(x)
        key = self.K_w(x)
        value = self.V_w(x)
        
        # (b, s, d_model) -> view -> (batch_size, seq_len, num_head, head_dim) -> permute ->(b, num_head, s, head_dim)
        query = query.view(self.batch_size, self.seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        key = key.view(self.batch_size, self.seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        value = value.view(self.batch_size, self.seq_len, config.n_heads, config.head_dim).transpose(1, 2)
        
        atten_score = self._calc_atten_score(query, key)
        atten_score = self._mask_atten_scores(atten_score, self.seq_len)
        atten_score = F.softmax(atten_score, dim=-1)
        contextualized_embed = self._create_contextualized_embeds(atten_score, value)
        
        return self.O_w(contextualized_embed)
        
        
        
        
        

In [9]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_model, config.intermidiate_size),
            nn.GELU(),
            nn.Linear(config.intermidiate_size, config.d_model)    
        )
        
    def forward(self, x):
        return self.feed_forward(x)

In [19]:
class RMSNorm(nn.Module):
    def __init__(self, config, eps=1e-8):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(config.d_model))
        self.eps = eps
    def forward(self, x: torch.Tensor):
        rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
        x_norm = x / (rms + self.eps)
        return x_norm * self.scale

In [23]:
attention = GPTAttention(config).to(device=config.device)
feedforward = FeedForward(config).to(device=config.device)
rms_norm = RMSNorm(config).to(config.device)

In [24]:
sample = torch.rand(1, 5, 768).to('cuda')
print(F.softmax(sample, dim=-1))

rms_norm(feedforward(attention(sample))).shape

tensor([[[0.0019, 0.0009, 0.0009,  ..., 0.0016, 0.0008, 0.0018],
         [0.0008, 0.0011, 0.0014,  ..., 0.0009, 0.0013, 0.0014],
         [0.0011, 0.0008, 0.0008,  ..., 0.0011, 0.0011, 0.0013],
         [0.0008, 0.0009, 0.0020,  ..., 0.0009, 0.0011, 0.0019],
         [0.0020, 0.0018, 0.0008,  ..., 0.0016, 0.0008, 0.0011]]],
       device='cuda:0')


torch.Size([1, 5, 768])

In [25]:
attention

GPTAttention(
  (Q_w): Linear(in_features=768, out_features=768, bias=True)
  (K_w): Linear(in_features=768, out_features=768, bias=True)
  (V_w): Linear(in_features=768, out_features=768, bias=True)
  (O_w): Linear(in_features=768, out_features=768, bias=True)
)