# FSM Experiment Analysis

Deep-dive analysis of the Finite State Machine validation experiment.

This notebook allows you to:
- Generate FSM data and visualize the state transitions
- Train a model and observe crystallization
- Analyze code-state alignment
- Compare with/without temperature annealing

In [None]:
# Setup
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from experiments.fsm.generate_data import FSMConfig, FiniteStateMachine, compute_state_code_alignment
from src.model import CrystallineTransformer
from src.config import ModelConfig, BottleneckConfig

from analysis import setup_style, COLORS

# Try interactive plots
try:
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    PLOTLY = True
except ImportError:
    PLOTLY = False
    print("Plotly not available - using matplotlib")

## 1. Generate FSM Data

In [None]:
# Configure FSM
NUM_STATES = 8
TOKENS_PER_STATE = 3
SEED = 42

fsm_config = FSMConfig(
    num_states=NUM_STATES,
    tokens_per_state=TOKENS_PER_STATE,
    vocab_size=NUM_STATES * TOKENS_PER_STATE,
    seed=SEED,
)

fsm = FiniteStateMachine(fsm_config)

print(f"FSM Configuration:")
print(f"  States: {NUM_STATES}")
print(f"  Tokens per state: {TOKENS_PER_STATE}")
print(f"  Vocabulary size: {fsm_config.vocab_size}")

In [None]:
# Visualize transition matrix
setup_style('notebook')

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(fsm.transition_matrix, cmap='Blues')
ax.set_xlabel('Next State')
ax.set_ylabel('Current State')
ax.set_title('FSM Transition Matrix (deterministic)')
ax.set_xticks(range(NUM_STATES))
ax.set_yticks(range(NUM_STATES))
plt.colorbar(im, ax=ax)

