In [None]:
# Option 1: Import everything from the helper library
from hybrid_transformer import *

# Option 2: Import specific components
# from hybrid_transformer import (
#     HybridTransformer,
#     NeedleInHaystackDataset,
#     evaluate_random_baseline,
#     count_parameters
# )

print("✓ Hybrid Transformer library loaded!")

In [49]:
#!pip install torch torchvision torchaudio
#!pip install numpy matplotlib

print("✓ All packages installed successfully!")

✓ All packages installed successfully!


In [50]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from torch.nn import LayerNorm
import random
from torch.utils.data import Dataset, DataLoader

print("✓ Libraries imported successfully!")

✓ Libraries imported successfully!


## Import Libraries

In [51]:
import random
from torch.utils.data import Dataset, DataLoader

class NeedleInHaystackDataset(Dataset):
    """
    A toy dataset for testing long-context memory.
    
    Task: Find a special key-value pair hidden in a long sequence.
    Format: [random tokens...] <KEY> special_id <VALUE> target_value [random tokens...]
    Goal: Given the sequence, predict the target_value associated with special_id
    
    This tests if the model can:
    1. Identify the needle (key-value pair) in a long sequence
    2. Store it in memory
    3. Retrieve it later when queried
    """
    
    def __init__(self, num_samples=1000, vocab_size=100, 
                 haystack_length=1000, num_needles=5, seed=42):
        """
        Args:
            num_samples: Number of examples in dataset
            vocab_size: Size of vocabulary (must be > num_needles + 3 for special tokens)
            haystack_length: Total sequence length
            num_needles: Number of key-value pairs to hide
            seed: Random seed for reproducibility
        """
        self.num_samples = num_samples
        self.vocab_size = vocab_size
        self.haystack_length = haystack_length
        self.num_needles = num_needles
        
        # Special tokens
        self.PAD_TOKEN = 0
        self.KEY_TOKEN = vocab_size - 3  # Special token indicating a key
        self.VALUE_TOKEN = vocab_size - 2  # Special token indicating a value
        self.QUERY_TOKEN = vocab_size - 1  # Special token for query
        
        # Key and value IDs (use tokens that won't be in haystack)
        self.key_ids = list(range(vocab_size - 3 - num_needles * 2, 
                                   vocab_size - 3 - num_needles))
        self.value_ids = list(range(vocab_size - 3 - num_needles, 
                                     vocab_size - 3))
        
        # Tokens available for haystack (excluding special tokens and key/value IDs)
        self.haystack_vocab = list(range(1, vocab_size - 3 - num_needles * 2))
        
        random.seed(seed)
        self.data = self._generate_dataset()
    
    def _generate_dataset(self):
        """Generate all samples"""
        data = []
        for _ in range(self.num_samples):
            sample = self._generate_sample()
            data.append(sample)
        return data
    
    def _generate_sample(self):
        """
        Generate a single sample with format:
        [haystack] <KEY> key_id <VALUE> value_id [haystack] ... <QUERY> key_id -> target: value_id
        """
        # Create needle pairs (key-value associations)
        needle_pairs = list(zip(self.key_ids[:self.num_needles], 
                                self.value_ids[:self.num_needles]))
        random.shuffle(needle_pairs)
        
        # Choose which needle to query
        query_key, query_value = random.choice(needle_pairs)
        
        # Calculate space for haystack
        # Each needle takes 4 tokens: <KEY> key_id <VALUE> value_id
        # Query takes 2 tokens: <QUERY> key_id
        needle_tokens = self.num_needles * 4
        query_tokens = 2
        available_haystack = self.haystack_length - needle_tokens - query_tokens
        
        # Distribute haystack tokens
        haystack_segments = []
        remaining = available_haystack
        for i in range(self.num_needles + 1):  # One segment before each needle, one after
            segment_length = remaining // (self.num_needles + 1 - i)
            haystack_segments.append(
                [random.choice(self.haystack_vocab) for _ in range(segment_length)]
            )
            remaining -= segment_length
        
        # Build sequence: interleave haystack and needles
        sequence = []
        for i, (key_id, value_id) in enumerate(needle_pairs):
            sequence.extend(haystack_segments[i])
            sequence.extend([self.KEY_TOKEN, key_id, self.VALUE_TOKEN, value_id])
        sequence.extend(haystack_segments[-1])  # Final haystack segment
        
        # Add query at the end
        sequence.extend([self.QUERY_TOKEN, query_key])
        
        # Target is the value associated with the queried key
        target = query_value
        
        return {
            'input': torch.tensor(sequence, dtype=torch.long),
            'target': torch.tensor(target, dtype=torch.long),
            'needle_pairs': needle_pairs,
            'query_key': query_key
        }
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def collate_fn(self, batch):
        """Custom collate function for DataLoader"""
        inputs = torch.stack([item['input'] for item in batch])
        targets = torch.stack([item['target'] for item in batch])
        return inputs, targets

