<a href="https://colab.research.google.com/github/MaazMikail/Vanilla-Transformer-Notebook/blob/main/Vanilla_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class InputEmbeddings(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.emb = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.emb(x) + torch.sqrt(torch.tensor(self.d_model))

In [3]:
embed = InputEmbeddings(512, 4096)
res = embed(torch.tensor([143, 891, 1000, 482, 18, 12]))
res.shape

torch.Size([6, 512])

In [40]:
class PositionalEncoding(nn.Module):

  def __init__(self, d_model: int, seq_len: int, dropout: float):
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.dropout = nn.Dropout()

    self.pe = torch.zeros(self.seq_len, self.d_model)

    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,self.d_model,2).float() * (-math.log(10000.0)/self.d_model))

    self.pe[:, 0::2] = torch.sin(position * div_term)
    self.pe[:, 1::2] = torch.cos(position * div_term)

    self.pe = self.pe.unsqueeze(0)
    self.pe.requires_grad = False


    self.register_buffer('pos', self.pe)


  def forward(self, x):


    x = x + self.pe

    return self.dropout(x)



In [5]:
pos = PositionalEncoding(512, 6,0)
#res = res.unsqueeze(0)
print(res.shape)

torch.Size([1, 6, 512])
torch.Size([6, 512])


In [6]:
embs_pos = pos(res)
embs_pos.shape

torch.Size([1, 6, 512])

In [7]:
class LayerNorm(nn.Module):

  def __init__(self, eps: float = 1e-6):
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1)) # mult
    self.bias = nn.Parameter(torch.ones(1)) # added

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    std = x.std(dim = -1, keepdim=True)
    return self.alpha * (x- mean) / (std + self.eps) + self.bias



In [8]:
ln = LayerNorm()
emb_norm = ln(embs_pos)
emb_norm.shape

torch.Size([1, 6, 512])

In [47]:
emb_norm[:, :, -1].mean()

tensor(0.6790, grad_fn=<MeanBackward0>)

In [49]:
embs_pos[:, :, -1].mean()

tensor(15.7962, grad_fn=<MeanBackward0>)

In [9]:
class FeedForward(nn.Module):

  def __init__(self, d_model: int, d_ff:int, dropout: float):
    super().__init__()
    self.ff_1 = nn.Linear(d_model, d_ff)
    self.ff_2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.ff_2(self.dropout(nn.functional.relu(self.ff_1(x))))

In [10]:
ff = FeedForward(512, 2048, 0)
ff_out = ff(embs_pos)
ff_out.shape

torch.Size([1, 6, 512])

In [11]:
class MultiHeadAttentionBlock(nn.Module):

  def __init__(self, d_model: int, h: int, dropout: float):
    super().__init__()
    self.d_model = d_model
    self.h = h
    self.dropout = nn.Dropout()
    assert self.d_model % h == 0, "d_model is not divisible by h"
    self.d_k = d_model // h

    self.w_q, self.w_k, self.w_v = [nn.Linear(d_model, d_model) for _ in range(3)]

    self.w_o = nn.Linear(d_model, d_model)

  @staticmethod
  def attention(query, key, value, mask, dropout: nn.Dropout):

    d_k = query.shape[-1]
    attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
    if mask is not None:
      attention_scores.masked_fill_(mask==0, -1e9)
    attention_scores = attention_scores.softmax(dim=-1)
    if dropout is not None:
      attention_scores = dropout(attention_scores)

    return (attention_scores @ value), attention_scores

  def forward(self, q, k, v, mask):
    query = self.w_q(q)
    key = self.w_k(k)
    value = self.w_v(v)

    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)

    x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

    x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h*self.d_k)

    return self.w_o(x)


In [12]:
mah = MultiHeadAttentionBlock(512, 8, 0)
mah.w_q.weight.shape, mah.w_q.bias.shape

(torch.Size([512, 512]), torch.Size([512]))

In [13]:
mah_res = mah(embs_pos, embs_pos, embs_pos, None)
mah_res.shape

torch.Size([1, 6, 512])

