# Week 3 Day 12: Sequence Packing & Masking for Causal LM

## Overview
In this notebook, we'll explore efficient sequence packing and masking techniques for causal language modeling, focusing on:
- Implementing sequence packing algorithms
- Creating causal attention masks
- Building an efficient data pipeline for transformer training

In [None]:
# Import necessary libraries
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import time
import random
from typing import List, Dict, Tuple, Optional
from tokenizers import Tokenizer

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Loading Tokenizer and Sample Data

Let's load the BPE tokenizer we created in Part 1 and prepare some sample data.

In [None]:
# Load the BPE tokenizer
try:
    tokenizer = Tokenizer.from_file("tokenizers/bpe_tokenizer.json")
    print("Loaded BPE tokenizer successfully")
except FileNotFoundError:
    print("Tokenizer file not found. Please run Part 1 first.")
    # Create a simple fallback tokenizer for demonstration
    from tokenizers import Tokenizer
    from tokenizers.models import BPE
    from tokenizers.trainers import BpeTrainer
    from tokenizers.pre_tokenizers import Whitespace
    
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()
    
    # Train on a small sample
    sample_texts = [
        "This is a sample text for tokenizer training.",
        "We need a tokenizer to demonstrate sequence packing.",
        "Transformers use attention mechanisms for language modeling."
    ]
    
    trainer = BpeTrainer(
        vocab_size=1000,
        special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
    )
    
    # Save sample texts to file
    import os
    os.makedirs("data", exist_ok=True)
    with open("data/sample.txt", "w") as f:
        f.write("\n".join(sample_texts))
    
    tokenizer.train(["data/sample.txt"], trainer)
    print("Created fallback tokenizer")

# Load or create sample data
try:
    with open("data/all_texts.txt", "r", encoding="utf-8") as f:
        all_text = f.read()
    print(f"Loaded sample data: {len(all_text)} characters")
except FileNotFoundError:
    print("Sample data not found. Creating minimal sample.")
    all_text = """
    Sequence packing is an important technique for efficient transformer training.
    It allows us to maximize GPU utilization by combining multiple sequences into a single batch.
    Causal masking ensures that the model can only attend to previous tokens in autoregressive language modeling.
    Efficient data pipelines are crucial for training large language models at scale.
    """
    print(f"Created minimal sample: {len(all_text)} characters")

## 2. Implementing Sequence Packing

Let's implement a sequence packing algorithm to efficiently combine multiple sequences into a single batch.

In [None]:
def create_packed_sequences(texts, tokenizer, max_seq_len=512):
    """Pack multiple sequences into fixed-length chunks."""
    # Tokenize all texts
    tokenized_texts = [tokenizer.encode(text) for text in texts]
    token_ids = [encoding.ids for encoding in tokenized_texts]
    
    # Get pad token ID
    pad_id = tokenizer.token_to_id("[PAD]")
    if pad_id is None:
        pad_id = 0  # Fallback
    
    # Initialize packed sequences
    packed_sequences = []
    sequence_mappings = []  # To track which original sequence each token belongs to
    current_sequence = []
    current_mapping = []
    current_length = 0
    
    # Sort sequences by length (descending) for more efficient packing
    token_ids_with_idx = [(i, ids) for i, ids in enumerate(token_ids)]
    token_ids_with_idx.sort(key=lambda x: len(x[1]), reverse=True)
    
    # Pack sequences
    for orig_idx, ids in token_ids_with_idx:
        # If adding this sequence would exceed max_seq_len, start a new packed sequence
        if current_length + len(ids) > max_seq_len:
            # Pad current sequence to max_seq_len
            padding_needed = max_seq_len - current_length
            current_sequence.extend([pad_id] * padding_needed)
            current_mapping.extend([-1] * padding_needed)  # -1 indicates padding
            
            # Add to packed sequences
            packed_sequences.append(current_sequence)
            sequence_mappings.append(current_mapping)
            
            # Start new sequence
            current_sequence = []
            current_mapping = []
            current_length = 0
        
        # Add current sequence
        current_sequence.extend(ids)
        current_mapping.extend([orig_idx] * len(ids))
        current_length += len(ids)
    
    # Add the last sequence if not empty
    if current_length > 0:
        # Pad to max_seq_len
        padding_needed = max_seq_len - current_length
        current_sequence.extend([pad_id] * padding_needed)
        current_mapping.extend([-1] * padding_needed)
        
        packed_sequences.append(current_sequence)
        sequence_mappings.append(current_mapping)
    
    return packed_sequences, sequence_mappings

