# Deep Learning Assignment: The Transformer

### Instructions:
1.  **File > Save a copy in Drive** to create your own editable version of this notebook.
2.  Read the instructions in the Markdown cells.
3.  Implement your code in the designated "YOUR CODE HERE" cells.
4.  Answer the analysis questions in the Markdown cells at the end.
5.  **Deliverable:** Share your completed Colab notebook.

### Assignment Objectives
* **Understand & Implement** the core "heart" of the Transformer: Multi-Head Attention.
* **Learn** to debug a neural network by "seeing" what it "sees" (analyzing attention weights).
* **Appreciate** the power of abstraction by using the Hugging Face library.
* **Connect** the "from scratch" theory to the "in-practice" library.

###  Task 0: Warm-up & Context

**1. Watch a Video:**
To understand the architecture you're about to build, watch a detailed visual explanation of the Transformer.
* **Recommended Video:** **"The Illustrated Transformer by Jay Alammar"** (either the article or a video based on it) or [course from Huggingface](https://huggingface.co/learn/llm-course/en/chapter1/4).

**2. Write a Reflection (Markdown Cell):**
In the cell below, write a 1-paragraph reflection:
* "What is the key idea of 'self-attention'? Based on the video, why was this a significant change from models like RNNs and LSTMs?"

**Reflection:** Jay Alammar's Illustrated Transformer helped me see how self-attention lets every token attend to the full sentence without the sequential bottleneck of RNNs. The query/key/value breakdown and scaled dot-product softmax made it clearer how the model learns to focus on syntactic roles (like subject vs. object) while still capturing long-range dependencies. I also appreciated how multi-head attention acts like several parallel lenses, each specializing in different patterns (positions, agreement, emphasis) and then recombining them. The residual connections, layer norm, and positional encodings felt less like implementation details and more like the glue that keeps deep stacks stable and order-aware. Overall, the architecture feels elegant: simple primitives plus good engineering to make training fast, parallel, and expressive.


### Setup: Imports and Installs

Run this cell to install and import all necessary libraries.

In [1]:
!pip install -q datasets transformers torch
!pip install -q nltk

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import nltk
from nltk.corpus import movie_reviews
import random

# For Part C
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

# Helper to set seeds for reproducible debugging
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


### Part A: Build MultiHeadAttention From Scratch

Your goal is to implement the `MultiHeadAttention` module.

#### 1. Implement `ScaledDotProductAttention`

This is the core function. Implement the formula:
$Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + \text{mask}\right)V$

* Use `-1e9` for the mask value (a large negative number).
* The function should return **both** the `output` and the attention `weights`.

In [3]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Implements the Scaled Dot-Product Attention.
    Args:
        Q (torch.Tensor): Queries (batch_size, num_heads, seq_len_q, head_dim)
        K (torch.Tensor): Keys (batch_size, num_heads, seq_len_k, head_dim)
        V (torch.Tensor): Values (batch_size, num_heads, seq_len_v, head_dim)
                         (seq_len_k and seq_len_v are the same)
        mask (torch.Tensor, optional): Mask to apply. (batch_size, 1, 1, seq_len_k)
    Returns:
        output (torch.Tensor): The context vector.
        attn_weights (torch.Tensor): The attention weights.
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output, attn_weights


#### 2. Implement `MultiHeadAttention`

Now, create a module that uses your function to run multiple heads in parallel.

