In [None]:
# --- 1. Library Imports ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Optional: For visualization
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# --- 2. Global Hyperparameters ---
# Model Dimensions (Keep these small for the sprint)
D_MODEL = 64        # Embedding dimension
NUM_HEADS = 4       # Number of attention heads
NUM_LAYERS = 1      # Use 1 Encoder and 1 Decoder layer for simplicity
D_FF = 128          # Feed-Forward hidden dimension

# Training Parameters
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50 

# Sequence Parameters
MAX_SEQ_LEN = 10    # Max tokens in any sequence

In [None]:
# --- 1. Vocabulary Definition and Mapping ---
# The fixed, small vocabulary for our lookup table
VOCAB = {
    '[PAD]': 0, 
    '[SOS]': 1, 
    '[EOS]': 2,
    
    # Keys
    'COLOR': 3, 'NAME': 4, 'ITEM': 5, 
    
    # Values
    'RED': 6, 'BLUE': 7, 'GREEN': 8,
    'ALICE': 9, 'BOB': 10, 'CHAIR': 11,
    
    # Query Tokens
    'QUERY': 12, 'FIND': 13
}
VOCAB_SIZE = len(VOCAB)
ID_TO_TOKEN = {v: k for k, v in VOCAB.items()}

# --- 2. The Lookup Dictionary (Ground Truth) ---
# Defines all the possible fact pairings the model must learn.
FACTS_DICT = {
    # Key Token ID: Value Token ID
    VOCAB['RED']: VOCAB['BLUE'],    # Red item's value is 'BLUE'
    VOCAB['ALICE']: VOCAB['GREEN'], # Alice's value is 'GREEN'
    VOCAB['CHAIR']: VOCAB['RED']    # Chair's value is 'RED'
}
FACT_PAIRS = list(FACTS_DICT.items()) # List of (Key_ID, Value_ID)

# --- 3. Synthetic Data Generation Function ---
def generate_synthetic_data(num_samples):
    data = []
    
    for _ in range(num_samples):
        # 1. Randomly select facts for the Encoder input (The Dictionary)
        # Select 2 distinct facts to ensure context is required.
        fact1_key, fact1_val = FACT_PAIRS[np.random.randint(len(FACT_PAIRS))]
        
        # 2. Randomly select one fact to be the target (The Query)
        query_key_id = fact1_key
        target_val_id = fact1_val
        
        # Construct Encoder Input (X): The facts (order is randomized later to test no-PE logic)
        enc_input_tokens = [VOCAB['COLOR'], query_key_id, VOCAB['ITEM'], fact1_val] 
        np.random.shuffle(enc_input_tokens) # Randomize order: crucial for no-PE test!
        
        # Add a final token to indicate end of fact list
        enc_input_tokens = [VOCAB['NAME']] + enc_input_tokens 
        
        # Construct Decoder Target (Y): The value we want
        # The target sequence is [Value, EOS]
        dec_target_tokens = [target_val_id, VOCAB['EOS']]
        
        data.append({
            'enc_input': torch.tensor(enc_input_tokens, dtype=torch.long),
            'dec_target': torch.tensor(dec_target_tokens, dtype=torch.long),
            'query_val_id': target_val_id # Used for checking/visualization
        })

    return data

# --- 4. PyTorch Dataset and DataLoader Setup ---
class LookupDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        # We need to return encoder input, decoder input (starts with SOS), and true target
        item = self.data[idx]
        
        # Decoder input is the target sequence shifted right (starts with [SOS])
        dec_input = torch.cat([torch.tensor([VOCAB['SOS']]), item['dec_target'][:-1]])
        
        # The true target is the original target, which excludes the [SOS] token
        true_target = item['dec_target']
        
        return item['enc_input'], dec_input, true_target

# Collate function to handle padding for the DataLoader
def collate_fn(batch):
    # Padding sequences to MAX_SEQ_LEN
    enc_inputs = [item[0] for item in batch]
    dec_inputs = [item[1] for item in batch]
    dec_targets = [item[2] for item in batch]
    
    # Use torch.nn.utils.rnn.pad_sequence for consistent padding
    enc_inputs = nn.utils.rnn.pad_sequence(enc_inputs, batch_first=True, padding_value=VOCAB['[PAD]'])
    dec_inputs = nn.utils.rnn.pad_sequence(dec_inputs, batch_first=True, padding_value=VOCAB['[PAD]'])
    dec_targets = nn.utils.rnn.pad_sequence(dec_targets, batch_first=True, padding_value=VOCAB['[PAD]'])
    
    # Create padding masks (True means mask, False means keep)
    src_padding_mask = (enc_inputs == VOCAB['[PAD]'])
    tgt_padding_mask = (dec_inputs == VOCAB['[PAD]'])
    
    return enc_inputs, dec_inputs, dec_targets, src_padding_mask, tgt_padding_mask

