# Transformer Architecture
Follows the "Attention is all you need" paper architecture

In [2]:
import torch
from torch import nn

In [3]:
d_model = 512
d_ff = 2048
h = 8
N = 6
device = "cuda" if torch.cuda.is_available() else "cpu"

In [27]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, heads:int = 8, masked = False) -> None:
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.heads = nn.Sequential()
    
    def forward(self,X:torch.Tensor) ->torch.Tensor:
        return self.layer_norm(X+self.heads(X))

In [20]:
class FeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model,out_features=d_ff),
            nn.ReLU(),
            nn.Linear(in_features=d_ff,out_features=d_model)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.layer_norm(X+self.ffn(X))
# feedforward = FeedForward()
# feedforward.state_dict()

In [18]:
class EncoderLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.multihead_self_attention = MultiHeadSelfAttention()
        self.feedforward = FeedForward()

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.feedforward(self.multihead_self_attention(X))
    
encoder_layer = EncoderLayer()
encoder_layer

EncoderLayer(
  (multihead_self_attention): MultiHeadSelfAttention()
  (feedforward): FeedForward()
)

In [14]:
class Encoder(nn.Module):
    def __init__(self, N:int=6) -> None:
        super().__init__()
        self.encode = nn.Sequential(*[EncoderLayer() for n in range(N)])

    def forward(self,X:torch.Tensor) -> torch.Tensor:
        return self.encode(X)

encoder = Encoder(N=6)
encoder

Encoder(
  (encode): Sequential(
    (0): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (1): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (2): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (3): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (4): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
    (5): EncoderLayer(
      (mhsa): MultiHeadSelfAttention()
      (feedforward): FeedForward()
    )
  )
)

In [28]:
class DecoderLayer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.masked_multihead_self_attention = MultiHeadSelfAttention(heads=8,masked=True)
        self.multihead_self_attention = MultiHeadSelfAttention(heads=8)
        self.feedforward = FeedForward()
        
    def forward(self,X:torch.Tensor) -> torch.Tensor:
        pass
    
decoder_layer = DecoderLayer()
decoder_layer

DecoderLayer(
  (masked_multihead_self_attention): MultiHeadSelfAttention(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (heads): Sequential()
  )
  (multihead_self_attention): MultiHeadSelfAttention(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (heads): Sequential()
  )
  (feedforward): FeedForward(
    (ffn): Sequential(
      (0): Linear(in_features=512, out_features=2048, bias=True)
      (1): ReLU()
      (2): Linear(in_features=2048, out_features=512, bias=True)
    )
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

In [25]:
class Decoder(nn.Module):
    def __init__(self, N:int=6) -> None:
        super().__init__()
        self.decode = nn.Sequential(*[DecoderLayer() for n in range(N)])


decoder = Decoder()
decoder

Decoder(
  (decode): Sequential(
    (0): DecoderLayer()
    (1): DecoderLayer()
    (2): DecoderLayer()
    (3): DecoderLayer()
    (4): DecoderLayer()
    (5): DecoderLayer()
  )
)

In [22]:
class Transformer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
model = Transformer()
model

Transformer(
  (encoder): Encoder(
    (encode): Sequential(
      (0): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (1): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (2): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (3): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (4): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
      (5): EncoderLayer(
        (multihead_self_attention): MultiHeadSelfAttention()
        (feedforward): FeedForward()
      )
    )
  )
  (decoder): Decoder()
)