# Explore Crystalline Checkpoint

Interactive exploration of a trained Crystalline model checkpoint.

This notebook allows you to:
- Load and inspect checkpoint metadata
- Visualize bottleneck statistics (temperatures, codebook)
- Run inference and analyze code activations
- Generate interactive plots

In [None]:
# Setup path for imports
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")

In [None]:
# Imports
import torch
import numpy as np

from analysis import (
    load_checkpoint_for_analysis,
    extract_bottleneck_stats,
    setup_style,
    COLORS,
)

# Try to import interactive visualizations
try:
    from analysis.visualize_interactive import (
        plot_layer_temperatures_interactive,
        plot_codebook_usage_interactive,
    )
    INTERACTIVE = True
    print("Interactive visualizations available (Plotly)")
except ImportError:
    INTERACTIVE = False
    print("Plotly not available - using static plots")
    from analysis.visualize import plot_layer_temperatures, plot_codebook_usage

## 1. Load Checkpoint

In [None]:
# Specify checkpoint path
CHECKPOINT_PATH = project_root / "checkpoints" / "tinystories" / "checkpoint_final.pt"

# Alternative: specify your own path
# CHECKPOINT_PATH = Path("/path/to/your/checkpoint.pt")

print(f"Loading checkpoint: {CHECKPOINT_PATH}")
print(f"Exists: {CHECKPOINT_PATH.exists()}")

In [None]:
# Load checkpoint
if CHECKPOINT_PATH.exists():
    result = load_checkpoint_for_analysis(CHECKPOINT_PATH)
    print("Checkpoint loaded successfully!")
else:
    print("Checkpoint not found. Run training first or specify a different path.")
    result = None

## 2. Checkpoint Metadata

In [None]:
if result:
    print("=" * 50)
    print("CHECKPOINT METADATA")
    print("=" * 50)
    print(f"Training Step: {result.step}")
    print(f"Epoch: {result.epoch}")
    print(f"\nModel Config:")
    for key, value in result.config.get('model', {}).items():
        if key != 'bottleneck':
            print(f"  {key}: {value}")
    print(f"\nBottleneck Config:")
    for key, value in result.config.get('model', {}).get('bottleneck', {}).items():
        print(f"  {key}: {value}")
    print(f"\nSaved Metrics:")
    for key, value in result.metrics.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")

## 3. Bottleneck Statistics

In [None]:
if result:
    stats = result.bottleneck_stats
    
    print("=" * 50)
    print("BOTTLENECK STATISTICS")
    print("=" * 50)
    print(f"Number of layers: {stats['n_layers']}")
    print(f"Codebook size: {stats['codebook_sizes'][0]}")
    print(f"Top-k codes: {stats['num_codes_k'][0]}")
    print(f"\nTemperature Summary:")
    print(f"  Mean: {stats['temperature_summary']['mean']:.4f}")
    print(f"  Min:  {stats['temperature_summary']['min']:.4f}")
    print(f"  Max:  {stats['temperature_summary']['max']:.4f}")
    print(f"  Std:  {stats['temperature_summary']['std']:.4f}")

In [None]:
if result:
    print("\nPer-Layer Temperatures:")
    print("-" * 40)
    print(f"{'Layer':<8} {'Attention':<12} {'MLP':<12}")
    print("-" * 40)
    for layer_stats in stats['layers']:
        i = layer_stats['layer']
        attn_t = layer_stats['attn']['temperature']
        mlp_t = layer_stats['mlp']['temperature']
        print(f"L{i:<7} {attn_t:<12.4f} {mlp_t:<12.4f}")

## 4. Visualize Temperatures

In [None]:
if result:
    temps = stats['temperatures']
    
    if INTERACTIVE:
        fig = plot_layer_temperatures_interactive(temps)
        fig.show()
    else:
        setup_style('notebook')
        fig = plot_layer_temperatures(temps)
        plt.show()

## 5. Model Summary

