# Transformer Architecture
Follows the "Attention is all you need" paper architecture

In [1]:
import torch
from torch import nn
import math

In [2]:
d_model = 512
d_ff = 2048
h = 8
head_dim = d_model/h
N = 6
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class AttentionHead(nn.Module):
    def __init__(self,head_dim:int = 64, masked:bool = False) -> None:
        super().__init__()
        self.masked = masked
        self.queries= nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.keys = nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?
        self.values = nn.Linear(in_features=head_dim,out_features=head_dim) # kaparthy set bias=False why?

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        B,T,C = X.shape

        K = self.keys(X)
        Q = self.queries(X)
        V = self.values(X)

        scaled_dot_product_attention = (Q @ K.transpose(2,1))/torch.sqrt(torch.tensor(head_dim))
        if self.masked:
            mask = torch.tril(torch.ones(T,T)) == 0
            scaled_dot_product_attention = scaled_dot_product_attention.masked_fill(mask,-float("inf"))

        dot_product_softened = torch.softmax(scaled_dot_product_attention,dim=-1)
        return dot_product_softened @ V

# attn = AttentionHead(masked=True)
# v = attn(torch.rand((10,3,64)))
# v.shape

In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,d_model:int = 512, number_of_heads:int = 8, masked = False) -> None:
        super().__init__()
        self.head_dim = math.floor(d_model/number_of_heads)
        self.layer_norm = nn.LayerNorm(d_model)
        self.heads = [AttentionHead(head_dim=self.head_dim,masked=masked) for h in range(number_of_heads)]
        self.linear = nn.Linear(d_model,d_model)
    
    def forward(self,X:torch.Tensor) ->torch.Tensor:

        splits = torch.split(X,self.head_dim,dim=2) # paper says 'project'
        heads_output = []
        for head_index,head in enumerate(self.heads):
            x = head(splits[head_index]) # this could be distributed to multiple devices for // processing
            heads_output.append(x) # accumulate result
        
        o = torch.cat(heads_output,dim=-1)

        linear_output = self.linear(o)

        return self.layer_norm(X+linear_output)


# mhsa = MultiHeadSelfAttention(d_model=512,number_of_heads=8,masked=True)
# mhsa(torch.randn((10,5,512)))

In [5]:
class FeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model,out_features=d_ff),
            nn.ReLU(),
            nn.Linear(in_features=d_ff,out_features=d_model)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.layer_norm(X+self.ffn(X))


# feedforward = FeedForward()
# feedforward.state_dict()

# Encorder

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.multihead_self_attention = MultiHeadSelfAttention()
        self.feedforward = FeedForward()

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.feedforward(self.multihead_self_attention(X))
    
encoder_layer = EncoderLayer()
encoder_layer(torch.randn(10,5,512)).shape

torch.Size([10, 5, 512])

In [9]:
class Encoder(nn.Module):
    def __init__(self, N:int=6) -> None:
        super().__init__()
        self.encode = nn.Sequential(*[EncoderLayer() for n in range(N)])

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.encode(X)

encoder = Encoder(N=6)
encoder(torch.randn(10,5,512)).shape

torch.Size([10, 5, 512])

# Decoder

In [7]:
class DecoderLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.masked_multi_head_self_attention = MultiHeadSelfAttention(d_model=512,number_of_heads=8,masked=True)
        self.multi_head_self_attention = MultiHeadSelfAttention(d_model=512,number_of_heads=8,masked=False)
        self.feedforward = FeedForward()

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.feedforward(self.multi_head_self_attention(self.masked_multi_head_self_attention(X)))
    
decoder_layer = DecoderLayer()
decoder_layer(torch.randn(10,5,512)).shape

torch.Size([10, 5, 512])

In [10]:
class Decoder(nn.Module):
    def __init__(self, N:int=6) -> None:
        super().__init__()
        self.decode = nn.Sequential(*[DecoderLayer() for n in range(N)])
        
    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.decode(X)

decoder = Decoder()
decoder

Decoder(
  (decode): Sequential(
    (0): DecoderLayer(
      (masked_multi_head_self_attention): MultiHeadSelfAttention(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (linear): Linear(in_features=512, out_features=512, bias=True)
      )
      (multi_head_self_attention): MultiHeadSelfAttention(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (linear): Linear(in_features=512, out_features=512, bias=True)
      )
      (feedforward): FeedForward(
        (ffn): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): DecoderLayer(
      (masked_multi_head_self_attention): MultiHeadSelfAttention(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (linear): Linea

# Transformer Model

In [13]:
vocab_size = 1200

In [17]:
class Transformer(nn.Module):
    def __init__(self,vocab_size:int,context_size:int,d_model:int,number_of_encoder_blocks:int=6,number_of_decoder_blocks:int=6) -> None:
        super().__init__()
        self.encoder = Encoder(N=number_of_encoder_blocks)
        self.decoder = Decoder(N=number_of_decoder_blocks)
        self.linear = nn.Linear(in_features=d_model*context_size,out_features=vocab_size)
    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return X
model = Transformer(vocab_size=vocab_size,context_size=10,d_model=512)
model(torch.randn(10,5,512)).shape

torch.Size([10, 5, 512])

In [18]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 137.712MB