# Add transition arrows text
for i in range(NUM_STATES):
    for j in range(NUM_STATES):
        if fsm.transition_matrix[i, j] > 0:
            ax.text(j, i, '1', ha='center', va='center', color='white', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
# Generate sample data
batch_size = 4
seq_len = 32

inputs, targets, states = fsm.generate_batch(batch_size, seq_len)

print(f"Input shape: {inputs.shape}")
print(f"Target shape: {targets.shape}")
print(f"States shape: {states.shape}")

print(f"\nSample sequence (first batch):")
print(f"  Tokens: {inputs[0, :10].tolist()}")
print(f"  States: {states[0, :10].tolist()}")

## 2. Create Model

In [None]:
# Model configuration
model_config = ModelConfig(
    vocab_size=fsm_config.vocab_size,
    dim=128,
    n_layers=3,
    n_heads=4,
    max_seq_len=seq_len,
    dropout=0.0,
    bottleneck=BottleneckConfig(
        codebook_size=32,  # Small codebook to encourage state mapping
        num_codes_k=4,
        temp_init=2.0,
        temp_min=0.1,
    ),
)

model = CrystallineTransformer(model_config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Quick Training Loop

Train for a short period to observe crystallization.

In [None]:
import torch.nn.functional as F
from torch.optim import AdamW
from src.losses import compression_loss, commitment_loss

# Training config
N_STEPS = 500
BATCH_SIZE = 32
LR = 3e-4
LAMBDA_COMPRESS = 0.01
LAMBDA_COMMIT = 0.25

# Temperature annealing
TEMP_START = 2.0
TEMP_END = 0.3

optimizer = AdamW(model.parameters(), lr=LR)

# Tracking
history = {
    'steps': [],
    'loss': [],
    'accuracy': [],
    'temperature': [],
    'entropy': [],
}

In [None]:
# Training loop
model.train()
torch.manual_seed(SEED)

for step in range(N_STEPS):
    # Temperature annealing
    progress = step / max(N_STEPS - 1, 1)
    target_temp = TEMP_START + progress * (TEMP_END - TEMP_START)
    
    with torch.no_grad():
        for block in model.blocks:
            block.attn_bottleneck._temperature.fill_(target_temp)
            block.mlp_bottleneck._temperature.fill_(target_temp)
    
    # Generate batch
    inputs, targets, states = fsm.generate_batch(BATCH_SIZE, seq_len)
    
    # Forward
    logits, infos = model(inputs)
    
    # Losses
    pred_loss = F.cross_entropy(logits.reshape(-1, fsm_config.vocab_size), targets.reshape(-1))
    
    compress_loss_val = torch.tensor(0.0)
    commit_loss_val = torch.tensor(0.0)
    for layer_info in infos:
        for bn_type in ['attn', 'mlp']:
            compress_loss_val = compress_loss_val + compression_loss(layer_info[bn_type]['soft_codes'])
            commit_loss_val = commit_loss_val + commitment_loss(layer_info[bn_type]['input'], layer_info[bn_type]['output'])
    
    n_bottlenecks = len(infos) * 2
    compress_loss_val = compress_loss_val / n_bottlenecks
    commit_loss_val = commit_loss_val / n_bottlenecks
    
    loss = pred_loss + LAMBDA_COMPRESS * compress_loss_val + LAMBDA_COMMIT * commit_loss_val
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    # Track metrics
    if step % 50 == 0:
        with torch.no_grad():
            preds = logits.argmax(dim=-1)
            acc = (preds == targets).float().mean().item()
            
            # Entropy
            entropies = [layer_info['attn']['entropy'].item() for layer_info in infos]
            avg_entropy = sum(entropies) / len(entropies)
        
        history['steps'].append(step)
        history['loss'].append(loss.item())
        history['accuracy'].append(acc)
        history['temperature'].append(target_temp)
        history['entropy'].append(avg_entropy)
        
        print(f"Step {step:4d} | Loss: {loss.item():.4f} | Acc: {acc:.3f} | Temp: {target_temp:.2f} | Entropy: {avg_entropy:.3f}")

print("\nTraining complete!")

## 4. Visualize Training Curves

In [None]:
if PLOTLY:
    fig = make_subplots(rows=2, cols=2, subplot_titles=('Loss', 'Accuracy', 'Temperature', 'Entropy'))
    
    fig.add_trace(go.Scatter(x=history['steps'], y=history['loss'], name='Loss'), row=1, col=1)
    fig.add_trace(go.Scatter(x=history['steps'], y=history['accuracy'], name='Accuracy'), row=1, col=2)
    fig.add_trace(go.Scatter(x=history['steps'], y=history['temperature'], name='Temperature'), row=2, col=1)
    fig.add_trace(go.Scatter(x=history['steps'], y=history['entropy'], name='Entropy'), row=2, col=2)
    
    fig.update_layout(height=500, title_text='FSM Training Progress', showlegend=False)
    fig.show()
else:
    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    
    axes[0, 0].plot(history['steps'], history['loss'])
    axes[0, 0].set_title('Loss')
    
    axes[0, 1].plot(history['steps'], history['accuracy'])
    axes[0, 1].set_title('Accuracy')
    
    axes[1, 0].plot(history['steps'], history['temperature'])
    axes[1, 0].set_title('Temperature')
    
    axes[1, 1].plot(history['steps'], history['entropy'])
    axes[1, 1].set_title('Entropy')
    
    for ax in axes.flat:
        ax.set_xlabel('Step')
    
    plt.tight_layout()
    plt.show()

## 5. Code-State Alignment Analysis

In [None]:
# Evaluate alignment
model.eval()

all_hard_codes = []
all_states = []

with torch.no_grad():
    for _ in range(20):  # Collect statistics
        inputs, targets, states = fsm.generate_batch(BATCH_SIZE, seq_len)
        _, infos = model(inputs)
        
        # Use first layer's attention bottleneck
        hard_codes = infos[0]['attn']['hard_codes']
        all_hard_codes.append(hard_codes)
        all_states.append(states)

all_hard_codes = torch.cat(all_hard_codes, dim=0)
all_states = torch.cat(all_states, dim=0)

print(f"Collected codes shape: {all_hard_codes.shape}")
print(f"Collected states shape: {all_states.shape}")

In [None]:
# Compute alignment
alignment = compute_state_code_alignment(all_hard_codes, all_states, NUM_STATES)

print(f"\nCode-State Alignment Results:")
print(f"  Purity: {alignment['purity']:.3f}")
print(f"  Active codes: {alignment['active_codes']}")
print(f"  Random baseline purity: {1/NUM_STATES:.3f}")

In [None]:
# Visualize alignment matrix
from analysis.visualize import plot_code_state_alignment

# Build alignment matrix
codebook_size = model_config.bottleneck.codebook_size
alignment_matrix = np.zeros((codebook_size, NUM_STATES))

for i in range(all_hard_codes.shape[0]):
    for j in range(all_hard_codes.shape[1]):
        state = all_states[i, j].item()
        active_codes = (all_hard_codes[i, j] > 0.5).nonzero(as_tuple=True)[0]
        for code_idx in active_codes:
            alignment_matrix[code_idx.item(), state] += 1

fig = plot_code_state_alignment(
    alignment_matrix,
    n_states=NUM_STATES,
    codebook_size=codebook_size,
    purity=alignment['purity'],
    title='Code-State Alignment (Layer 0 Attention)'
)
plt.show()

## 6. Summary

Key observations from FSM experiment:

In [None]:
print("=" * 50)
print("FSM EXPERIMENT SUMMARY")
print("=" * 50)
print(f"\nConfiguration:")
print(f"  States: {NUM_STATES}")
print(f"  Codebook size: {codebook_size}")
print(f"  Training steps: {N_STEPS}")
print(f"  Temperature: {TEMP_START} -> {TEMP_END}")

print(f"\nResults:")
print(f"  Final accuracy: {history['accuracy'][-1]:.3f}")
print(f"  Random baseline: {1/NUM_STATES:.3f}")
print(f"  Improvement: {history['accuracy'][-1] / (1/NUM_STATES):.1f}x")
print(f"\n  Final entropy: {history['entropy'][-1]:.3f}")
print(f"  Code-state purity: {alignment['purity']:.3f}")
print(f"  Active codes: {alignment['active_codes']}/{codebook_size}")

print(f"\nConclusion:")
if alignment['purity'] > 1/NUM_STATES * 1.5:
    print("  Crystallization SUCCESSFUL - codes learned to represent states!")
else:
    print("  Crystallization incomplete - try more steps or adjust hyperparameters")