# Transformer Architecture
Follows the "Attention is all you need" paper architecture

In [2]:
import torch
from torch import nn

In [30]:
d_model = 512
d_ff = 2048
h = 8
head_dim = d_model/h
N = 6
device = "cuda" if torch.cuda.is_available() else "cpu"

In [63]:
class AttentionHead(nn.Module):
    def __init__(self,head_dim:int = 64, masked:bool = False) -> None:
        super().__init__()
        self.masked = masked
        self.queries= nn.Linear(in_features=head_dim,out_features=head_dim)
        self.keys = nn.Linear(in_features=head_dim,out_features=head_dim)
        self.values = nn.Linear(in_features=head_dim,out_features=head_dim)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        B,T,C = X.shape

        K = self.keys(X)
        Q = self.queries(X)
        V = self.values(X)

        # print(K.shape,Q.shape,V.shape)
        # print(K.transpose(2,1).shape)
        # print((Q @ K.transpose(2,1)).shape)

        scaled_dot_product_attention = (Q @ K.transpose(2,1))/torch.sqrt(torch.tensor(head_dim))
        # print(scaled_dot_product_attention.shape)
        if self.masked:
            mask = torch.tril(torch.ones(T,T)) == 0
            scaled_dot_product_attention = scaled_dot_product_attention.masked_fill(mask,-float("inf"))
            # print(scaled_dot_product_attention[0])

        dot_product_softened = torch.softmax(scaled_dot_product_attention,dim=1)

        return dot_product_softened @ V

attn = AttentionHead(masked=True)
v = attn(torch.rand((10,3,64)))
v.shape

torch.Size([10, 3, 64])

In [27]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, heads:int = 8, masked = False) -> None:
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.heads = [AttentionHead() for h in range(heads)]
    
    def forward(self,X:torch.Tensor) ->torch.Tensor:
        return self.layer_norm(X+self.heads(X))

In [20]:
class FeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model,out_features=d_ff),
            nn.ReLU(),
            nn.Linear(in_features=d_ff,out_features=d_model)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.layer_norm(X+self.ffn(X))
# feedforward = FeedForward()
# feedforward.state_dict()

In [18]:
class EncoderLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.multihead_self_attention = MultiHeadSelfAttention()
        self.feedforward = FeedForward()

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.feedforward(self.multihead_self_attention(X))
    
encoder_layer = EncoderLayer()
encoder_layer

EncoderLayer(
  (multihead_self_attention): MultiHeadSelfAttention()
  (feedforward): FeedForward()
)

In [14]:
class Encoder(nn.Module):
    def __init__(self, N:int=6) -> None:
        super().__init__()
        self.encode = nn.Sequential(*[EncoderLayer() for n in range(N)])

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.encode(X)

encoder = Encoder(N=6)
encoder

Encoder(
  (encode): Sequential(
    (0): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (1): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (2): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (3): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (4): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (5): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
  )
)

In [25]:
class Decoder(nn.Module):
    def __init__(self, N:int=6) -> None:
        super().__init__()
        self.decode = nn.Sequential(*[DecoderLayer() for n in range(N)])


decoder = Decoder()
decoder

Decoder(
  (decode): Sequential(
    (0): DecoderLayer()
    (1): DecoderLayer()
    (2): DecoderLayer()
    (3): DecoderLayer()
    (4): DecoderLayer()
    (5): DecoderLayer()
  )
)

In [22]:
class Transformer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
model = Transformer()
model

Transformer(
  (encoder): Encoder(
    (encode): Sequential(
      (0): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (1): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (2): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (3): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (4): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (5): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
    )
  )
  (decoder): Decoder()
)