def visualize_packed_sequences(packed_sequences, sequence_mappings):
    """Visualize packed sequences."""
    num_sequences = len(packed_sequences)
    seq_len = len(packed_sequences[0])
    
    plt.figure(figsize=(12, num_sequences * 0.5))
    
    for i, mapping in enumerate(sequence_mappings):
        # Create a row for each packed sequence
        row = np.array(mapping)
        row = row.reshape(1, -1)
        
        # Plot as heatmap
        ax = plt.subplot(num_sequences, 1, i+1)
        sns.heatmap(row, cmap='viridis', cbar=False, xticklabels=50, yticklabels=False)
        ax.set_title(f"Packed Sequence {i+1}")
    
    plt.tight_layout()
    plt.show()

In [None]:
# Split text into sentences for packing
import re
sentences = re.split(r'(?<=[.!?])\s+', all_text)
sentences = [s.strip() for s in sentences if len(s.strip()) > 0]

# Take a subset for demonstration
sample_sentences = sentences[:50]
print(f"Number of sentences: {len(sample_sentences)}")

# Create packed sequences
max_seq_len = 128
packed_sequences, sequence_mappings = create_packed_sequences(
    sample_sentences, tokenizer, max_seq_len
)

print(f"Number of packed sequences: {len(packed_sequences)}")
print(f"Sequence length: {len(packed_sequences[0])}")

# Calculate packing efficiency
total_tokens = sum(len(tokenizer.encode(s).ids) for s in sample_sentences)
packed_tokens = len(packed_sequences) * max_seq_len
packing_efficiency = total_tokens / packed_tokens * 100

print(f"Total tokens in original sentences: {total_tokens}")
print(f"Total tokens in packed sequences: {packed_tokens}")
print(f"Packing efficiency: {packing_efficiency:.2f}%")

# Visualize packed sequences
visualize_packed_sequences(packed_sequences, sequence_mappings)

## 3. Creating Attention Masks for Packed Sequences

Now let's create attention masks for our packed sequences, ensuring proper causal masking and preventing attention across different packed examples.

In [None]:
def create_causal_mask(seq_len):
    """Create a causal mask for autoregressive language modeling."""
    # Create a mask where each position can attend to itself and previous positions
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    # Convert to float and replace True with -inf, False with 0.0
    mask = mask.float().masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
    return mask

def create_packed_attention_mask(sequence_mapping, seq_len):
    """Create attention mask for packed sequences."""
    # Start with causal mask
    mask = create_causal_mask(seq_len)
    
    # Modify mask to prevent attention across different sequences
    for i in range(seq_len):
        for j in range(seq_len):
            # If tokens belong to different sequences or one is padding, mask the attention
            if sequence_mapping[i] != sequence_mapping[j] or sequence_mapping[i] == -1 or sequence_mapping[j] == -1:
                mask[i, j] = float('-inf')
    
    return mask

def visualize_attention_mask(mask, title="Attention Mask"):
    """Visualize an attention mask."""
    # Convert -inf to a small value for visualization
    vis_mask = mask.clone()
    vis_mask[vis_mask == float('-inf')] = -10
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(vis_mask, cmap='Blues')
    plt.title(title)
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.show()

In [None]:
# Create a standard causal mask
seq_len = 64  # Smaller for visualization
causal_mask = create_causal_mask(seq_len)
visualize_attention_mask(causal_mask, "Standard Causal Mask")

# Create a packed attention mask
# Simulate a sequence mapping with 3 sequences
sequence_mapping = [0] * 20 + [1] * 25 + [2] * 15 + [-1] * 4  # Three sequences + padding
packed_mask = create_packed_attention_mask(sequence_mapping, seq_len)
visualize_attention_mask(packed_mask, "Packed Sequence Attention Mask")

# Compare with a real packed sequence
if len(packed_sequences) > 0:
    real_mapping = sequence_mappings[0][:seq_len]  # Take first seq_len tokens
    real_packed_mask = create_packed_attention_mask(real_mapping, seq_len)
    visualize_attention_mask(real_packed_mask, "Real Packed Sequence Attention Mask")

## 4. Building an Efficient Data Pipeline

Let's create an efficient data pipeline for training language models with packed sequences.

In [None]:
class PackedSequenceDataset(torch.utils.data.Dataset):
    """Dataset for packed sequences."""
    
    def __init__(self, packed_sequences, sequence_mappings):
        self.packed_sequences = packed_sequences
        self.sequence_mappings = sequence_mappings
    
    def __len__(self):
        return len(self.packed_sequences)
    
    def __getitem__(self, idx):
        # Get packed sequence and mapping
        sequence = self.packed_sequences[idx]
        mapping = self.sequence_mappings[idx]
        
        # Convert to tensors
        sequence_tensor = torch.tensor(sequence, dtype=torch.long)
        mapping_tensor = torch.tensor(mapping, dtype=torch.long)
        
        # Create input and target tensors for language modeling
        # Input: all tokens except the last one
        # Target: all tokens except the first one
        input_tensor = sequence_tensor[:-1]
        target_tensor = sequence_tensor[1:]
        
        # Create attention mask
        seq_len = len(input_tensor)
        mask = create_packed_attention_mask(mapping[:-1], seq_len)
        
        return {
            'input_ids': input_tensor,
            'targets': target_tensor,
            'attention_mask': mask,
            'sequence_mapping': mapping_tensor[:-1]
        }

