# TransformerLens + mess3: Simple Training Example

Train a small transformer on the mess3 Hidden Markov Model using TransformerLens.

In [None]:
# Cell 1: Installation (skip if already installed)
%pip -q install --upgrade pip wheel setuptools
%pip -q install "einops>=0.7.0" "jaxtyping>=0.2.28" "beartype>=0.14" better_abc
%pip -q install --no-deps "transformer-lens>=2.16.1"
%pip -q install "git+https://github.com/Astera-org/simplexity.git@MATS_2025_app"
print("✅ Ready!")

In [None]:
# Cell 2: Setup - Product of tom_quantum and mess3
import torch
import torch.nn.functional as F
import numpy as np
import jax
import jax.numpy as jnp
from transformer_lens import HookedTransformer, HookedTransformerConfig
from simplexity.generative_processes.builder import build_hidden_markov_model
from simplexity.generative_processes.torch_generator import generate_data_batch
from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel
from simplexity.generative_processes.transition_matrices import tom_quantum as tom_quantum_matrix

# Create mess3 process
mess3 = build_hidden_markov_model("mess3", x=0.15, a=0.6)

# Create tom_quantum manually since it's not in the builder
tom_transition = tom_quantum_matrix(alpha=1, beta=51)
tom_quantum = HiddenMarkovModel(transition_matrices=tom_transition)

print(f"mess3: vocab_size={mess3.vocab_size}, states={mess3.num_states}")
print(f"tom_quantum: vocab_size={tom_quantum.vocab_size}, states={tom_quantum.num_states}")

# Product space has vocab_size = 2 * 3 = 6
product_vocab_size = tom_quantum.vocab_size * mess3.vocab_size
print(f"Product space: vocab_size={product_vocab_size}")

# Create TransformerLens model for product space
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = HookedTransformerConfig(
    d_model=128,  # Bigger model for more complex data
    n_heads=4,
    n_layers=2,
    n_ctx=32,
    d_vocab=product_vocab_size,  # 6 tokens
    act_fn="relu",
    device=device,
)
model = HookedTransformer(cfg)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} params on {device}")

In [None]:
# Cell 3: Train on Product Space
from tqdm import tqdm

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []
batch_size, seq_len = 32, 32
key = jax.random.PRNGKey(42)

# Get stationary distributions
tom_stationary = tom_quantum.stationary_state
mess3_stationary = mess3.stationary_state

for step in tqdm(range(1000), desc="Training"):  # More steps for harder task
    # Generate NEW batch each time
    key, key1, key2 = jax.random.split(key, 3)
    
    # Generate from tom_quantum
    tom_states = jnp.repeat(tom_stationary[None, :], batch_size, axis=0)
    _, tom_inputs, _ = generate_data_batch(tom_states, tom_quantum, batch_size, seq_len, key1)
    
    # Generate from mess3
    mess3_states = jnp.repeat(mess3_stationary[None, :], batch_size, axis=0)
    _, mess3_inputs, _ = generate_data_batch(mess3_states, mess3, batch_size, seq_len, key2)
    
    # Combine into product space: token = tom * 3 + mess3
    # This maps (tom=0,mess3=0)->0, (0,1)->1, (0,2)->2, (1,0)->3, (1,1)->4, (1,2)->5
    if isinstance(tom_inputs, torch.Tensor):
        tom_arr = tom_inputs.cpu().numpy()
        mess3_arr = mess3_inputs.cpu().numpy()
    else:
        tom_arr = np.array(tom_inputs)
        mess3_arr = np.array(mess3_inputs)
    
    product_tokens = tom_arr * 3 + mess3_arr
    tokens = torch.from_numpy(product_tokens).long().to(device)
    
    # Train step
    loss = model(tokens, return_type="loss")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

print(f"\nFinal loss: {losses[-1]:.4f} (started at {losses[0]:.4f})")

# Visualization
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss on Product Space (tom_quantum × mess3)')
plt.show()

In [None]:
# Cell 4: Extract activations from residual stream (Product Space)
from sklearn.decomposition import PCA
import plotly.graph_objects as go

# Generate a batch for analysis
key, key1, key2 = jax.random.split(key, 3)

# Generate product space data
tom_states = jnp.repeat(tom_stationary[None, :], 100, axis=0)
_, tom_inputs, _ = generate_data_batch(tom_states, tom_quantum, 100, seq_len, key1)