In [52]:
# Random baseline
def evaluate_random_baseline(dataset):
    """Evaluate random guessing performance"""
    correct = 0
    total = len(dataset)
    
    for i in range(total):
        sample = dataset[i]
        # Random guess from possible values
        random_guess = random.choice(dataset.value_ids)
        if random_guess == sample['target'].item():
            correct += 1
    
    accuracy = correct / total * 100
    return accuracy

random_acc = evaluate_random_baseline(test_dataset)
print(f"Random baseline accuracy: {random_acc:.2f}%")
print(f"Expected random accuracy: {100/test_dataset.num_needles:.2f}% (1/{test_dataset.num_needles})")
print(f"\nThis is what we need to beat! The model must learn to:")
print("  1. Identify and memorize the key-value pairs")
print("  2. Recognize the query key")
print("  3. Retrieve the correct value from memory")

Random baseline accuracy: 20.00%
Expected random accuracy: 33.33% (1/3)

This is what we need to beat! The model must learn to:
  1. Identify and memorize the key-value pairs
  2. Recognize the query key
  3. Retrieve the correct value from memory


### Baseline: Random Guessing Performance

In [53]:
# Create training and validation datasets
train_dataset = NeedleInHaystackDataset(
    num_samples=1000,
    vocab_size=100,
    haystack_length=512,  # Medium length for initial training
    num_needles=5,
    seed=42
)

