<a href="https://colab.research.google.com/github/AnasAlhasan/large-models-course/blob/main/notebooks/TransformersNotebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [2]:
#Multi-Head Attention
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % num_heads == 0

    self.depth = d_model // num_heads

    self.Wq = nn.Linear(d_model, d_model)
    self.Wk = nn.Linear(d_model, d_model)
    self.Wv = nn.Linear(d_model, d_model)

    self.out = nn.Linear(d_model, d_model)

  def split_heads(self, x, batch_size):
    x = x.view(batch_size, -1, self.num_heads, self.depth)


    return x.transpose(1,2)

  def forward(self, x):

    batch_size = x.size(0)
    Q = self.split_heads(self.Wq(x), batch_size)
    K = self.split_heads(self.Wk(x), batch_size)
    V = self.split_heads(self.Wv(x), batch_size)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5)
    attn_weights = torch.softmax(scores, dim=-1)
    context = torch.matmul(attn_weights, V)
    context = context.transpose(1, 2)
    context = context.contiguous().view(batch_size, -1, self.d_model)
    return self.out(context)

In [19]:
#Feed Forward Network

class FeedForward(nn.Module):

  def __init__(self, d_model, d_ff):
    super().__init__()

    self.Linear1 = nn.Linear(d_model, d_ff)
    self.relu = nn.ReLU()
    self.Linear2 = nn.Linear( d_ff, d_model)

  def forward(self, x):
    return self.Linear2(self.relu(self.Linear1(x)))

In [20]:
# Transformer Block
class TransformerBlock(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super().__init__()
    self.attention = MultiHeadAttention(d_model,num_heads)

    self.norm1 = nn.LayerNorm(d_model)
    self.ffn = FeedForward(d_model, d_ff)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x):
    attn_out = self.attention(x)
    x = self.norm1(x + attn_out)

    ffn_out = self.ffn(x)
    x = self.norm2(x+ffn_out)

    return x

In [21]:
x = torch.rand(2, 5, 64)

block = TransformerBlock(d_model=64, num_heads=8, d_ff=256)

out = block(x)
print(out.shape)

torch.Size([2, 5, 64])
