<a href="https://colab.research.google.com/github/a-mhamdi/deep_learning_nlp/blob/main/03_attention_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention Mechanisms and Advanced Architectures
---

## Outlines
1. [Attention mechanisms](#att-mech)
1. [Transformer architecture basics](#transform-arch)
1. [Self-attention](#self-att)
1. [Multi-head attention](#multi-head-attention)

## Attention mechanisms <a name="att-mech"></a>


## Transformer architecture basics <a name="transform-arch"></a>


## Self-attention <a name="self-att"></a>


## Multi-head attention <a name="multi-head-attention"></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import random
from collections import Counter


Simple positional encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=100):
        super().__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x


Sequence-to-sequence translation model using Transformer

In [None]:
class TransformerTranslator(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers,
                 num_decoder_layers, dim_feedforward, dropout=0.1):
        super().__init__()

        # Source and target embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # Positional encoding layer
        self.positional_encoding = PositionalEncoding(d_model)

        # Transformer layers
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )

        # Final linear layer
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)

        # Initialize parameters
        self.init_weights()

        # Save model dimensions
        self.d_model = d_model
        self.tgt_vocab_size = tgt_vocab_size

    def init_weights(self):
        initrange = 0.1
        self.src_embedding.weight.data.uniform_(-initrange, initrange)
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)
        self.output_layer.bias.data.zero_()
        self.output_layer.weight.data.uniform_(-initrange, initrange)

    def create_masks(self, src, tgt):
        # Create padding masks
        src_padding_mask = (src == 0).to(src.device)
        tgt_padding_mask = (tgt == 0).to(tgt.device)

        # Create causal mask for decoder (to prevent attention to future tokens)
        tgt_seq_len = tgt.size(1)
        tgt_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len), diagonal=1).bool().to(tgt.device)

        return src_padding_mask, tgt_mask, tgt_padding_mask

    def forward(self, src, tgt):
        # Shape of src/tgt: [batch_size, seq_len]

        # Create masks
        src_padding_mask, tgt_mask, tgt_padding_mask = self.create_masks(src, tgt)

        # Embedding and positional encoding for source sequence
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        src_emb = self.positional_encoding(src_emb)

        # Embedding and positional encoding for target sequence
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.positional_encoding(tgt_emb)

        # Transpose for PyTorch Transformer: [batch_size, seq_len, d_model] -> [seq_len, batch_size, d_model]
        src_emb = src_emb.transpose(0, 1)
        tgt_emb = tgt_emb.transpose(0, 1)

        # Pass through transformer
        output = self.transformer(
            src_emb, tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )

        # Transpose back: [seq_len, batch_size, d_model] -> [batch_size, seq_len, d_model]
        output = output.transpose(0, 1)

        # Pass through final linear layer
        output = self.output_layer(output)

        return output

    def greedy_decode(self, src, max_len, start_symbol, end_symbol):
        """Perform greedy decoding for inference"""
        batch_size = src.size(0)
        device = src.device

        # Encode the source sequence
        src_padding_mask = (src == 0).to(device)
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        src_emb = self.positional_encoding(src_emb)
        src_emb = src_emb.transpose(0, 1)

        memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)

        # Initialize target with start symbol
        ys = torch.ones(batch_size, 1).fill_(start_symbol).long().to(device)

        for i in range(max_len-1):
            # Decode one token at a time
            tgt_mask = torch.triu(torch.ones((i+1, i+1)), diagonal=1).bool().to(device)
            tgt_padding_mask = (ys == 0).to(device)

            tgt_emb = self.tgt_embedding(ys) * math.sqrt(self.d_model)
            tgt_emb = self.positional_encoding(tgt_emb)
            tgt_emb = tgt_emb.transpose(0, 1)

            output = self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask,
                                             tgt_key_padding_mask=tgt_padding_mask)
            output = output.transpose(0, 1)
            output = self.output_layer(output)

            # Get next token
            prob = output[:, -1]
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.unsqueeze(1)

            # Concatenate with output so far
            ys = torch.cat([ys, next_word], dim=1)

            # Check if all sequences have reached the end symbol
            if (next_word == end_symbol).all():
                break

        return ys