val_dataset = NeedleInHaystackDataset(
    num_samples=200,
    vocab_size=100,
    haystack_length=512,
    num_needles=5,
    seed=123  # Different seed for validation
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=train_dataset.collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=val_dataset.collate_fn
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Test loading a batch
inputs, targets = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  Inputs: {inputs.shape}")
print(f"  Targets: {targets.shape}")

Training batches: 125
Validation batches: 25

Batch shapes:
  Inputs: torch.Size([8, 512])
  Targets: torch.Size([8])


### Create DataLoader for Training

In [54]:
# Create a small dataset to visualize
test_dataset = NeedleInHaystackDataset(
    num_samples=5,
    vocab_size=50,
    haystack_length=100,
    num_needles=3,
    seed=42
)

print("Dataset Information:")
print(f"Total samples: {len(test_dataset)}")
print(f"Vocab size: {test_dataset.vocab_size}")
print(f"Haystack length: {test_dataset.haystack_length}")
print(f"Number of needles: {test_dataset.num_needles}")
print(f"Special tokens: KEY={test_dataset.KEY_TOKEN}, VALUE={test_dataset.VALUE_TOKEN}, QUERY={test_dataset.QUERY_TOKEN}")
print(f"\nKey IDs: {test_dataset.key_ids}")
print(f"Value IDs: {test_dataset.value_ids}")

# Examine one sample
sample = test_dataset[0]
sequence = sample['input']
target = sample['target']

print(f"\n{'='*60}")
print("Sample Sequence Breakdown:")
print(f"{'='*60}")
print(f"Total sequence length: {len(sequence)}")
print(f"Target value: {target.item()}")
print(f"Needle pairs in this sample: {sample['needle_pairs']}")
print(f"Queried key: {sample['query_key']}")

# Find and display the needles in the sequence
print(f"\nNeedles found in sequence:")
seq_list = sequence.tolist()
for i, token in enumerate(seq_list):
    if token == test_dataset.KEY_TOKEN:
        key_id = seq_list[i+1]
        value_id = seq_list[i+3]
        is_queried = "← QUERIED" if key_id == sample['query_key'] else ""
        print(f"  Position {i:3d}: <KEY> {key_id} <VALUE> {value_id} {is_queried}")

# Display query
query_pos = seq_list.index(test_dataset.QUERY_TOKEN)
query_key = seq_list[query_pos + 1]
print(f"\nQuery at position {query_pos}: <QUERY> {query_key} -> Expected answer: {target.item()}")

# Show a snippet of the sequence
print(f"\n{'='*60}")
print("Sequence snippet (first 50 tokens):")
print(f"{'='*60}")
print(seq_list[:50])

Dataset Information:
Total samples: 5
Vocab size: 50
Haystack length: 100
Number of needles: 3
Special tokens: KEY=47, VALUE=48, QUERY=49

Key IDs: [41, 42, 43]
Value IDs: [44, 45, 46]

Sample Sequence Breakdown:
Total sequence length: 100
Target value: 45
Needle pairs in this sample: [(42, 45), (41, 44), (43, 46)]
Queried key: 42

Needles found in sequence:
  Position  21: <KEY> 42 <VALUE> 45 ← QUERIED
  Position  46: <KEY> 41 <VALUE> 44 
  Position  72: <KEY> 43 <VALUE> 46 

Query at position 98: <QUERY> 42 -> Expected answer: 45

Sequence snippet (first 50 tokens):
[18, 16, 15, 9, 7, 35, 6, 38, 28, 3, 2, 6, 14, 15, 33, 39, 2, 36, 13, 35, 27, 47, 42, 48, 45, 15, 29, 38, 18, 1, 11, 28, 22, 18, 10, 14, 22, 7, 6, 25, 7, 23, 23, 39, 17, 3, 47, 41, 48, 44]


### Test the Dataset Generator

In [55]:
# Test the model
vocab_size = 100
hidden_size = 64
num_layers = 2
n_slots = 8
window_size = 128

model = HybridTransformer(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    n_slots=n_slots,
    window_size=window_size
)

# Create dummy input
batch_size = 2
seq_len = 50
inputs = torch.randint(0, vocab_size, (batch_size, seq_len))

# Forward pass
logits, memory = model(inputs)

print(f"Input shape: {inputs.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Memory state shape: {memory.shape}")
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
print("\n✓ Model instantiated and forward pass successful!")

Input shape: torch.Size([2, 50])
Output logits shape: torch.Size([2, 50, 100])
Memory state shape: torch.Size([2, 8, 64])
Model has 97920 parameters

✓ Model instantiated and forward pass successful!


### Test Model Instantiation

---
## Testing & Evaluation

Now that all components are defined, let's test the model and dataset:

## Needle-in-Haystack Dataset

This toy task tests whether the model can find and retrieve a specific "needle" (key-value pair) hidden in a long sequence of "haystack" (random tokens).

In [56]:
class HybridTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, n_slots=16, window_size=512):
        super().__init__()
        
        self.tok_emb = nn.Embedding(vocab_size, hidden_size)
        self.pos_emb = PositionalEncoding(hidden_size)
        
        # Initialize memory slots as learnable embeddings
        self.memory_init = nn.Parameter(torch.randn(1, n_slots, hidden_size) * 0.02)
        
        # Stack of hybrid blocks
        self.blocks = nn.ModuleList([
            HybridTransformerBlock(hidden_size, n_slots, window_size) 
            for _ in range(num_layers)
        ])
        
        self.head = nn.Linear(hidden_size, vocab_size, bias=False)
        self.head.weight = self.tok_emb.weight  # Weight tying
        
        self.hidden_size = hidden_size
        self.n_slots = n_slots
        
    def forward(self, inputs, memory_state=None):
        """
        inputs: shape (batch, seq_len)
        memory_state: optional, shape (batch, n_slots, hidden_size)
        Returns: logits (batch, seq_len, vocab_size), final_memory
        """
        batch_size = inputs.shape[0]
        seq_len = inputs.shape[1]
        
        # Initialize memory if not provided
        if memory_state is None:
            memory_state = self.memory_init.expand(batch_size, -1, -1)
        
        # Embeddings
        pos = torch.arange(0, seq_len, device=inputs.device).long()
        x = self.tok_emb(inputs) + self.pos_emb(pos)
        
        # Pass through transformer blocks with memory
        for block in self.blocks:
            x, memory_state = block(x, memory_state)
        
        # Output projection
        logits = self.head(x)
        
        return logits, memory_state