In [None]:
if result:
    model = result.model
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("=" * 50)
    print("MODEL SUMMARY")
    print("=" * 50)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"\nModel architecture:")
    print(f"  Vocab size: {model.token_embedding.num_embeddings}")
    print(f"  Hidden dim: {model.token_embedding.embedding_dim}")
    print(f"  Layers: {len(model.blocks)}")

## 6. Codebook Analysis

Analyze the learned codebook vectors.

In [None]:
if result:
    from analysis.checkpoint_analysis import get_codebook_embeddings
    
    # Get codebook from first layer's attention bottleneck
    codebook = get_codebook_embeddings(result.model, layer=0, bn_type='attn')
    
    print(f"Codebook shape: {codebook.shape}")
    print(f"Codebook norm (should be ~1 for normalized): {np.linalg.norm(codebook, axis=1).mean():.4f}")

In [None]:
if result:
    # Compute pairwise similarities
    codebook_norm = codebook / np.linalg.norm(codebook, axis=1, keepdims=True)
    similarities = codebook_norm @ codebook_norm.T
    
    # Visualize similarity matrix
    import matplotlib.pyplot as plt
    
    setup_style('notebook')
    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(similarities, cmap='RdBu_r', vmin=-1, vmax=1)
    ax.set_xlabel('Code Index')
    ax.set_ylabel('Code Index')
    ax.set_title('Codebook Similarity Matrix (Layer 0, Attention)')
    plt.colorbar(im, ax=ax, label='Cosine Similarity')
    plt.tight_layout()
    plt.show()
    
    # Statistics
    off_diag = similarities[np.triu_indices(len(similarities), k=1)]
    print(f"\nOff-diagonal similarity stats:")
    print(f"  Mean: {off_diag.mean():.4f}")
    print(f"  Std:  {off_diag.std():.4f}")
    print(f"  Max:  {off_diag.max():.4f}")

## 7. Run Inference (Optional)

Run the model on sample data to collect code activations.

In [None]:
# Create sample input
if result:
    vocab_size = model.token_embedding.num_embeddings
    
    # Random tokens as example
    sample_input = torch.randint(0, vocab_size, (1, 32))
    print(f"Sample input shape: {sample_input.shape}")
    
    # Run inference
    model.eval()
    with torch.no_grad():
        logits, infos = model(sample_input)
    
    print(f"Output logits shape: {logits.shape}")
    print(f"Number of layer infos: {len(infos)}")

In [None]:
# Analyze code activations from inference
if result and 'infos' in dir():
    print("\nCode Activation Statistics (per layer):")
    print("-" * 60)
    print(f"{'Layer':<8} {'Type':<8} {'Entropy':<12} {'Active Codes':<15}")
    print("-" * 60)
    
    for layer_info in infos:
        layer_idx = layer_info['layer']
        for bn_type in ['attn', 'mlp']:
            info = layer_info[bn_type]
            entropy = info['entropy'].item()
            hard_codes = info['hard_codes']
            active = (hard_codes > 0.5).sum().item()
            total = hard_codes.numel()
            print(f"L{layer_idx:<7} {bn_type:<8} {entropy:<12.4f} {active}/{total}")

## 8. Save Analysis

Export analysis results for later use.

In [None]:
if result:
    import json
    
    # Create analysis summary
    analysis_summary = {
        'checkpoint_path': str(CHECKPOINT_PATH),
        'step': result.step,
        'epoch': result.epoch,
        'temperature_summary': stats['temperature_summary'],
        'temperatures_per_layer': {
            'attn': stats['temperatures']['attn'],
            'mlp': stats['temperatures']['mlp'],
        },
        'n_layers': stats['n_layers'],
        'codebook_size': stats['codebook_sizes'][0],
    }
    
    # Save to JSON
    output_path = project_root / 'analysis_output.json'
    with open(output_path, 'w') as f:
        json.dump(analysis_summary, f, indent=2)
    
    print(f"Analysis saved to: {output_path}")

---

## Next Steps

- **FSM Analysis**: See `fsm_analysis.ipynb` for state-code alignment analysis
- **TinyStories Analysis**: See `tinystories_analysis.ipynb` for language model analysis
- **Generate Figures**: Use `analysis.visualize` to create publication-quality plots