Translation dataset

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences, src_vocab=None, tgt_vocab=None,
                 max_vocab_size=10000, max_seq_length=100):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.max_seq_length = max_seq_length

        # Special tokens
        self.PAD_IDX = 0
        self.SOS_IDX = 1
        self.EOS_IDX = 2
        self.UNK_IDX = 3

        # Build vocabularies if not provided
        if src_vocab is None:
            self.src_vocab = self.build_vocab(src_sentences, max_vocab_size)
        else:
            self.src_vocab = src_vocab

        if tgt_vocab is None:
            self.tgt_vocab = self.build_vocab(tgt_sentences, max_vocab_size)
        else:
            self.tgt_vocab = tgt_vocab

    def build_vocab(self, sentences, max_vocab_size):
        # Count word frequencies
        counter = Counter()
        for sentence in sentences:
            counter.update(sentence.lower().split())

        # Sort by frequency
        words = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        words = words[:max_vocab_size-4]  # Leave room for special tokens

        # Create vocab dictionary
        vocab = {
            '<pad>': self.PAD_IDX,
            '<sos>': self.SOS_IDX,
            '<eos>': self.EOS_IDX,
            '<unk>': self.UNK_IDX
        }

        for i, (word, _) in enumerate(words):
            vocab[word] = i + 4

        return vocab

    def sentence_to_indices(self, sentence, vocab):
        # Convert sentence to sequence of indices
        words = sentence.lower().split()[:self.max_seq_length-2]  # Leave room for SOS/EOS
        indices = [self.SOS_IDX]
        indices.extend([vocab.get(word, self.UNK_IDX) for word in words])
        indices.append(self.EOS_IDX)
        return indices

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_sentence = self.src_sentences[idx]
        tgt_sentence = self.tgt_sentences[idx]

        src_indices = self.sentence_to_indices(src_sentence, self.src_vocab)
        tgt_indices = self.sentence_to_indices(tgt_sentence, self.tgt_vocab)

        return torch.tensor(src_indices), torch.tensor(tgt_indices)


Collate function for padding sequences to the same length

In [None]:
def collate_translation_batch(batch):
    src_sequences, tgt_sequences = zip(*batch)

    # Pad sequences
    src_padded = pad_sequence(src_sequences, batch_first=True, padding_value=0)
    tgt_padded = pad_sequence(tgt_sequences, batch_first=True, padding_value=0)

    return src_padded, tgt_padded


Train function

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)

        # For teacher forcing, use all but last token of target as input
        tgt_input = tgt[:, :-1]
        # Use all but first token of target as output
        tgt_output = tgt[:, 1:]

        # Forward pass
        output = model(src, tgt_input)

        # Reshape for loss calculation
        output = output.reshape(-1, output.size(-1))
        tgt_output = tgt_output.reshape(-1)

        # Calculate loss
        loss = criterion(output, tgt_output)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()

        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


Evaluation function

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            output = model(src, tgt_input)
            output = output.reshape(-1, output.size(-1))
            tgt_output = tgt_output.reshape(-1)

            loss = criterion(output, tgt_output)
            total_loss += loss.item()

    return total_loss / len(dataloader)


Translation function

In [None]:
def translate(model, src_sentence, src_vocab, tgt_vocab, max_len=50, device='cpu'):
    model.eval()

    # Convert sentence to indices
    idx_to_tgt = {v: k for k, v in tgt_vocab.items()}
    PAD_IDX, SOS_IDX, EOS_IDX, UNK_IDX = 0, 1, 2, 3

    # Tokenize and convert to indices
    tokenized = src_sentence.lower().split()
    src_indices = [SOS_IDX]
    src_indices.extend([src_vocab.get(token, UNK_IDX) for token in tokenized])
    src_indices.append(EOS_IDX)

    # Convert to tensor
    src_tensor = torch.tensor([src_indices]).to(device)

    # Generate translation
    with torch.no_grad():
        output_indices = model.greedy_decode(src_tensor, max_len, SOS_IDX, EOS_IDX).squeeze(0)

    # Convert back to words
    output_tokens = []
    for idx in output_indices:
        token = idx_to_tgt.get(idx.item(), '<unk>')
        if token == '<eos>':
            break
        if token not in ['<sos>', '<pad>']:
            output_tokens.append(token)

    return ' '.join(output_tokens)


Main function to run the translator

In [None]:
def main():
    # Example parallel corpus (English to French)
    src_sentences = [
        "hello how are you",
        "I am a student",
        "where is the library",
        "the book is on the table",
        "I like to read books"
    ]

    tgt_sentences = [
        "bonjour comment allez vous",
        "je suis un étudiant",
        "où est la bibliothèque",
        "le livre est sur la table",
        "j'aime lire des livres"
    ]

    # Create dataset and dataloader
    dataset = TranslationDataset(src_sentences, tgt_sentences)
    train_dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_translation_batch, shuffle=True)

    # Model parameters
    src_vocab_size = len(dataset.src_vocab)
    tgt_vocab_size = len(dataset.tgt_vocab)
    d_model = 128
    nhead = 4
    num_encoder_layers = 3
    num_decoder_layers = 3
    dim_feedforward = 512
    dropout = 0.1

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model
    model = TransformerTranslator(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=d_model,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dim_feedforward=dim_feedforward,
        dropout=dropout
    ).to(device)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding tokens
    optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

    # Training loop
    num_epochs = 100
    print(f"Starting training on device: {device}")

    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}")

            # Translation example
            test_sentence = "I am learning to translate"
            translation = translate(model, test_sentence, dataset.src_vocab, dataset.tgt_vocab, device=device)
            print(f"Source: {test_sentence}")
            print(f"Translation: {translation}")
            print("-" * 50)

    # Test with some examples
    print("\nFinal Translations:")
    test_sentences = [
        "hello my friend",
        "I want to learn French",
        "the cat is black"
    ]

    for sentence in test_sentences:
        translation = translate(model, sentence, dataset.src_vocab, dataset.tgt_vocab, device=device)
        print(f"Source: {sentence}")
        print(f"Translation: {translation}")
        print("-" * 30)


In [None]:
if __name__ == "__main__":
    main()