## Full Hybrid Transformer Model

In [57]:
class HybridTransformerBlock(nn.Module):
    def __init__(self, hidden_size, n_slots=16, window_size=512):
        super().__init__()
        self.norm1 = LayerNorm(hidden_size)
        self.norm2 = LayerNorm(hidden_size)
        self.norm3 = LayerNorm(hidden_size)
        
        # Components you already know
        self.local_attention = WindowedAttention(hidden_size, window_size)
        
        # New components for memory
        self.memory_read = CrossAttention(hidden_size, n_slots)
        self.memory_write = GatedSSM(hidden_size, n_slots)
        
        # MLP from your assignment
        self.mlp = MLP(hidden_size)
        
    def forward(self, x, memory_state):
        """
        x: shape (batch, seq_len, hidden_size)
        memory_state: shape (batch, n_slots, hidden_size)
        Returns: output (batch, seq_len, hidden_size), new_memory (batch, n_slots, hidden_size)
        """
        # 1. Local windowed attention (you've done this before)
        h_local, _ = self.local_attention(self.norm1(x))
        h = x + h_local
        
        # 2. Read from memory using cross-attention
        context = self.memory_read(self.norm2(h), memory_state)
        h = h + context  # Add memory context to hidden states
        
        # 3. MLP (standard transformer component)
        h = h + self.mlp(self.norm3(h))
        
        # 4. Write to memory (update memory state)
        new_memory = self.memory_write(h, memory_state)
        
        return h, new_memory

## Hybrid Transformer Block

This combines local attention with memory read/write:

In [58]:
class GatedSSM(nn.Module):
    """Gated State Space Model for writing to memory"""
    def __init__(self, hidden_size, n_slots):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_slots = n_slots
        
        # Gate parameters: decide how much to update memory
        self.gate_net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Sigmoid()
        )
        
        # Update parameters: what to write to memory
        self.A = nn.Linear(hidden_size, hidden_size, bias=False)  # Memory recurrence
        self.B = nn.Linear(hidden_size, hidden_size, bias=False)  # Input projection
        
    def forward(self, h, memory_state):
        """
        h: shape (batch, seq_len, hidden_size) - current hidden states
        memory_state: shape (batch, n_slots, hidden_size) - current memory
        Returns: new_memory (batch, n_slots, hidden_size)
        
        Implements: M_{t+1} = (1 - G_t) ⊙ M_t + G_t ⊙ φ(AM_t + BU_t)
        """
        # TODO: Implement gated memory update
        
        # Summarize the sequence for memory update (simple mean pooling)
        summary = h.mean(dim=1, keepdim=True)  # (batch, 1, hidden_size)
        summary = summary.expand(-1, self.n_slots, -1)  # (batch, n_slots, hidden_size)
        
        # Compute gate: how much to update each memory slot
        gate = self.gate_net(summary)  # (batch, n_slots, hidden_size)
        
        # Compute update: what to write to memory
        update = torch.tanh(self.A(memory_state) + self.B(summary))
        
        # Blend old and new memory
        new_memory = (1 - gate) * memory_state + gate * update
        
        return new_memory

In [59]:
class CrossAttention(nn.Module):
    """Cross-attention for reading from memory slots"""
    def __init__(self, hidden_size, n_slots):
        super().__init__()
        self.Q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.K = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V = nn.Linear(hidden_size, hidden_size, bias=False)
        self.hidden_size = hidden_size
        self.n_slots = n_slots
    
    def forward(self, x, memory):
        """
        x: shape (batch, seq_len, hidden_size) - queries from input
        memory: shape (batch, n_slots, hidden_size) - keys/values from memory
        Returns: context (batch, seq_len, hidden_size)
        """
        # TODO: Implement cross-attention to read from memory
        # This is similar to your Attention class but queries come from x
        # and keys/values come from memory
        
        q = self.Q(x)  # (batch, seq_len, hidden_size)
        k = self.K(memory)  # (batch, n_slots, hidden_size)
        v = self.V(memory)  # (batch, n_slots, hidden_size)
        
        context, attention = scaled_dot_attention(q, k, v, mask=0)
        return context