mess3_states = jnp.repeat(mess3_stationary[None, :], 100, axis=0)
_, mess3_inputs, _ = generate_data_batch(mess3_states, mess3, 100, seq_len, key2)

# Convert and combine
if isinstance(tom_inputs, torch.Tensor):
    tom_arr = tom_inputs.cpu().numpy()
    mess3_arr = mess3_inputs.cpu().numpy()
else:
    tom_arr = np.array(tom_inputs)
    mess3_arr = np.array(mess3_inputs)

product_tokens_np = tom_arr * 3 + mess3_arr
tokens = torch.from_numpy(product_tokens_np).long().to(device)

# Run with cache to get all activations
logits, cache = model.run_with_cache(tokens)

# Extract residual stream activations
residual_streams = {
    'embeddings': cache['hook_embed'],
    'layer_0': cache['blocks.0.hook_resid_post'],
    'layer_1': cache['blocks.1.hook_resid_post'],
}

print("Activation shapes:")
for name, acts in residual_streams.items():
    print(f"  {name}: {acts.shape}")

# Flatten for PCA
activations_flat = {}
for name, acts in residual_streams.items():
    acts_reshaped = acts.reshape(-1, acts.shape[-1]).cpu().numpy()
    activations_flat[name] = acts_reshaped

# Create labels for visualization
token_labels = tokens.flatten().cpu().numpy()
tom_labels = token_labels // 3  # Extract tom component
mess3_labels = token_labels % 3   # Extract mess3 component

print(f"\nTotal points for PCA: {activations_flat['layer_1'].shape[0]}")
print(f"Token distribution: {np.bincount(token_labels)}")

In [None]:
# Cell 5: 2D PCA visualization of Product Space
%pip -q install plotly scikit-learn

# Perform PCA on the final layer activations
pca = PCA(n_components=2)
pca_coords = pca.fit_transform(activations_flat['layer_1'])

print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")
print(f"Total variance explained: {sum(pca.explained_variance_ratio_):.2%}")

# Create interactive 2D plot with plotly
fig = go.Figure()

# Color by full token (6 colors)
import matplotlib.cm as cm
colors = cm.tab10(np.linspace(0, 0.6, 6))

for token_id in range(6):
    mask = token_labels == token_id
    tom_val = token_id // 3
    mess3_val = token_id % 3
    
    fig.add_trace(go.Scatter(
        x=pca_coords[mask, 0],
        y=pca_coords[mask, 1],
        mode='markers',
        name=f'({tom_val}, {mess3_val})',
        marker=dict(
            size=5,
            color=f'rgb({int(colors[token_id][0]*255)},{int(colors[token_id][1]*255)},{int(colors[token_id][2]*255)})',
            opacity=0.6,
        ),
        text=[f'tom={tom_val}, mess3={mess3_val}' for _ in range(mask.sum())],
        hovertemplate='%{text}<br>PC1: %{x:.2f}<br>PC2: %{y:.2f}'
    ))

fig.update_layout(
    title='2D PCA of Product Space (tom_quantum × mess3)',
    xaxis_title='PC1',
    yaxis_title='PC2',
    height=600,
    width=800,
    showlegend=True
)

fig.show()

In [None]:
# Cell 6: Compare PCA across layers (2D)
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
colors = ['red', 'green', 'blue']

for idx, layer_name in enumerate(['embeddings', 'layer_0', 'layer_1']):
    # PCA for this layer
    pca = PCA(n_components=2)
    pca_coords = pca.fit_transform(activations_flat[layer_name])
    
    # Plot each token type
    for token_id in range(3):
        mask = token_labels == token_id
        axes[idx].scatter(
            pca_coords[mask, 0],
            pca_coords[mask, 1],
            c=colors[token_id],
            label=f'Token {token_id}',
            alpha=0.5,
            s=10
        )
    
    axes[idx].set_title(f'{layer_name}\n({sum(pca.explained_variance_ratio_):.1%} var)')
    axes[idx].set_xlabel('PC1')
    axes[idx].set_ylabel('PC2')
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.suptitle('Token Representations Across Layers (2D PCA)', fontsize=14)
plt.tight_layout()
plt.show()

# Show variance explained progression
print("Variance explained by first 2 PCs at each layer:")
for layer_name in ['embeddings', 'layer_0', 'layer_1']:
    pca = PCA(n_components=2)
    pca.fit(activations_flat[layer_name])
    print(f"  {layer_name}: {sum(pca.explained_variance_ratio_):.2%}")