# Generate and split data
full_data = generate_synthetic_data(num_samples=5000)
train_split = int(0.8 * 5000)
val_split = int(0.1 * 5000)

train_data = full_data[:train_split]
val_data = full_data[train_split:train_split + val_split]
test_data = full_data[train_split + val_split:]

# Create DataLoaders
train_loader = DataLoader(LookupDataset(train_data), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(LookupDataset(val_data), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(LookupDataset(test_data), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print(f"Data ready: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")

In [None]:
# --- 1. Scaled Dot-Product Attention Function (Helper) ---
# Function that calculates QK^T / sqrt(dk) and applies the mask.
# NOTE: Positional Encoding (PE) is deliberately omitted in the final Encoder!

# --- 2. Multi-Head Attention (MHA) Class ---
class MultiHeadAttention(nn.Module):
    # Initializes Q, K, V projections and implements the core logic:
    # Split heads, calculate attention scores, concatenate, final linear layer.
    pass

# --- 3. Feed-Forward Network (FFN) Class ---
class PositionWiseFeedForward(nn.Module):
    # Implements the two-layer linear transformation (Linear -> ReLU -> Linear).
    pass

In [None]:
# --- 1. The Encoder Layer ---
class EncoderLayer(nn.Module):
    # Combines Self-Attention (MHA) and FFN.
    # CRITICAL STEP: No PE is added here, proving content-addressability.
    # Structure: (Input -> MHA -> Add&Norm) -> (Input -> FFN -> Add&Norm)
    pass

# --- 2. The Decoder Layer ---
class DecoderLayer(nn.Module):
    # Structure:
    # a. Masked Self-Attention (Input -> Masked MHA -> Add&Norm)
    # b. CROSS-ATTENTION (Input -> Cross MHA -> Add&Norm) 
    # c. FFN (Input -> FFN -> Add&Norm)
    pass

In [None]:
# --- 1. Full Encoder Module ---
class Encoder(nn.Module):
    # Stacks 1 or 2 EncoderLayers. Adds the initial token embedding.
    pass

# --- 2. Full Decoder Module ---
class Decoder(nn.Module):
    # Stacks 1 or 2 DecoderLayers. Handles the initial token embedding and PE addition.
    # NOTE: PE is typically required in the Decoder for temporal/sequential generation.
    pass

# --- 3. CrossAttentionLookupModel (The Final Model) ---
class CrossAttentionLookupModel(nn.Module):
    # Initializes Encoder, Decoder, and the final Projection Layer (maps D_MODEL to VOCAB_SIZE).
    # Implements the full forward pass: Enc(X) -> Dec(Y, Enc_Output).
    pass

In [None]:
# --- 1. Initialization ---
# Instantiate the model, optimizer (Adam), and loss function (CrossEntropyLoss).
model = CrossAttentionLookupModel(...)

# --- 2. Training Loop ---
# Iterates through NUM_EPOCHS:
#   - Trains on train_loader.
#   - Evaluates on val_loader to track performance.
#   - Includes logic to save the best model weights.

In [None]:
# --- 1. Inference Function ---
def greedy_decode(model, enc_input, max_len):
    # Implements token-by-token autoregressive generation for testing.
    pass

# --- 2. The Scrambling Test Execution ---
# Selects a test batch.
# Runs inference on the ORIGINAL enc_input.
# Creates a SCRAMBLED version of the enc_input.
# Runs inference on the SCRAMBLED enc_input.
# Prints the results: Input (Original) -> Output; Input (Scrambled) -> Output.
# VERIFICATION: The two outputs must be identical.

In [None]:
# --- 1. Attention Hook Setup ---
# Add a 'hook' function to the Cross-Attention block in the DecoderLayer
# to capture and save the attention weight tensor during the forward pass.

# --- 2. Extraction Run ---
# Run a single batch from the test_loader through the model to trigger the hook.
# Save the extracted attention matrix.

# --- 3. Visualization ---
def plot_cross_attention(attention_matrix, enc_tokens, dec_tokens):
    # Uses Matplotlib/Seaborn to plot the attention_matrix as a heatmap.
    # Labels the X-axis with the Encoder tokens and the Y-axis with the Decoder tokens.
    # CRITICAL: Visually confirms the sharp spike of attention from the query token to the value token.
    pass

# Execute plot_cross_attention to display the key heatmap.