# Training TransformerLens on the mess3 Process

This notebook demonstrates how to:
1. Use the simplexity library to generate data from the mess3 Hidden Markov Model
2. Train a TransformerLens model on this data
3. Use TransformerLens's interpretability features to analyze what the model learned

The mess3 process is a 3-state Hidden Markov Model that generates sequences with specific transition patterns.

## Quick Start for Google Colab

**Option 1: Open this notebook directly in Colab**
- Go to: https://colab.research.google.com/github/Astera-org/simplexity/blob/MATS_2025_app/notebooks/train_transformerlens_mess3.ipynb

**Option 2: Download and run in any Colab notebook**
```python
# Download this notebook to Colab
!wget https://raw.githubusercontent.com/Astera-org/simplexity/MATS_2025_app/notebooks/train_transformerlens_mess3.ipynb
```

Then File → Open → Upload the downloaded notebook

In [None]:
# Simple installation for Google Colab
# Install from the MATS_2025_app branch with TransformerLens support

!pip install git+https://github.com/adamimos/simplexity.git@MATS_2025_app#egg=simplexity[transformerlens] -q

print("✅ Installation complete!")

# Do NOT import anything before this cell.
%pip -q install --upgrade pip wheel setuptools

# Keep Colab's NumPy 2.x to avoid ABI hell.
# Install minimal, pure-Python deps that TransformerLens actually uses.
%pip -q install "einops>=0.7.0" "jaxtyping>=0.2.28" "beartype>=0.14"

# Install TransformerLens WITHOUT pulling its optional heavy deps (e.g. HF transformers).
# Pin to a 2025 build that supports Python 3.12+ and works fine with NumPy 2.x.
%pip -q install --no-deps "transformer-lens>=2.16.1"

# Install simplexity from your branch (without extras so it doesn't try to re-resolve TL)
%pip -q install "git+https://github.com/Astera-org/simplexity.git@MATS_2025_app"

# Install better_abc (required by TransformerLens)
%pip -q install better_abc

print("✅ Installation complete! You can now run the import cell.")

In [None]:
# Alternative: One-line installation (run after restart)
!pip install numpy==1.24.3 git+https://github.com/Astera-org/simplexity.git@MATS_2025_app#egg=simplexity[transformerlens] -q

print("✅ Ready to go!")

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import jax
import jax.numpy as jnp

# TransformerLens imports
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens import utils as tl_utils

# Simplexity imports
from simplexity.generative_processes.builder import build_hidden_markov_model
from simplexity.generative_processes.torch_generator import generate_data_batch
from simplexity.predictive_models.transformerlens_model import TransformerLensWrapper

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Create the mess3 process
mess3 = build_hidden_markov_model("mess3", x=0.15, a=0.6)

print(f"mess3 process created:")
print(f"  Vocabulary size: {mess3.vocab_size}")
print(f"  Number of states: {mess3.num_states}")
print(f"  Initial state shape: {mess3.initial_state.shape}")

# Get the stationary distribution
stationary_state = mess3.stationary_state
print(f"\nStationary distribution: {stationary_state}")

## 3. Configure the TransformerLens Model

We'll create a small transformer suitable for learning the mess3 patterns.

In [None]:
# Model configuration
model_config = {
    "d_model": 64,           # Model dimension
    "d_head": 16,            # Head dimension  
    "n_heads": 4,            # Number of attention heads
    "n_layers": 2,           # Number of transformer layers
    "n_ctx": 64,             # Context window
    "d_vocab": mess3.vocab_size,  # Vocabulary size (3 for mess3)
    "act_fn": "relu",        # Activation function
    "normalization_type": "LN",  # Layer normalization
    "device": str(device),
    "seed": 42,
    "attn_only": False,      # Include MLPs
    "use_cache": True,       # Enable caching for interpretability
    "use_hook_tokens": True, # Enable hook points for analysis
}

# Create the model using the wrapper
model = TransformerLensWrapper(**model_config)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {total_params:,} parameters")
print(f"Model configuration:")
for key, value in model_config.items():
    print(f"  {key}: {value}")

## 4. Data Generation Function

Create a function to generate training batches from the mess3 process.