**Crucial Instruction:** Your module's `forward` method must return **two** values:
1.  `context_vector` (the final output of the module)
2.  `attention_weights` (the weights from the scaled dot-product attention)

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, head_dim)"""
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return x.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
    def forward(self, x, mask=None):
        """
        Args:
            x (torch.Tensor): Input (batch_size, seq_len, embed_dim)
            mask (torch.Tensor, optional): Mask for attention.
        Returns:
            context_vector (torch.Tensor): Final output
            attention_weights (torch.Tensor): Attention weights
        """
        batch_size = x.size(0)
        Q = self.split_heads(self.W_q(x), batch_size)
        K = self.split_heads(self.W_k(x), batch_size)
        V = self.split_heads(self.W_v(x), batch_size)
        context, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        context_vector = self.W_o(context)
        return context_vector, attn_weights


#### 3. Test Your Module

Run this cell. If your implementation is correct, it will run without errors and print the correct output shape.

In [5]:
# Unit Test
set_seed(42)
test_mha = MultiHeadAttention(embed_dim=64, num_heads=8).to(device)
dummy_input = torch.rand(4, 10, 64).to(device)  # (batch_size, seq_len, embed_dim)
context_vec, attn_weights = test_mha(dummy_input)

# Check shapes
assert context_vec.shape == (4, 10, 64), f"Context vector shape is {context_vec.shape}"
assert attn_weights.shape == (4, 8, 10, 10), f"Attn weights shape is {attn_weights.shape}"

print("Part A: MultiHeadAttention Test Passed!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output context shape: {context_vec.shape}")
print(f"Output weights shape: {attn_weights.shape}")

Part A: MultiHeadAttention Test Passed!
Input shape: torch.Size([4, 10, 64])
Output context shape: torch.Size([4, 10, 64])
Output weights shape: torch.Size([4, 8, 10, 10])


### Part B: Task (Debugging)

**Your Task:** I have provided a `BuggyTransformerEncoderLayer`. It uses your `MultiHeadAttention` module. This code **runs** but **fails to learn**.

Your job is to:
1.  Run the buggy training loop.
2.  Use `torch.hooks` to "spy" on your `MultiHeadAttention` module.
3.  Extract and plot the attention weights.
4.  Analyze the plot and identify the bug in the code.

#### 1. The Buggy Code (Read Only)

Do not edit this cell. This is the broken code I am providing.

In [None]:
class BuggyTransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
        super(BuggyTransformerEncoderLayer, self).__init__()

        # We use your MultiHeadAttention module
        self.mha = MultiHeadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.ReLU(),
            nn.Linear(ffn_dim, embed_dim)
        )

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # --- This is the buggy part ---
        # Find the one-line error in this forward pass.

        attn_output, attn_weights = self.mha(x, mask)
        x = self.dropout1(attn_output)
        x = self.norm1(x)  # <--- HINT: Something is missing here!

        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output) # <--- And something is wrong here!
        x = self.norm2(x)

        return x, attn_weights # Pass weights through for analysis

#### 2. Data and Training Helpers (Read Only)

This cell contains the helper functions to load data and run the training.

In [None]:
# Download NLTK data (if not already)
nltk.download('movie_reviews')

def get_nltk_data():
    documents = [(list(movie_reviews.words(fileid)), category)
                 for category in movie_reviews.categories()
                 for fileid in movie_reviews.fileids(category)]
    random.shuffle(documents)
    return documents

# Simple vocabulary and data processing
class Vocab:
    def __init__(self, documents):
        all_words = [w.lower() for doc, cat in documents for w in doc]
        self.word_freq = nltk.FreqDist(all_words)
        self.vocab = {word: i+2 for i, (word, freq) in enumerate(self.word_freq.most_common(3000))}
        self.vocab['<PAD>'] = 0
        self.vocab['<UNK>'] = 1

    def tokenize(self, doc, max_len=100):
        tokens = [self.vocab.get(w.lower(), self.vocab['<UNK>']) for w in doc]
        tokens = tokens[:max_len]
        tokens = tokens + [self.vocab['<PAD>']] * (max_len - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

def get_dataloaders(batch_size=32):
    documents = get_nltk_data()
    train_docs, test_docs = documents[:1800], documents[1800:]
    vocab = Vocab(documents)

    train_data = [(vocab.tokenize(doc), 1 if cat == 'pos' else 0) for doc, cat in train_docs]
    test_data = [(vocab.tokenize(doc), 1 if cat == 'pos' else 0) for doc, cat in test_docs]

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
    return train_loader, test_loader, len(vocab.vocab)

# The full (buggy) classifier
class BuggyClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ffn_dim):
        super(BuggyClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # We use the BuggyEncoderLayer
        self.encoder = BuggyTransformerEncoderLayer(embed_dim, num_heads, ffn_dim)
        # self.pooler = nn.Linear(embed_dim, 1) # Simple pooler
        self.classifier = nn.Linear(embed_dim, 2)
        self.embed_dim = embed_dim

    def forward(self, x):
        # x shape: (batch_size, seq_len=100)
        padding_mask = (x == 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_len)

        x = self.embedding(x) * math.sqrt(self.embed_dim)
        # Positional encoding is skipped to make the bug more obvious

        x, attn_weights = self.encoder(x, mask=padding_mask)

        # Pool by taking the first 100 tokens
        # x = x.view(x.size(0), -1)
        # x = x[:, :100 * self.embed_dim // self.embed_dim] # Hacky way to get (batch_size, 100)

        # x = F.relu(self.pooler(x.view(x.size(0), self.embed_dim, 100))) # This is a mess
        # x = x.view(x.size(0), -1)

        x = x.mean(dim=1)
        logits = self.classifier(x)
        return logits, attn_weights

# Training Loop
def train(model, loader, optimizer, criterion):
    model.train()
    for data, labels in loader:
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(data)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            logits, _ = model(data)
            pred = logits.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    return correct / total

#### 3. Run the Buggy Training (and inspect the problem)

This cell trains the buggy model. Notice how the accuracy is stuck around 50%? This means it's not learning, it's just guessing.

In [None]:
set_seed(42)
train_loader, test_loader, vocab_size = get_dataloaders()
buggy_model = BuggyClassifier(vocab_size, 64, 8, 128).to(device)
optimizer = torch.optim.Adam(buggy_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

print("--- Training Buggy Model ---")
for epoch in range(3):
    train(buggy_model, train_loader, optimizer, criterion)
    acc = evaluate(buggy_model, test_loader)
    print(f"Epoch {epoch+1}, Test Accuracy: {acc:.4f}")

#### 4. Your Task: Debug with Hooks

Fill in the code below to:
1.  Instantiate a new `BuggyClassifier`.
2.  Register a `forward_hook` on the `MultiHeadAttention` module inside it.
3.  The hook should save the `attn_weights` (the 2nd item returned by your MHA's forward pass) into the `attention_storage` list.
4.  Run training for **one** epoch to populate the `attention_storage`.
5.  Plot the attention weights for the first head of the first batch.

In [None]:
import matplotlib.pyplot as plt

set_seed(42)
# 1. Instantiate a new model, optimizer, and criterion
debug_model = BuggyClassifier(vocab_size, 64, 8, 128).to(device)
optimizer = torch.optim.Adam(debug_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# 2. Set up the hook
attention_storage = []
def hook_fn(module, input, output):
    # output is (context_vector, attention_weights)
    # We want to save the attention_weights

    # YOUR CODE HERE (1 line)


# 3. Register the hook on the MHA module
#    (Hint: debug_model.encoder.mha)
# YOUR CODE HERE (1 line)



# 4. Run training for one epoch
print("Running one epoch of training to capture weights...")
train(debug_model, train_loader, optimizer, criterion)
print("Training done.")


# 5. Plot the attention weights
if attention_storage:
    print("Plotting attention weights from the first batch...")
    # Get weights from the first batch: (batch_size, num_heads, seq_len, seq_len)
    first_batch_weights = attention_storage[0]

    # Get weights for the first item in the batch, first head
    # (seq_len, seq_len)
    weights_to_plot = first_batch_weights[0, 0, :, :].numpy()

    plt.figure(figsize=(8, 8))
    plt.imshow(weights_to_plot)
    plt.title("Attention Weights (First Batch, First Head)")
    plt.xlabel("Key (Token Position)")
    plt.ylabel("Query (Token Position)")
    plt.colorbar(label="Attention Weight")
    plt.show()
else:
    print("Hook failed to capture attention weights. Check your hook implementation.")

#### 5. Analysis & The Bug

Based on the plot above (which will likely look very strange or random, not focused), answer the following:

<div style="border: 2px dashed #ccc; padding: 10px; background-color: #f9f9f9;">
    
<b>YOUR ANALYSIS HERE:</b>

**1. Analyze the Plot:**
<i>(Double-click to edit. Describe what you see in the plot. Is the attention focused? Is it random? Is it attending to padding tokens? What does this tell you about the model's ability to learn?)</i>

**2. Identify the Bug:**
<i>(Look at the <code>BuggyTransformerEncoderLayer</code> code. There are two one-line bugs related to residual connections and layer norm. Find them.)</i>

* **Bug 1 (Line 23):**
* **Bug 2 (Line 26):**

**3. Explain the Bug:**
<i>(Why do these bugs (especially the missing residual connection) cause the model to fail and produce the attention plot you see? What is a residual connection <i>for</i>?)</i>

</div>

### Part C: Use the Abstraction (Hugging Face)

Now, let's solve the *same problem* (NLTK movie reviews) with Hugging Face.

In [None]:
# 1. Load the NLTK data into a simple list format
set_seed(42)
nltk.download('movie_reviews')

def get_hf_dataset():
    documents = []
    labels = []
    for fileid in movie_reviews.fileids('pos'):
        documents.append(movie_reviews.raw(fileid))
        labels.append(1) # 'pos'
    for fileid in movie_reviews.fileids('neg'):
        documents.append(movie_reviews.raw(fileid))
        labels.append(0) # 'neg'

    # Create a datasets.Dataset
    from datasets import Dataset, train_test_split_dict
    dataset = Dataset.from_dict({'text': documents, 'label': labels})

    # Split into train/test
    dataset = dataset.train_test_split(test_size=0.1, seed=42)
    return dataset['train'], dataset['test']

train_dataset, test_dataset = get_hf_dataset()
print(f"Hugging Face train dataset: {train_dataset}")

# 2. Load Tokenizer and Model
# YOUR CODE HERE
# Use "distilbert-base-uncased" as the model_name


# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the model (for sequence classification, 2 labels)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)


# 3. Create a preprocessing function
def tokenize_fn(batch):
    # YOUR CODE HERE
    # Use the tokenizer on the 'text' field.
    # Make sure to truncate!
    return # return tokenizer

# 4. Tokenize the datasets
# YOUR CODE HERE



# 5. Set up Trainer
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,  # 1 epoch is fine for this demo
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none", # Disable wandb/tensorboard reporting
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": (predictions == labels).mean()}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_tok,
    eval_dataset=test_dataset_tok,
    compute_metrics=compute_metrics,
)

# 6. Train!
print("\n--- Training Hugging Face Model ---")
trainer.train()

# 7. Evaluate
print("\n--- Evaluating Hugging Face Model ---")
eval_results = trainer.evaluate()
print(f"Final Test Accuracy (Hugging Face): {eval_results['eval_accuracy']:.4f}")

### Part D: Final Analysis & Reflection

Answer the final questions in this Markdown cell.

<div style="border: 2px dashed #ccc; padding: 10px; background-color: #f9f9f9;">
    
<b>YOUR FINAL ANALYSIS:</b>

**1. Final Accuracy:**
<i>(What was the final test accuracy you achieved in Part C with Hugging Face?)</i>


**2. Debugging vs. Building:**
<i>(What did you learn from the <i>debugging</i> task in Part B that you wouldn't have learned from just <i>implementing</i> in Part A?)</i>


**3. Connecting the Pieces:**
<i>(In Part A, you built <i>one</i> `MultiHeadAttention` module. The `distilbert-base-uncased` model you used in Part C has <b>6 layers</b> and <b>12 attention heads</b> (per layer). This means it's using <b>6</b> of the <i>type</i> of module you built, and each one is 12-headed. How does this make you feel about the complexity of modern models?)</i>


**4. Final Thought:**
<i>(Why is it valuable to build a component like `MultiHeadAttention` from scratch, even if you will almost always use a library like Hugging Face in practice?)</i>

</div>