## Memory Components (New - To Implement)

These handle reading from and writing to external memory slots:

In [60]:
class WindowedAttention(nn.Module):
    """Local attention within a sliding window"""
    def __init__(self, hidden_size, window_size=512):
        super().__init__()
        self.Q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.K = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V = nn.Linear(hidden_size, hidden_size, bias=False)
        self.hidden_size = hidden_size
        self.window_size = window_size
    
    def forward(self, x):
        """
        x: shape (batch, seq_len, hidden_size)
        Returns: context (batch, seq_len, hidden_size), attention weights
        """
        batch_size, seq_len, _ = x.shape
        
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        
        # TODO: Implement windowed attention
        # For now, just doing causal attention as placeholder
        # You'll need to create a mask that only attends to nearby tokens
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        
        # Simple approach: full causal attention (to be replaced with windowed)
        unnorm_attn = torch.matmul(q, k.transpose(-2, -1))
        unnorm_attn = unnorm_attn / math.sqrt(self.hidden_size)
        masked_unnorm_attn = unnorm_attn + mask * -1e9
        attention = torch.softmax(masked_unnorm_attn, dim=-1)
        context = torch.matmul(attention, v)
        
        return context, attention

## Windowed Attention (New Component)

This implements local attention within a sliding window:

In [61]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, max_len = 1000):
        super().__init__()
        pos = torch.arange(max_len).float().unsqueeze(1)
        dim = torch.arange(hidden_size // 2).float().unsqueeze(0)
        div_term = torch.exp(-math.log(10000.0) * (2 * dim) / hidden_size)
        angle = pos * div_term
        pe = torch.zeros(max_len, hidden_size)
        pe[:, 0::2] = torch.sin(angle)
        pe[:, 1::2] = torch.cos(angle)
        self.register_buffer("pe", pe)

    def forward(self, idx):
        return self.pe[idx]


class MLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.layer1  = nn.Linear(hidden_size, hidden_size, bias=False)
        self.layer2  = nn.Linear(hidden_size, hidden_size, bias=False)
        self.relu    = nn.ReLU()

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

In [None]:
def scaled_dot_attention(q, k, v, mask=0):
    """Computes scaled dot product attention with an optional mask.
    q : shape ( batch, k, hidden_size)
    k : shape ( batch, seq_len, hidden_size)
    v : shape ( batch, seq_len, hidden_size)
    mask: optional. shape (k, seq_len)

    returns:
        context   : shape (batch, k, hidden_size)
        attention : shape (batch, k, seq_len)
    """
    unnorm_attn = torch.matmul(q, k.transpose(-2, -1))
    unnorm_attn = unnorm_attn/torch.sqrt(torch.tensor(q.shape[-1], dtype=torch.float32))
    masked_unnorm_attn = unnorm_attn + mask*-1e9
    attention = torch.softmax(masked_unnorm_attn, dim = -1)
    context = torch.matmul(attention, v)
    return context, attention


class Attention(nn.Module):
    """Standard attention mechanism"""
    def __init__(self, hidden_size):
        super().__init__()
        self.Q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.K = nn.Linear(hidden_size, hidden_size, bias=False)
        self.V = nn.Linear(hidden_size, hidden_size, bias=False)
        self.hidden_size = hidden_size

    def forward(self, x, annots):
        q = self.Q(x)
        k = self.K(annots)
        v = self.V(annots)
        return scaled_dot_attention(q, k, v)


class CausalAttention(Attention):
    """Causal (autoregressive) attention"""
    def __init__(self, hidden_size):
        super().__init__(hidden_size)

    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        seq_len = x.shape[1]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal = 1)
        return scaled_dot_attention(q, k, v, mask)

# Hybrid Transformer with Memory Slots

This notebook implements a transformer with external memory slots for long-context processing, building on components from the NN assignment.