In [76]:
x, scores = MultiHeadAttentionBlock.attention(mah_res, mah_res, mah_res, None, None)

In [92]:
x.shape, x.transpose_(2, 1).contiguous().view(x.shape[0], x.shape[1], 8*64).shape

(torch.Size([1, 8, 6, 64]), torch.Size([1, 6, 512]))

In [78]:
scores.shape

torch.Size([1, 8, 6, 6])

In [91]:
x.shape

torch.Size([1, 8, 6, 64])

In [14]:
class ResidualConnection(nn.Module):

  def __init__(self, dropout: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNorm()

  def forward(self, x, sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

In [15]:
rc = ResidualConnection(0)
rc

ResidualConnection(
  (dropout): Dropout(p=0, inplace=False)
  (norm): LayerNorm()
)

In [16]:
class EncoderBlock(nn.Module):

  def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForward, dropout: float):
    super().__init__()
    self.attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])


  def forward(self, x, src_mask):
    x = self.residual_connections[0](x, lambda x: self.attention_block(x,x,x,src_mask))
    x = self.residual_connections[1](x, self.feed_forward_block)

    return x


In [17]:
en = EncoderBlock(mah, ff, 0)

In [18]:
all_result = en(embs_pos, None)

In [19]:
all_result.shape

torch.Size([1, 6, 512])

In [20]:
class Encoder(nn.Module):

  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

In [21]:
encoder = Encoder(nn.ModuleList([EncoderBlock(mah, ff, 0) for _ in range(5)]))
encoder

Encoder(
  (layers): ModuleList(
    (0-4): 5 x EncoderBlock(
      (attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0.5, inplace=False)
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward_block): FeedForward(
        (ff_1): Linear(in_features=512, out_features=2048, bias=True)
        (ff_2): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0, inplace=False)
      )
      (residual_connections): ModuleList(
        (0-1): 2 x ResidualConnection(
          (dropout): Dropout(p=0, inplace=False)
          (norm): LayerNorm()
        )
      )
    )
  )
  (norm): LayerNorm()
)

In [22]:
encoder_result = encoder(embs_pos, None)
encoder_result.shape

torch.Size([1, 6, 512])

In [28]:
class DecoderBlock(nn.Module):

  def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward: FeedForward, dropout: float):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward
    self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x,x,x,tgt_mask))
    x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)

    return x

In [29]:
dec_block = DecoderBlock(mah, mah, ff, 0)
dec_block

