<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-Transformer-From-Scratch/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Encoder-Decoder Tranformer From Scratch Using PyTorch

Implemeting The Encoder-Decoder Transformer Architecture From The 2017 Paper Published By Google ["*Attention Is All You Need* "](https://arxiv.org/abs/1706.03762)

In [12]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [13]:
# Importing Libraries
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer
import datasets
from datasets import load_dataset

In [2]:
class InputEmbeddings(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()

    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embeddings = nn.Embedding(
        vocab_size, d_model
    )

  def forward(self, x):
    return self.embeddings(x) * math.sqrt(self.d_model)

In [3]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_len):
    super().__init__()

    pe = torch.zeros(max_seq_len, d_model)
    position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0)/d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    return x + self.pe[:, :x.size(1)]

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()

    assert d_model % num_heads == 0, 'd_model must be divisible by num_heads.'

    self.num_heads = num_heads
    self.d_model = d_model
    self.head_dim = d_model // num_heads

    self.query_linear = nn.Linear(d_model, d_model, bias=False)
    self.key_linear = nn.Linear(d_model, d_model, bias=False)
    self.value_linear = nn.Linear(d_model, d_model, bias=False)

    self.output_linear = nn.Linear(d_model, d_model)

  def split_heads(self, x, batch_size):
    seq_len = x.size(1)
    x = x.reshape(batch_size, seq_len, self.num_heads, self.head_dim)

    return x.permute(0, 2, 1, 3)

  def compute_attention(self, query, key, value, mask=None):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)

    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)

    return torch.matmul(attention_weights, value)

  def combine_heads(self, x, batch_size):
    x = x.permute(0, 2, 1, 3).contiguous()
    return x.view(batch_size, -1, self.d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)

    query = self.split_heads(self.query_linear(q), batch_size)
    key = self.split_heads(self.key_linear(k), batch_size)
    value = self.split_heads(self.value_linear(v), batch_size)

    attention_weights = self.compute_attention(query, key, value, mask)

    output = self.combine_heads(attention_weights, batch_size)

    return self.output_linear(output)

In [5]:
class FeedForwardSubLayer(nn.Module):
  def __init__(self, d_model, d_ff):
    super().__init__()
    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_ff, d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.fc2(self.relu(self.fc1(x)))

In [7]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super().__init__()

    self.self_attn = MultiHeadAttention(d_model, num_heads)

    self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, src_mask):
    attn_output = self.self_attn(x, x, x, src_mask)

    x = self.norm1(x + self.dropout(attn_output))

    ff_output = self.ff_sublayer(x)

    x = self.norm2(x + self.dropout(ff_output))

    return x

class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super().__init__()

    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.cross_attn = MultiHeadAttention(d_model, num_heads)

    self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, y, tgt_mask, cross_mask):
    self_attn_output = self.self_attn(x, x, x, tgt_mask)

    x = self.norm1(x + self.dropout(self_attn_output))

    cross_attn_output = self.cross_attn(x, y, y, cross_mask)

    x = self.norm2(x + self.dropout(cross_attn_output))

    ff_output = self.ff_sublayer(x)

    x = self.norm3(x + self.dropout(ff_output))

    return x

In [8]:
class TransformerEncoder(nn.Module):
  def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
    super().__init__()

    self.embedding = InputEmbeddings(vocab_size, d_model)

    self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

    self.layers = nn.ModuleList([
        EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
    ])

  def forward(self, x, src_mask):
    x = self.embedding(x)

    x = self.positional_encoding(x)

    for layer in self.layers:
      x = layer(x, src_mask)

    return x

class TransformerDecoder(nn.Module):
  def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
    super().__init__()

    self.embedding = InputEmbeddings(vocab_size, d_model)

    self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

    self.layers = nn.ModuleList([
        DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
    ])

    self.fc = nn.Linear(d_model, vocab_size)

  def forward(self, x, tgt_mask):
    x = self.embedding(x)

    x = self.positional_encoding(x)

    for layer in self.layers:
      x = layer(x, tgt_mask)

    x = self.fc(x)

    return x

In [9]:
class ClassificationHead(nn.Module):
  def __init__(self, d_model, num_classes):
    super().__init__()
    self.fc = nn.Linear(d_model, num_classes)

  def forward(self, x):
    logits = self.fc(x)
    return F.log_softmax(logits, dim=-1)

In [10]:
class Transformer(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout):
    super().__init__()

    self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_seq_len)

    self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_len)

  def forward(self, x, src_mask, tgt_mask, cross_mask):
    encoder_output = self.encoder(x, src_mask)
    decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask)

    return decoder_output