In [None]:
def generate_training_batch(
    generator,
    batch_size: int,
    sequence_len: int,
    key: jax.random.PRNGKey,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate a batch of training data from the mess3 process.
    
    Args:
        generator: The generative process (mess3)
        batch_size: Number of sequences in the batch
        sequence_len: Length of each sequence
        key: JAX random key
        
    Returns:
        Tuple of (inputs, labels) as PyTorch tensors
    """
    # Initialize generator states
    gen_state = generator.initial_state
    gen_states = jnp.repeat(gen_state[None, :], batch_size, axis=0)
    
    # Generate data using simplexity's torch_generator
    gen_states, inputs, labels = generate_data_batch(
        gen_states,
        generator,
        batch_size,
        sequence_len,
        key,
        bos_token=None,
        eos_token=None,
    )
    
    # Convert to PyTorch tensors
    inputs_torch = torch.from_numpy(np.array(inputs)).long().to(device)
    labels_torch = torch.from_numpy(np.array(labels)).long().to(device)
    
    return inputs_torch, labels_torch

# Test data generation
test_key = jax.random.PRNGKey(0)
test_inputs, test_labels = generate_training_batch(mess3, batch_size=4, sequence_len=10, key=test_key)
print(f"Generated batch shape: {test_inputs.shape}")
print(f"Sample sequence: {test_inputs[0].cpu().numpy()}")
print(f"Sample labels: {test_labels[0].cpu().numpy()}")

## 5. Training Loop

Train the TransformerLens model on data from the mess3 process.

In [None]:
def train_transformerlens_on_mess3(
    model: TransformerLensWrapper,
    generator,
    num_steps: int = 1000,
    batch_size: int = 32,
    sequence_len: int = 64,
    learning_rate: float = 1e-3,
    log_every: int = 100,
    seed: int = 42,
):
    """Train TransformerLens model on mess3 data.
    
    Args:
        model: TransformerLens model wrapper
        generator: mess3 generative process
        num_steps: Number of training steps
        batch_size: Batch size
        sequence_len: Sequence length
        learning_rate: Learning rate for Adam optimizer
        log_every: Log metrics every N steps
        seed: Random seed
        
    Returns:
        List of losses
    """
    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Initialize JAX random key
    key = jax.random.PRNGKey(seed)
    
    # Training metrics
    losses = []
    
    print(f"Starting training for {num_steps} steps...")
    print(f"Batch size: {batch_size}, Sequence length: {sequence_len}")
    
    model.train()
    
    for step in range(num_steps):
        # Generate new batch
        key, batch_key = jax.random.split(key)
        inputs, labels = generate_training_batch(
            generator, batch_size, sequence_len, batch_key
        )
        
        # Forward pass - TransformerLens computes loss internally
        # For next-token prediction, we use the inputs as both input and target
        loss = model.model(inputs, return_type="loss")
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track loss
        loss_value = loss.item()
        losses.append(loss_value)
        
        # Logging
        if (step + 1) % log_every == 0:
            avg_loss = np.mean(losses[-log_every:])
            print(f"Step {step + 1}/{num_steps}: Loss = {loss_value:.4f}, Avg = {avg_loss:.4f}")
    
    print(f"\nTraining completed! Final loss: {losses[-1]:.4f}")
    return losses

# Train the model
training_losses = train_transformerlens_on_mess3(
    model=model,
    generator=mess3,
    num_steps=1000,
    batch_size=32,
    sequence_len=64,
    learning_rate=1e-3,
    log_every=100,
    seed=42,
)

## 6. Visualize Training Progress

In [None]:
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(training_losses, alpha=0.7)
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('TransformerLens Training on mess3 Process')
plt.grid(True, alpha=0.3)

# Add smoothed curve
window_size = 50
if len(training_losses) > window_size:
    smoothed = np.convolve(training_losses, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(window_size//2, len(training_losses)-window_size//2+1), 
             smoothed, 'r-', linewidth=2, label='Smoothed')
    plt.legend()

plt.show()

print(f"Initial loss: {training_losses[0]:.4f}")
print(f"Final loss: {training_losses[-1]:.4f}")
print(f"Improvement: {(training_losses[0] - training_losses[-1]) / training_losses[0] * 100:.1f}%")

## 7. Model Evaluation

Evaluate the trained model on fresh data from the mess3 process.

In [None]:
def evaluate_model(model, generator, num_batches=10, batch_size=32, sequence_len=64, seed=999):
    """Evaluate model on fresh data."""
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0
    
    key = jax.random.PRNGKey(seed)
    
    with torch.no_grad():
        for _ in range(num_batches):
            key, batch_key = jax.random.split(key)
            inputs, labels = generate_training_batch(
                generator, batch_size, sequence_len, batch_key
            )
            
            # Get model predictions
            logits = model(inputs, return_loss=False)
            
            # Calculate loss
            loss = F.cross_entropy(
                logits[:, :-1].reshape(-1, logits.size(-1)),
                inputs[:, 1:].reshape(-1)
            )
            total_loss += loss.item()
            
            # Calculate accuracy
            predictions = logits[:, :-1].argmax(dim=-1)
            correct = (predictions == inputs[:, 1:]).sum().item()
            total_correct += correct
            total_tokens += inputs[:, 1:].numel()
    
    avg_loss = total_loss / num_batches
    accuracy = total_correct / total_tokens
    
    model.train()
    return avg_loss, accuracy

# Evaluate the model
val_loss, val_accuracy = evaluate_model(model, mess3)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.2%}")

# Compare with random baseline
random_accuracy = 1.0 / mess3.vocab_size
print(f"\nRandom baseline accuracy: {random_accuracy:.2%}")
print(f"Improvement over random: {(val_accuracy - random_accuracy) / random_accuracy * 100:.1f}%")

## 8. Interpretability Analysis with TransformerLens

Use TransformerLens's built-in features to analyze what the model learned.

In [None]:
# Generate a sample sequence for analysis
model.eval()
analysis_key = jax.random.PRNGKey(123)
sample_inputs, _ = generate_training_batch(mess3, batch_size=1, sequence_len=20, key=analysis_key)

print(f"Analyzing sequence: {sample_inputs[0].cpu().numpy()}")

# Run with cache to get all activations
logits, cache = model.run_with_cache(sample_inputs)

# Get predictions
predictions = logits[0].argmax(dim=-1).cpu().numpy()
print(f"Model predictions: {predictions[:-1]}")
print(f"Actual next tokens: {sample_inputs[0, 1:].cpu().numpy()}")

## 9. Attention Pattern Visualization

Visualize the attention patterns learned by the model.

In [None]:
# Extract attention patterns for each layer
n_layers = model.config.n_layers
n_heads = model.config.n_heads

fig, axes = plt.subplots(n_layers, n_heads, figsize=(n_heads * 4, n_layers * 4))
if n_layers == 1:
    axes = axes.reshape(1, -1)
if n_heads == 1:
    axes = axes.reshape(-1, 1)

for layer in range(n_layers):
    # Get attention patterns for this layer
    attn_patterns = cache["pattern", layer][0].cpu().numpy()  # Shape: (n_heads, seq_len, seq_len)
    
    for head in range(n_heads):
        ax = axes[layer, head] if n_layers > 1 and n_heads > 1 else axes[max(layer, head)]
        
        # Plot attention pattern
        im = ax.imshow(attn_patterns[head], cmap='Blues', aspect='auto')
        ax.set_title(f'Layer {layer}, Head {head}')
        ax.set_xlabel('Source Position')
        ax.set_ylabel('Target Position')
        plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle('Attention Patterns Across Layers and Heads', fontsize=16)
plt.tight_layout()
plt.show()

# Analyze attention focus
for layer in range(n_layers):
    attn_patterns = cache["pattern", layer][0].cpu().numpy()
    avg_attention = attn_patterns.mean(axis=0)  # Average across heads
    
    # Calculate entropy to measure attention focus
    entropy = -np.sum(avg_attention * np.log(avg_attention + 1e-10), axis=-1)
    print(f"Layer {layer} - Average attention entropy: {entropy.mean():.3f}")

## 10. Analyze Model's Internal Representations

Examine how the model represents the mess3 states internally.

In [None]:
# Get embeddings and residual stream activations
embed = cache["embed"][0].cpu().numpy()  # Token embeddings
final_residual = cache["resid_post", n_layers - 1][0].cpu().numpy()  # Final layer residual

# Analyze token embeddings
print("Token Embedding Analysis:")
print(f"Embedding shape: {embed.shape}")
print(f"Final residual shape: {final_residual.shape}")

# Calculate cosine similarity between token embeddings
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

# Get unique token embeddings
unique_tokens = list(range(mess3.vocab_size))
token_embeds = model.model.embed.W_E.detach().cpu().numpy()

print("\nCosine similarity between token embeddings:")
sim_matrix = np.zeros((len(unique_tokens), len(unique_tokens)))
for i, token_i in enumerate(unique_tokens):
    for j, token_j in enumerate(unique_tokens):
        sim = cosine_similarity(token_embeds[token_i], token_embeds[token_j])
        sim_matrix[i, j] = sim
        if i < j:
            print(f"  Token {token_i} vs Token {token_j}: {sim:.3f}")

# Visualize similarity matrix
plt.figure(figsize=(6, 5))
plt.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1)
plt.colorbar(label='Cosine Similarity')
plt.xlabel('Token')
plt.ylabel('Token')
plt.title('Token Embedding Similarity Matrix')
plt.xticks(range(len(unique_tokens)), unique_tokens)
plt.yticks(range(len(unique_tokens)), unique_tokens)
for i in range(len(unique_tokens)):
    for j in range(len(unique_tokens)):
        plt.text(j, i, f'{sim_matrix[i, j]:.2f}', ha='center', va='center')
plt.show()

## 11. Probe Model's Understanding of mess3 Dynamics

Test if the model learned the underlying transition structure of mess3.

In [None]:
def analyze_transition_predictions(model, generator, num_samples=1000, seed=456):
    """Analyze how well the model predicts transitions."""
    model.eval()
    
    # Track predictions for each token pair
    transition_counts = np.zeros((generator.vocab_size, generator.vocab_size))
    transition_correct = np.zeros((generator.vocab_size, generator.vocab_size))
    
    key = jax.random.PRNGKey(seed)
    
    with torch.no_grad():
        for _ in range(num_samples // 32):
            key, batch_key = jax.random.split(key)
            inputs, _ = generate_training_batch(generator, 32, 64, batch_key)
            
            logits = model(inputs, return_loss=False)
            predictions = logits.argmax(dim=-1)
            
            # Count transitions
            for i in range(inputs.shape[1] - 1):
                current_tokens = inputs[:, i].cpu().numpy()
                next_tokens = inputs[:, i + 1].cpu().numpy()
                predicted_tokens = predictions[:, i].cpu().numpy()
                
                for curr, next_tok, pred in zip(current_tokens, next_tokens, predicted_tokens):
                    transition_counts[curr, next_tok] += 1
                    if pred == next_tok:
                        transition_correct[curr, next_tok] += 1
    
    # Calculate accuracy for each transition
    transition_accuracy = np.divide(
        transition_correct,
        transition_counts,
        out=np.zeros_like(transition_correct),
        where=transition_counts > 0
    )
    
    # Normalize to get empirical transition probabilities
    empirical_transitions = transition_counts / transition_counts.sum(axis=1, keepdims=True)
    
    return transition_accuracy, empirical_transitions, transition_counts

# Analyze transitions
trans_acc, emp_trans, trans_counts = analyze_transition_predictions(model, mess3)

# Visualize transition accuracy
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot empirical transitions
im1 = axes[0].imshow(emp_trans, cmap='Blues', vmin=0, vmax=1)
axes[0].set_title('Empirical Transition Probabilities\n(from generated data)')
axes[0].set_xlabel('Next Token')
axes[0].set_ylabel('Current Token')
plt.colorbar(im1, ax=axes[0])
for i in range(mess3.vocab_size):
    for j in range(mess3.vocab_size):
        axes[0].text(j, i, f'{emp_trans[i, j]:.2f}', ha='center', va='center')

# Plot prediction accuracy
im2 = axes[1].imshow(trans_acc, cmap='RdYlGn', vmin=0, vmax=1)
axes[1].set_title('Model Prediction Accuracy\nfor Each Transition')
axes[1].set_xlabel('Next Token')
axes[1].set_ylabel('Current Token')
plt.colorbar(im2, ax=axes[1])
for i in range(mess3.vocab_size):
    for j in range(mess3.vocab_size):
        if trans_counts[i, j] > 0:
            axes[1].text(j, i, f'{trans_acc[i, j]:.2f}', ha='center', va='center')

# Plot count distribution
im3 = axes[2].imshow(trans_counts, cmap='YlOrRd')
axes[2].set_title('Transition Counts\n(total observations)')
axes[2].set_xlabel('Next Token')
axes[2].set_ylabel('Current Token')
plt.colorbar(im3, ax=axes[2])
for i in range(mess3.vocab_size):
    for j in range(mess3.vocab_size):
        axes[2].text(j, i, f'{int(trans_counts[i, j])}', ha='center', va='center')

plt.tight_layout()
plt.show()

# Summary statistics
overall_accuracy = trans_correct.sum() / trans_counts.sum()
print(f"\nOverall transition prediction accuracy: {overall_accuracy:.2%}")
print(f"\nPer-token accuracy:")
for i in range(mess3.vocab_size):
    token_acc = trans_correct[i].sum() / trans_counts[i].sum() if trans_counts[i].sum() > 0 else 0
    print(f"  Token {i}: {token_acc:.2%}")

## 12. Summary and Conclusions

This notebook demonstrated:
1. How to use simplexity to generate data from the mess3 Hidden Markov Model
2. How to train a TransformerLens model on this synthetic data
3. How to use TransformerLens's interpretability features to understand what the model learned

Key observations:
- The model successfully learns to predict the next token in mess3 sequences
- Attention patterns reveal how the model tracks dependencies in the sequence
- The model's internal representations capture the structure of the underlying HMM

This approach can be extended to:
- Other generative processes in simplexity
- Larger transformer models
- More complex analysis using TransformerLens's advanced features

In [None]:
# Save the trained model if needed
save_model = False  # Set to True to save

if save_model:
    model_path = "transformerlens_mess3_model.pt"
    torch.save({
        'model_state_dict': model.model.state_dict(),
        'config': model_config,
        'training_losses': training_losses,
    }, model_path)
    print(f"Model saved to {model_path}")