# TinyStories Analysis

Analysis of a Crystalline model trained on TinyStories dataset.

This notebook covers:
- Loading trained checkpoint
- Evaluating perplexity
- Analyzing crystallization metrics
- Interpreting code activations on sample stories

In [None]:
# Setup
import sys
from pathlib import Path

project_root = Path.cwd().parent.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from analysis import load_checkpoint_for_analysis, setup_style, COLORS

try:
    from transformers import AutoTokenizer
    TOKENIZER_AVAILABLE = True
except ImportError:
    TOKENIZER_AVAILABLE = False
    print("transformers not available - some features disabled")

try:
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    PLOTLY = True
except ImportError:
    PLOTLY = False

## 1. Load Checkpoint

In [None]:
CHECKPOINT_PATH = project_root / "checkpoints" / "tinystories" / "checkpoint_final.pt"

if CHECKPOINT_PATH.exists():
    result = load_checkpoint_for_analysis(CHECKPOINT_PATH)
    model = result.model
    print("Checkpoint loaded!")
    print(f"Training step: {result.step}")
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")
    print("Run TinyStories training first.")
    result = None

## 2. Model Information

In [None]:
if result:
    stats = result.bottleneck_stats
    config = result.config
    
    print("=" * 50)
    print("MODEL CONFIGURATION")
    print("=" * 50)
    print(f"\nModel:")
    model_cfg = config.get('model', {})
    print(f"  Vocabulary: {model_cfg.get('vocab_size', 'N/A')}")
    print(f"  Dimension: {model_cfg.get('dim', 'N/A')}")
    print(f"  Layers: {model_cfg.get('n_layers', 'N/A')}")
    print(f"  Heads: {model_cfg.get('n_heads', 'N/A')}")
    
    bn_cfg = model_cfg.get('bottleneck', {})
    print(f"\nBottleneck:")
    print(f"  Codebook size: {bn_cfg.get('codebook_size', 'N/A')}")
    print(f"  Top-k codes: {bn_cfg.get('num_codes_k', 'N/A')}")
    
    print(f"\nCurrent State:")
    print(f"  Mean temperature: {stats['temperature_summary']['mean']:.4f}")
    print(f"  Temperature range: [{stats['temperature_summary']['min']:.4f}, {stats['temperature_summary']['max']:.4f}]")

## 3. Temperature Analysis

In [None]:
if result:
    from analysis.visualize_interactive import plot_layer_temperatures_interactive
    
    if PLOTLY:
        fig = plot_layer_temperatures_interactive(stats['temperatures'])
        fig.show()
    else:
        from analysis.visualize import plot_layer_temperatures
        setup_style('notebook')
        fig = plot_layer_temperatures(stats['temperatures'])
        plt.show()

## 4. Load Tokenizer

In [None]:
if TOKENIZER_AVAILABLE:
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
    print(f"Vocabulary size: {len(tokenizer)}")
else:
    tokenizer = None
    print("Tokenizer not available")

## 5. Sample Text Generation

In [None]:
def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.8):
    """Generate text from a prompt."""
    model.eval()
    
    # Tokenize prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    
    # Generate
    with torch.no_grad():
        for _ in range(max_length):
            logits, _ = model(input_ids)
            next_token_logits = logits[0, -1, :] / temperature
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [None]:
if result and tokenizer:
    prompts = [
        "Once upon a time",
        "The little girl",
        "One day, a boy",
    ]
    
    print("Generated Stories:")
    print("=" * 60)
    
    for prompt in prompts:
        try:
            generated = generate_text(model, tokenizer, prompt, max_length=30)
            print(f"\nPrompt: '{prompt}'")
            print(f"Generated: {generated}")
            print("-" * 60)
        except Exception as e:
            print(f"Error generating for '{prompt}': {e}")

## 6. Code Activation Analysis

In [None]:
def analyze_code_activations(model, tokenizer, text):
    """Analyze which codes activate for each token."""
    model.eval()
    
    # Tokenize
    input_ids = tokenizer.encode(text, return_tensors='pt')
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    # Forward pass
    with torch.no_grad():
        _, infos = model(input_ids)
    
    # Collect activations from first layer
    hard_codes = infos[0]['attn']['hard_codes'][0]  # (seq_len, codebook_size)
    entropy = infos[0]['attn']['entropy'].item()
    
    return {
        'tokens': tokens,
        'hard_codes': hard_codes,
        'entropy': entropy,
    }

In [None]:
if result and tokenizer:
    sample_text = "The little cat sat on the mat."
    
    analysis = analyze_code_activations(model, tokenizer, sample_text)
    
    print(f"Text: '{sample_text}'")
    print(f"Entropy: {analysis['entropy']:.4f}")
    print(f"\nToken -> Active Codes:")
    print("-" * 50)
    
    for i, token in enumerate(analysis['tokens']):
        active = (analysis['hard_codes'][i] > 0.5).nonzero(as_tuple=True)[0].tolist()
        print(f"  {token:<15} -> codes {active}")

## 7. Codebook Visualization

In [None]:
if result:
    from analysis.checkpoint_analysis import get_codebook_embeddings
    from sklearn.decomposition import PCA
    
    # Get codebook
    codebook = get_codebook_embeddings(model, layer=0, bn_type='attn')
    
    # PCA to 2D
    pca = PCA(n_components=2)
    codebook_2d = pca.fit_transform(codebook)
    
    print(f"Codebook shape: {codebook.shape}")
    print(f"Variance explained: {pca.explained_variance_ratio_.sum():.2%}")

In [None]:
if result:
    if PLOTLY:
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=codebook_2d[:, 0],
            y=codebook_2d[:, 1],
            mode='markers+text',
            text=[str(i) for i in range(len(codebook_2d))],
            textposition='top center',
            marker=dict(size=10, color=COLORS['primary']),
            hovertemplate='Code %{text}<br>PC1: %{x:.3f}<br>PC2: %{y:.3f}<extra></extra>',
        ))
        fig.update_layout(
            title='Codebook Embeddings (PCA)',
            xaxis_title='PC1',
            yaxis_title='PC2',
            height=500,
        )
        fig.show()
    else:
        setup_style('notebook')
        fig, ax = plt.subplots(figsize=(10, 8))
        ax.scatter(codebook_2d[:, 0], codebook_2d[:, 1], alpha=0.7)
        for i in range(len(codebook_2d)):
            ax.annotate(str(i), (codebook_2d[i, 0], codebook_2d[i, 1]), fontsize=8)
        ax.set_xlabel('PC1')
        ax.set_ylabel('PC2')
        ax.set_title('Codebook Embeddings (PCA)')
        plt.show()

## 8. Summary

In [None]:
if result:
    print("=" * 50)
    print("TINYSTORIES ANALYSIS SUMMARY")
    print("=" * 50)
    
    print(f"\nModel:")
    print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  Layers: {stats['n_layers']}")
    
    print(f"\nCrystallization State:")
    print(f"  Mean temperature: {stats['temperature_summary']['mean']:.4f}")
    
    crystallized = stats['temperature_summary']['mean'] < 1.0
    if crystallized:
        print(f"  Status: CRYSTALLIZED (temperature < 1.0)")
    else:
        print(f"  Status: Still warm (temperature >= 1.0)")
    
    print(f"\nNext Steps:")
    print(f"  - Evaluate perplexity on validation set")
    print(f"  - Analyze code specialization patterns")
    print(f"  - Compare with baseline transformer")