# The Transformer Architecture in PyTorch

[link](https://ut.philkr.net/deeplearning/transformers/the_transformer_architecture_in_pytorch/)

In [1]:
import torch
import torch.nn as nn

In [2]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.self_att = nn.MultiheadAttention(embed_dim, num_heads)
        self.in_norm= nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.mlp_norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        # Self-attention block
        x_norm = self.in_norm(x)
        x = x + self.self_att(x_norm, x_norm, x_norm)[0]
        x = x + self.mlp(self.mlp_norm(x))
        return x

class TransformerEncoder(nn.Module):
    def __init__(self,embed_dim, num_heads,num_layers,):
        super().__init__()
        self.network = nn.Sequential(
            *[TransformerEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)]
        )
        
    def forward(self, x):
        return self.network(x)

In [3]:
x = torch.randn(16,10, 128)
tran = TransformerEncoderLayer(embed_dim=128, num_heads=8)
nn_tran = nn.TransformerEncoderLayer(d_model=128, nhead=8)
tran(x).shape  # (sequence_length, batch_size, embed_dim)

torch.Size([16, 10, 128])