In [2]:
import torch
import torch.nn as nn

In [3]:
cfg = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

In [49]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        norm_x = (x-mean)/torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

In [50]:
class Feed_Forward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], cfg["emb_dim"]*4),
            nn.GELU(),
            nn.Linear(cfg["emb_dim"]*4, cfg["emb_dim"])
        )
    def forward(self, x):
        return self.layers(x)

In [51]:
class MultiHead_Attention(nn.Module):
    def __init__(self,d_in,
                 d_out,
                 num_head,
                 dropout,
                 context_length,
                 bias=False):
        super().__init__()
        self.W_Query = nn.Linear(d_in, d_out, bias=bias)
        self.W_Key = nn.Linear(d_in, d_out, bias=bias)
        self.W_Value = nn.Linear(d_in, d_out, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.num_head = num_head
        self.head_dim = d_out//num_head
        self.d_out = d_out
        self.out_project = nn.Linear(d_out, d_out)

        self. register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_Key(x)
        query = self.W_Query(x)
        value = self.W_Value(x)

        keys = keys.view(b, num_tokens, self.num_head, self.head_dim).transpose(1,2)
        query = query.view(b, num_tokens, self.num_head, self.head_dim).transpose(1,2)
        value = value.view(b, num_tokens, self.num_head, self.head_dim).transpose(1,2)
        
        att_score = query @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        att_score.masked_fill_(mask_bool,-torch.inf)

        att_weight = torch.softmax(att_score/keys.shape[-1]**0.5, dim=-1)
        att_weight = self.dropout(att_weight)
        
        context_vec = (att_weight @ value).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, d_in)
        context_vec = self.out_project(context_vec)
        return context_vec

In [52]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.att = MultiHead_Attention(d_in=cfg['emb_dim'],
                                       d_out=cfg['emb_dim'],
                                       num_head=cfg['n_heads'],
                                       dropout=cfg['drop_rate'],
                                       context_length=cfg['context_length'],
                                       bias=cfg['qkv_bias']
                                       )
        self.feed_forward = Feed_Forward(cfg=cfg)
        self.norm1 = LayerNorm(emb_dim=cfg['emb_dim'])
        self.norm2 = LayerNorm(emb_dim=cfg['emb_dim']) 
        self.shortcut_drop = nn.Dropout(cfg['drop_rate'])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.shortcut_drop(x)
        x = x + shortcut

        shortcut = x

        x = self.norm2(x)
        x = self.feed_forward(x)
        x = self.shortcut_drop(x)
        x = x + shortcut

        return x

In [55]:
torch.manual_seed(42)
x = torch.randn(2,4,768)
transformer = TransformerBlock(cfg=cfg)
print(f"Input: {x.shape}\nOutput: {transformer(x).shape}")


Input: torch.Size([2, 4, 768])
Output: torch.Size([2, 4, 768])
