In [111]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math

In [144]:
@dataclass
class Config:
    n_layers: int = 2
    d_model: int = 15
    eps: float = 1e-5
    hidden_layer: int = 4 * d_model
    num_heads: int = 3
    
config = Config()
config   

Config(n_layers=2, d_model=15, eps=1e-05, hidden_layer=60, num_heads=3)

In [None]:
# Layer Normalization layer

class LayerNormalization(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.eps
        self.scale = nn.Parameter(torch.ones((config.d_model)))
        self.shift = nn.Parameter(torch.zeros((config.d_model))) 
        
    def forward(self, x: torch.Tensor):
        x_mean = x.mean(dim=-1, keepdim=True) 
        x_std = x.std(dim=-1, keepdim=True)
        x_norm = (x - x_mean) / (x_std + self.eps)
        return x * self.scale + self.shift
                
        
        
        

In [None]:
# FeedForward Layer
class FeedForwardLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        d_model = config.d_model
        hidden_layer = config.hidden_layer
        
        self.ff_layer = nn.Sequential([
            nn.Linear(d_model, hidden_layer),
            nn.GELU(),
            nn.Linear(hidden_layer, d_model)        
        ])
        
    def forward(self, x):
        return x + self.ff_layer(x)
        
        

In [None]:
class SelfAttentionLayer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.query_weights = nn.Linear(config.d_model, config.d_model)
        self.key_weights = nn.Linear(config.d_model, config.d_model)
        self.value_weights = nn.Linear(config.d_model, config.d_model)
        self.num_heads = config.num_heads
        assert config.d_model % config.num_heads == 0, "d_model should be divisible by num_heads"
        self.h_dmodel = config.d_model // config.num_heads
        self.neg_inf = - 1e+5
        
    def forward(self, x):
        qeury_vectors = self.query_weights(x)
        key_vectors = self.key_weights(x)
        value_vectors = self.value_weights(x)
        batch_size, seq_len, _ = x.size()
        
        # (B,S,d_model) -> (B, S, num_head, h_dmodel)
        qeury_vectors = qeury_vectors.view(batch_size, seq_len, self.num_heads, self.h_dmodel)
        key_vectors = key_vectors.view(batch_size, seq_len, self.num_heads, self.h_dmodel)
        value_vectors = value_vectors.view(batch_size, seq_len, self.num_heads, self.h_dmodel)
        
        # (B,S, num_heads, h_dmodel) -> (b,num_heads, S, h_dmodel)
        qeury_vectors = torch.permute(qeury_vectors, dims=(0, 2, 1, 3))
        key_vectors = torch.permute(key_vectors, dims=(0, 2, 1, 3))
        value_vectors = torch.permute(value_vectors, dims=(0, 2, 1, 3))
        casual_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        
        attention_scores = self.calculate_attention_score(qeury_vectors, key_vectors, casual_mask)
        contextualized_vectores = attention_scores@value_vectors
        
        contextualized_vectores = torch.permute(contextualized_vectores, dims=(0, 2, 1, 3))
        contextualized_vectores = contextualized_vectores.contiguous().view(batch_size, seq_len, self.num_heads*self.h_dmodel)
        return (contextualized_vectores, attention_scores)
    
    def calculate_attention_score(self, qeury, key, mask):
        # (B,NumHeads,Seq, h_dmodel) * (B,num_heads,h_model, seq) => (B,num_heads, seq, seq)
        k_dmodel = key.size(-1)
        attention_scores = torch.matmul(qeury @ key.tranpose(-1,-2)) / math.sqrt(k_dmodel)
        attention_scores = torch.masked_fill(attention_scores, mask=mask, value=self.neg_inf)
        attention_scores = torch.softmax(attention_scores, dim=-1)
        return attention_scores
        
        
        

In [140]:
torch.manual_seed(0)
tensor = torch.rand((3,3,15))
tensorb = torch.rand(3,3,15)

tensora= tensor.view(3,3,3,5)
tensorb = tensor.view(3,3,3,5)

In [145]:
ffv = SelfAttentionLayer(config)

In [None]:
ffv()

In [137]:
(tensora @ tensorb)

tensor([[[[0.9562, 1.0250, 0.5964],
          [1.0250, 2.0533, 0.7009],
          [0.5964, 0.7009, 0.3985]],

         [[1.5012, 1.5688, 1.1156],
          [1.5688, 2.3999, 1.5817],
          [1.1156, 1.5817, 1.3885]],

         [[1.0882, 0.6742, 0.5203],
          [0.6742, 1.9828, 0.7379],
          [0.5203, 0.7379, 0.6984]]],


        [[[1.6625, 2.1261, 0.7818],
          [2.1261, 3.1737, 1.1274],
          [0.7818, 1.1274, 0.4887]],

         [[2.0842, 2.0933, 1.7668],
          [2.0933, 2.4566, 2.3007],
          [1.7668, 2.3007, 2.3089]],

         [[1.5826, 0.7673, 1.0756],
          [0.7673, 2.0198, 0.7691],
          [1.0756, 0.7691, 0.9417]]],


        [[[2.2417, 1.2798, 1.7922],
          [1.2798, 1.0155, 1.0145],
          [1.7922, 1.0145, 1.8621]],

         [[0.8233, 1.2009, 0.7710],
          [1.2009, 2.3798, 1.4923],
          [0.7710, 1.4923, 1.1312]],

         [[1.3741, 0.6474, 1.0072],
          [0.6474, 0.3615, 0.5839],
          [1.0072, 0.5839, 1.5147]]]])

In [138]:
tensora.matmul(tensorb)

tensor([[[[0.9562, 1.0250, 0.5964],
          [1.0250, 2.0533, 0.7009],
          [0.5964, 0.7009, 0.3985]],

         [[1.5012, 1.5688, 1.1156],
          [1.5688, 2.3999, 1.5817],
          [1.1156, 1.5817, 1.3885]],

         [[1.0882, 0.6742, 0.5203],
          [0.6742, 1.9828, 0.7379],
          [0.5203, 0.7379, 0.6984]]],


        [[[1.6625, 2.1261, 0.7818],
          [2.1261, 3.1737, 1.1274],
          [0.7818, 1.1274, 0.4887]],

         [[2.0842, 2.0933, 1.7668],
          [2.0933, 2.4566, 2.3007],
          [1.7668, 2.3007, 2.3089]],

         [[1.5826, 0.7673, 1.0756],
          [0.7673, 2.0198, 0.7691],
          [1.0756, 0.7691, 0.9417]]],


        [[[2.2417, 1.2798, 1.7922],
          [1.2798, 1.0155, 1.0145],
          [1.7922, 1.0145, 1.8621]],

         [[0.8233, 1.2009, 0.7710],
          [1.2009, 2.3798, 1.4923],
          [0.7710, 1.4923, 1.1312]],

         [[1.3741, 0.6474, 1.0072],
          [0.6474, 0.3615, 0.5839],
          [1.0072, 0.5839, 1.5147]]]])