def collate_packed_sequences(batch):
    """Collate function for packed sequences."""
    # Stack all tensors
    input_ids = torch.stack([item['input_ids'] for item in batch])
    targets = torch.stack([item['targets'] for item in batch])
    attention_masks = torch.stack([item['attention_mask'] for item in batch])
    sequence_mappings = torch.stack([item['sequence_mapping'] for item in batch])
    
    return {
        'input_ids': input_ids,
        'targets': targets,
        'attention_mask': attention_masks,
        'sequence_mapping': sequence_mappings
    }

In [None]:
# Create dataset and dataloader
dataset = PackedSequenceDataset(packed_sequences, sequence_mappings)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_packed_sequences
)

# Test the dataloader
batch = next(iter(dataloader))
print(f"Input shape: {batch['input_ids'].shape}")
print(f"Target shape: {batch['targets'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"Sequence mapping shape: {batch['sequence_mapping'].shape}")

# Visualize a sample attention mask from the batch
sample_mask = batch['attention_mask'][0]
visualize_attention_mask(sample_mask, "Sample Batch Attention Mask")

## 5. Measuring Efficiency Gains

Let's measure the efficiency gains from using packed sequences compared to standard padding.

In [None]:
def create_standard_sequences(texts, tokenizer, max_seq_len=512):
    """Create standard padded sequences without packing."""
    # Tokenize all texts
    tokenized_texts = [tokenizer.encode(text) for text in texts]
    token_ids = [encoding.ids for encoding in tokenized_texts]
    
    # Get pad token ID
    pad_id = tokenizer.token_to_id("[PAD]")
    if pad_id is None:
        pad_id = 0  # Fallback
    
    # Pad sequences
    padded_sequences = []
    for ids in token_ids:
        # Truncate if too long
        if len(ids) > max_seq_len:
            ids = ids[:max_seq_len]
        
        # Pad if too short
        padding_needed = max_seq_len - len(ids)
        padded_ids = ids + [pad_id] * padding_needed
        padded_sequences.append(padded_ids)
    
    return padded_sequences

# Compare standard padding vs. sequence packing
standard_sequences = create_standard_sequences(sample_sentences, tokenizer, max_seq_len)

# Calculate efficiency metrics
total_tokens = sum(len(tokenizer.encode(s).ids) for s in sample_sentences)
standard_tokens = len(standard_sequences) * max_seq_len
packed_tokens = len(packed_sequences) * max_seq_len

standard_efficiency = total_tokens / standard_tokens * 100
packed_efficiency = total_tokens / packed_tokens * 100
efficiency_gain = packed_efficiency / standard_efficiency

print(f"Total tokens in original sentences: {total_tokens}")
print(f"Total tokens with standard padding: {standard_tokens}")
print(f"Total tokens with sequence packing: {packed_tokens}")
print(f"Standard padding efficiency: {standard_efficiency:.2f}%")
print(f"Sequence packing efficiency: {packed_efficiency:.2f}%")
print(f"Efficiency gain: {efficiency_gain:.2f}x")

# Visualize comparison
labels = ['Original', 'Standard Padding', 'Sequence Packing']
values = [total_tokens, standard_tokens, packed_tokens]

plt.figure(figsize=(10, 6))
bars = plt.bar(labels, values)
plt.title('Token Count Comparison')
plt.ylabel('Number of Tokens')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add efficiency labels
plt.text(1, values[1] * 1.05, f"{standard_efficiency:.1f}% Efficient", ha='center')
plt.text(2, values[2] * 1.05, f"{packed_efficiency:.1f}% Efficient", ha='center')

plt.show()

## 6. Summary and Key Insights

In this notebook, we've explored sequence packing and masking techniques for efficient language model training:

1. **Sequence Packing**:
   - Efficiently combining multiple sequences into fixed-length batches
   - Tracking sequence boundaries with mappings
   - Significantly improving GPU utilization

2. **Attention Masking**:
   - Creating causal masks for autoregressive language modeling
   - Preventing attention across different packed sequences
   - Handling padding tokens appropriately

3. **Efficient Data Pipeline**:
   - Building a PyTorch dataset for packed sequences
   - Creating proper collate functions
   - Generating input-target pairs for language modeling

4. **Efficiency Analysis**:
   - Measuring token utilization with different approaches
   - Quantifying the benefits of sequence packing
   - Visualizing efficiency gains

These techniques are essential for training large language models efficiently, especially when working with limited computational resources.