<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-Transformer-From-Scratch/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Encoder-Decoder Tranformer From Scratch Using PyTorch

Implemeting The Encoder-Decoder Transformer Architecture From The 2017 Paper Published By Google ["*Attention Is All You Need* "](https://arxiv.org/abs/1706.03762)

In [None]:
!pip install datasets

In [None]:
# Importing Libraries
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer
import datasets
from datasets import load_dataset

In [None]:
class InputEmbeddings(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()

    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embeddings = nn.Embedding(
        vocab_size, d_model
    )

  def forward(self, x):
    return self.embeddings(x) * math.sqrt(self.d_model)

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_len):
    super().__init__()

    pe = torch.zeros(max_seq_len, d_model)
    position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0)/d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    return x + self.pe[:, :x.size(1)]

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()

    print(f"Initializing MultiHeadAttention with d_model={d_model} and num_heads={num_heads}")

    assert d_model % num_heads == 0, 'd_model must be divisible by num_heads.'

    self.num_heads = num_heads
    self.d_model = d_model
    self.head_dim = d_model // num_heads

    self.query_linear = nn.Linear(d_model, d_model, bias=False)
    self.key_linear = nn.Linear(d_model, d_model, bias=False)
    self.value_linear = nn.Linear(d_model, d_model, bias=False)

    self.output_linear = nn.Linear(d_model, d_model)

  def split_heads(self, x, batch_size):
    seq_len = x.size(1)
    x = x.reshape(batch_size, seq_len, self.num_heads, self.head_dim)

    return x.permute(0, 2, 1, 3)

  def compute_attention(self, query, key, value, mask=None):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)

    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)

    return torch.matmul(attention_weights, value)

  def combine_heads(self, x, batch_size):
    x = x.permute(0, 2, 1, 3).contiguous()
    return x.view(batch_size, -1, self.d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)

    query = self.split_heads(self.query_linear(q), batch_size)
    key = self.split_heads(self.key_linear(k), batch_size)
    value = self.split_heads(self.value_linear(v), batch_size)

    attention_weights = self.compute_attention(query, key, value, mask)

    output = self.combine_heads(attention_weights, batch_size)

    return self.output_linear(output)

In [None]:
class FeedForwardSubLayer(nn.Module):
  def __init__(self, d_model, d_ff):
    super().__init__()
    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_ff, d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.fc2(self.relu(self.fc1(x)))

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super().__init__()

    self.self_attn = MultiHeadAttention(d_model, num_heads)

    self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, src_mask):
    attn_output = self.self_attn(x, x, x, src_mask)

    x = self.norm1(x + self.dropout(attn_output))

    ff_output = self.ff_sublayer(x)

    x = self.norm2(x + self.dropout(ff_output))

    return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, tgt_mask, cross_mask):
        # Self-attention
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))

        # Cross-attention
        cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, cross_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))

        # Feed-forward
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))

        return x

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
    super().__init__()

    self.embedding = InputEmbeddings(vocab_size, d_model)

    self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

    self.layers = nn.ModuleList([
        EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
    ])

  def forward(self, x, src_mask):
    x = self.embedding(x)

    x = self.positional_encoding(x)

    for layer in self.layers:
      x = layer(x, src_mask)

    return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()

        self.embedding = InputEmbeddings(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, encoder_output, tgt_mask, cross_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)

        for layer in self.layers:
            x = layer(x, encoder_output, tgt_mask, cross_mask)

        x = self.fc(x)

        return x

In [None]:
class ClassificationHead(nn.Module):
  def __init__(self, d_model, num_classes):
    super().__init__()
    self.fc = nn.Linear(d_model, num_classes)

  def forward(self, x):
    logits = self.fc(x)
    return F.log_softmax(logits, dim=-1)

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout):
        super().__init__()

        self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_seq_len)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_len)

    def forward(self, src, src_mask, tgt, tgt_mask):
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(tgt, encoder_output, tgt_mask, src_mask)
        return decoder_output

In [None]:
# Load "WMT 2014 English-to-German" Dataset
dataset = load_dataset("wmt14", "de-en")

train_dataset = dataset['train']
test_dataset = dataset['test']

In [None]:
print(train_dataset.features)

In [None]:
# Tokenize the data
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

def preprocess(batch):
    src_texts = [example["en"] for example in batch["translation"]]
    tgt_texts = [example["de"] for example in batch["translation"]]

    src = tokenizer(
        src_texts, padding="max_length",
        truncation=True, max_length=128, return_tensors="pt"
    )
    tgt = tokenizer(
        tgt_texts, padding="max_length",
        truncation=True, max_length=128, return_tensors="pt"
    )

    return {
        "src_input_ids": src["input_ids"].tolist(),
        "tgt_input_ids": tgt["input_ids"].tolist()
    }



train_data = train_dataset.map(preprocess, batched=True, batch_size=64)

In [None]:
print(train_data[0])

In [None]:
# Define collate_fn to pad sequences dynamically
def collate_fn(batch):
    src_batch = [torch.tensor(item["src_input_ids"]) for item in batch]
    tgt_batch = [torch.tensor(item["tgt_input_ids"]) for item in batch]

    # Pad sequences in each batch
    src_batch = torch.nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=tokenizer.pad_token_id)
    tgt_batch = torch.nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tokenizer.pad_token_id)

    return src_batch, tgt_batch

train_dataloader = DataLoader(train_data, batch_size=32, collate_fn=collate_fn, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Define Hyperparameters
vocab_size = tokenizer.vocab_size
d_model = 504
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_len = 128
dropout = 0.1


model = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)
model.to(device)

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

In [None]:
# Define function to train the model
def train(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        for src_batch, tgt_batch in dataloader:
            src_batch, tgt_batch = src_batch.to("cuda"), tgt_batch.to("cuda")

            # Create masks
            src_mask = (src_batch != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)

            # Shift target for teacher forcing
            tgt_input = tgt_batch[:, :-1]
            tgt_output = tgt_batch[:, 1:]

            # Create tgt_mask for tgt_input
            tgt_mask = (tgt_input != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)

            # Forward pass
            outputs = model(src_batch, src_mask, tgt_input, tgt_mask)
            loss = criterion(outputs.reshape(-1, vocab_size), tgt_output.reshape(-1))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(dataloader):.4f}")

In [None]:
train(model, train_dataloader, criterion, optimizer, num_epochs=10)