DecoderBlock(
  (self_attention_block): MultiHeadAttentionBlock(
    (dropout): Dropout(p=0.5, inplace=False)
    (w_q): Linear(in_features=512, out_features=512, bias=True)
    (w_k): Linear(in_features=512, out_features=512, bias=True)
    (w_v): Linear(in_features=512, out_features=512, bias=True)
    (w_o): Linear(in_features=512, out_features=512, bias=True)
  )
  (cross_attention_block): MultiHeadAttentionBlock(
    (dropout): Dropout(p=0.5, inplace=False)
    (w_q): Linear(in_features=512, out_features=512, bias=True)
    (w_k): Linear(in_features=512, out_features=512, bias=True)
    (w_v): Linear(in_features=512, out_features=512, bias=True)
    (w_o): Linear(in_features=512, out_features=512, bias=True)
  )
  (feed_forward_block): FeedForward(
    (ff_1): Linear(in_features=512, out_features=2048, bias=True)
    (ff_2): Linear(in_features=2048, out_features=512, bias=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (residual_connections): ModuleList(
    (0-2): 3 x Resi

In [24]:
class Decoder(nn.Module):

  def __init__(self, layers: nn.ModuleList):
    super().__init__()
    self.layers = layers
    self.norm = LayerNorm()

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, tgt_mask)
    return self.norm(x)

In [30]:
dec = Decoder(nn.ModuleList([DecoderBlock(mah,mah, ff, 0) for _ in range(5)]))
dec

Decoder(
  (layers): ModuleList(
    (0-4): 5 x DecoderBlock(
      (self_attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0.5, inplace=False)
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (cross_attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0.5, inplace=False)
        (w_q): Linear(in_features=512, out_features=512, bias=True)
        (w_k): Linear(in_features=512, out_features=512, bias=True)
        (w_v): Linear(in_features=512, out_features=512, bias=True)
        (w_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward_block): FeedForward(
        (ff_1): Linear(in_features=512, out_features=2048, bias=True)
        (ff_2): Linear(in_features=2048, out_features=512,

In [32]:
decoder_output = dec(embs_pos, encoder_result, None, None)
decoder_output.shape

torch.Size([1, 6, 512])

In [25]:
class ProjectionLayer(nn.Module):

  def __init__(self, d_model:int, vocab_size: int):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    x = self.proj(x)

    return torch.log_softmax(x, dim=-1)


In [33]:
proj = ProjectionLayer(512, 4096)
proj = proj(decoder_output)
proj.shape

torch.Size([1, 6, 4096])

In [34]:
class TransformerBlock(nn.Module):

  def __init__(self, encoder: Encoder, decoder: Decoder, src_emb: InputEmbeddings, tgt_emb: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, proj: ProjectionLayer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_emb = src_emb
    self.tgt_emb = tgt_emb
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.proj = proj

  def encode(self, src, src_mask):
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src, src_mask)

  def decode(self, encoder_output, src_mask, tgt, tgt_mask):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

  def project(self, x):
    return self.proj(x)




In [46]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len:int, d_ff: int = 2048, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.0):

  # embed layer
  src_embed = InputEmbeddings(d_model, src_vocab_size)
  tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

  # pos embs
  src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
  tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)


  # encoder
  encoder_blocks = []
  for _ in range(N):
    encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
    feed_forward_block = FeedForward(d_model, d_ff, dropout)
    encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
    encoder_blocks.append(encoder_block)

  # decoder
  decoder_blocks = []
  decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
  decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
  feed_forward_block = FeedForward(d_model, d_ff, dropout)
  decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
  decoder_blocks.append(decoder_block)

  # combine
  encoder = Encoder(nn.ModuleList(encoder_blocks))
  decoder = Decoder(nn.ModuleList(decoder_blocks))


  # projection
  projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

  # transformer

  transformer = TransformerBlock(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

  for p in transformer.parameters():
    if p.dim()>1:
      nn.init.xavier_uniform_(p)


  return transformer


In [47]:
tf = build_transformer(4096, 4096, 6, 6)

In [48]:
tf

TransformerBlock(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (attention_block): MultiHeadAttentionBlock(
          (dropout): Dropout(p=0.5, inplace=False)
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_o): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward_block): FeedForward(
          (ff_1): Linear(in_features=512, out_features=2048, bias=True)
          (ff_2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (residual_connections): ModuleList(
          (0-1): 2 x ResidualConnection(
            (dropout): Dropout(p=0.0, inplace=False)
            (norm): LayerNorm()
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (l

In [49]:
# Get named learnable parameters
learnable_named_params = [(name, param) for name, param in tf.named_parameters() if param.requires_grad]

for name, param in learnable_named_params:
    print(f"Parameter name: {name}, value: {param.shape}")


Parameter name: encoder.layers.0.attention_block.w_q.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_q.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.attention_block.w_k.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_k.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.attention_block.w_v.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_v.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.attention_block.w_o.weight, value: torch.Size([512, 512])
Parameter name: encoder.layers.0.attention_block.w_o.bias, value: torch.Size([512])
Parameter name: encoder.layers.0.feed_forward_block.ff_1.weight, value: torch.Size([2048, 512])
Parameter name: encoder.layers.0.feed_forward_block.ff_1.bias, value: torch.Size([2048])
Parameter name: encoder.layers.0.feed_forward_block.ff_2.weight, value: torch.Size([512, 2048])
Parameter name: enc

In [56]:
# Calculate the total number of learnable parameters
learnable_params = [param for param in tf.parameters() if param.requires_grad]
total_learnable_params = sum(p.numel() for p in learnable_params)
print(f"Total Learnable params : {(total_learnable_params / 1000000):.1f} Million")


Total Learnable params : 29.4 Million
