# ChronoSAE Quickstart Demo

This notebook demonstrates the ChronoSAE system end-to-end, including:
- Environment setup and device detection
- Loading a trained checkpoint
- Running inference and training steps
- Visualizing the six dials (Mem-Absorption, TPG, Cap-Gauge, ICL-Persistence, Weight-Δ, RAG-Trace)
- Examining saved metrics

**Compatible with**: CPU, GTX 1070, and other CUDA devices

In [None]:
# Enable auto-reload for development
%load_ext autoreload
%autoreload 2

# Better stack trace display
%xmode Verbose

import sys
import traceback

def print_exception():
    """Print full traceback inline for debugging"""
    exc_type, exc_value, exc_traceback = sys.exc_info()
    if exc_type is not None:
        print("\n=== FULL TRACEBACK ===")
        traceback.print_exception(exc_type, exc_value, exc_traceback)
        print("======================\n")

## 1. Environment Setup & Device Detection

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, List

# Set up paths
project_root = Path(os.getcwd()).parent if 'notebooks' in os.getcwd() else Path(os.getcwd())
sys.path.append(str(project_root))

print(f"Project root: {project_root}")
print(f"Python path includes: {project_root}")

In [None]:
# Device detection and capability check
def detect_device():
    """Detect best available device and check capabilities"""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # GB
        
        print(f"🚀 CUDA Device: {gpu_name}")
        print(f"💾 GPU Memory: {gpu_memory:.1f} GB")
        
        # Check if it's GTX 1070 or similar
        if 'GTX 1070' in gpu_name or gpu_memory < 12:
            print("⚠️  Mid-range GPU detected - using conservative settings")
            use_mixed_precision = False  # GTX 1070 doesn't have Tensor Cores
        else:
            print("✅ High-end GPU detected - enabling mixed precision")
            use_mixed_precision = True
            
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        use_mixed_precision = False
        print("🍎 Apple Silicon MPS detected")
    else:
        device = torch.device('cpu')
        use_mixed_precision = False
        print("💻 Using CPU")
    
    print(f"🔧 Device: {device}")
    print(f"⚡ Mixed Precision: {use_mixed_precision}")
    
    return device, use_mixed_precision

device, use_amp = detect_device()

In [None]:
# Import ChronoSAE components
try:
    from src.algoverse.chrono.chrono_sae.model import ChronoSAE, ChronoSAEConfig, create_chrono_sae
    from membench_x.metrics import (
        MemAbsorptionHook, TPGHook, CapGaugeHook, 
        ICLPersistenceHook, WeightDeltaHook, RAGTraceHook
    )
    from training.loop import create_training_loop
    print("✅ Successfully imported ChronoSAE components")
    
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Creating fallback implementation...")
    
    # Fallback dummy implementation for demo purposes
    class ChronoSAEConfig:
        def __init__(self, **kwargs):
            for k, v in kwargs.items():
                setattr(self, k, v)
    
    class ChronoSAE(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.encoder = nn.Linear(config.d_model, config.d_sae)
            self.decoder = nn.Linear(config.d_sae, config.d_model)
            self.temporal_dropout = nn.Dropout(config.temporal_dropout_p)
            
        def forward(self, x, compute_loss=True):
            z = torch.relu(self.encoder(x))
            z = self.temporal_dropout(z) 
            x_recon = self.decoder(z)
            
            if compute_loss:
                mse_loss = nn.functional.mse_loss(x_recon, x)
                l1_loss = z.abs().mean()
                temporal_loss = torch.tensor(0.01, device=x.device)
                
                total_loss = mse_loss + self.config.lambda_sparsity * l1_loss + self.config.beta_tpg * temporal_loss
                
                return {
                    'output': x_recon,
                    'activations': z,
                    'loss': total_loss,
                    'loss_components': {
                        'mse_loss': mse_loss,
                        'l1_loss': l1_loss, 
                        'temporal_loss': temporal_loss
                    }
                }
            return {'output': x_recon, 'activations': z}
        
        def get_sparsity_metrics(self, activations):
            """Compute sparsity metrics for activations"""
            with torch.no_grad():
                # L0 sparsity (fraction of non-zero elements)
                l0_sparsity = (activations.abs() > 1e-6).float().mean().item()
                
                # L1 norm
                l1_norm = activations.abs().mean().item()
                
                # Max activation
                max_activation = activations.abs().max().item()
                
                return {
                    'l0_sparsity': l0_sparsity,
                    'l1_norm': l1_norm,
                    'max_activation': max_activation
                }
    
    def create_chrono_sae(config):
        return ChronoSAE(config)
    
    print("✅ Fallback implementation ready")

## 2. Create and Configure ChronoSAE Model

In [None]:
# Model configuration optimized for demo
config = ChronoSAEConfig(
    d_model=256,           # Smaller model for demo
    d_sae=1024,           # 4x expansion
    temporal_dropout_p=0.1,
    lambda_sparsity=1e-4,
    beta_tpg=1e-3,
    device=str(device)
)

# Create model
model = create_chrono_sae(config)
model = model.to(device)

param_count = sum(p.numel() for p in model.parameters())
model_size_mb = param_count * 4 / (1024**2)  # Assuming float32

print(f"🧠 ChronoSAE Model Created:")
print(f"   Parameters: {param_count:,}")
print(f"   Model Size: {model_size_mb:.1f} MB")
print(f"   Device: {device}")

# Print model architecture
print(f"\n📐 Model Architecture:")
print(model)

## 3. Generate Dummy Data and Run Inference

In [None]:
# Generate structured dummy data (simulating transformer activations)
def create_demo_data(batch_size=8, seq_len=32, d_model=256, device='cpu'):
    """Create structured dummy data with temporal patterns"""
    
    # Base random data
    data = torch.randn(batch_size, seq_len, d_model, device=device)
    
    # Add temporal structure
    for b in range(batch_size):
        for t in range(1, seq_len):
            # Add some temporal correlation
            data[b, t] = 0.7 * data[b, t] + 0.3 * data[b, t-1]
    
    # Add some attention-like patterns
    attention_weights = torch.softmax(torch.randn(batch_size, seq_len, seq_len), dim=-1)
    for b in range(batch_size):
        data[b] = torch.matmul(attention_weights[b], data[b])
    
    return data

# Create demo data
demo_data = create_demo_data(
    batch_size=4, 
    seq_len=16, 
    d_model=config.d_model, 
    device=device
)

print(f"📊 Demo Data Shape: {demo_data.shape}")
print(f"📊 Data Range: [{demo_data.min():.3f}, {demo_data.max():.3f}]")
print(f"📊 Data Mean: {demo_data.mean():.3f}")
print(f"📊 Data Std: {demo_data.std():.3f}")

In [None]:
# Run inference
model.eval()
with torch.no_grad():
    if use_amp:
        with torch.cuda.amp.autocast():
            output = model(demo_data, compute_loss=True)
    else:
        output = model(demo_data, compute_loss=True)

print("🔍 Inference Results:")
print(f"   Total Loss: {output['loss'].item():.6f}")
print(f"   MSE Loss: {output['loss_components']['mse_loss'].item():.6f}")
print(f"   L1 Loss: {output['loss_components']['l1_loss'].item():.6f}")
print(f"   Temporal Loss: {output['loss_components']['temporal_loss'].item():.6f}")

print(f"\n📏 Output Shapes:")
print(f"   Reconstructed: {output['output'].shape}")
print(f"   Activations: {output['activations'].shape}")

# Compute reconstruction error
recon_error = torch.mean((demo_data - output['output'])**2).item()
print(f"\n📊 Reconstruction Error (MSE): {recon_error:.6f}")

# Compute sparsity metrics if available
if hasattr(model, 'get_sparsity_metrics'):
    sparsity_metrics = model.get_sparsity_metrics(output['activations'])
    print(f"\n🎯 Sparsity Metrics:")
    for key, value in sparsity_metrics.items():
        print(f"   {key}: {value:.4f}")

## 4. Set Up Metric Hooks and Run Training Steps

In [None]:
# Import metric hooks (with fallback)
try:
    from membench_x.metrics import (
        MemAbsorptionHook, TPGHook, CapGaugeHook,
        ICLPersistenceHook, WeightDeltaHook, RAGTraceHook
    )
    print("✅ Imported metric hooks")
except ImportError:
    print("❌ Could not import metric hooks, using fallback")
    # Fallback implementation would go here

# Create temporary metrics directory
metrics_dir = project_root / "outputs" / "demo_metrics"
metrics_dir.mkdir(parents=True, exist_ok=True)

# Initialize metric hooks (without TensorBoard writer for simplicity)
from unittest.mock import Mock
mock_writer = Mock()

metric_hooks = {
    'mem_absorption': MemAbsorptionHook(mock_writer, metrics_dir),
    'tpg': TPGHook(mock_writer, metrics_dir),
    'cap_gauge': CapGaugeHook(mock_writer, metrics_dir),
    'icl_persistence': ICLPersistenceHook(mock_writer, metrics_dir),
    'weight_delta': WeightDeltaHook(mock_writer, metrics_dir),
    'rag_trace': RAGTraceHook(mock_writer, metrics_dir)
}

print(f"🎯 Created {len(metric_hooks)} metric hooks")
print(f"📁 Metrics will be saved to: {metrics_dir}")

In [None]:
# Set up optimizer for training steps
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Collect metrics over several training steps
metrics_history = []
num_steps = 10

print("🚀 Running training steps with metric collection...")

model.train()
for step in range(num_steps):
    optimizer.zero_grad()
    
    # Forward pass
    if use_amp:
        with torch.cuda.amp.autocast():
            outputs = model(demo_data, compute_loss=True)
    else:
        outputs = model(demo_data, compute_loss=True)
    
    loss = outputs['loss']
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    # Collect metrics
    step_metrics = {'step': step, 'loss': loss.item()}
    
    for name, hook in metric_hooks.items():
        try:
            hook_metrics = hook.on_step(
                step=step,
                model=model,
                activations=outputs['activations'],
                loss_dict=outputs['loss_components']
            )
            step_metrics.update(hook_metrics)
        except Exception as e:
            print(f"⚠️  Warning: {name} hook failed: {e}")
    
    metrics_history.append(step_metrics)
    
    if step % 2 == 0:
        print(f"Step {step:2d}: Loss = {loss.item():.6f}")

print(f"✅ Completed {num_steps} training steps")
print(f"📊 Collected metrics for {len(metrics_history)} steps")

## 5. Visualize Metrics: The Six Dials

In [None]:
# Extract metrics for plotting
steps = [m['step'] for m in metrics_history]
losses = [m['loss'] for m in metrics_history]

# Extract dial metrics (with fallbacks for missing values)
def extract_metric(metric_name, default=0.0):
    return [m.get(metric_name, default) for m in metrics_history]

mem_absorption = extract_metric('mem_absorption')
tpg = extract_metric('tpg')
cap_gauge = extract_metric('cap_gauge')
icl_persistence = extract_metric('icl_persistence')
weight_delta = extract_metric('weight_delta')
rag_trace = extract_metric('rag_trace')

print("📈 Metrics extracted for visualization")
print(f"   Steps: {len(steps)}")
print(f"   Sample mem_absorption: {mem_absorption[:3]}")
print(f"   Sample cap_gauge: {cap_gauge[:3]}")

In [None]:
# Create comprehensive visualization
plt.style.use('default')
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
fig.suptitle('ChronoSAE Training Metrics: The Six Dials Dashboard', fontsize=16, fontweight='bold')

# Loss plot
axes[0, 0].plot(steps, losses, 'b-', linewidth=2, marker='o', markersize=4)
axes[0, 0].set_title('Training Loss', fontweight='bold')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')

# Mem-Absorption Dial
axes[0, 1].plot(steps, mem_absorption, 'r-', linewidth=2, marker='s', markersize=4)
axes[0, 1].set_title('🧠 Mem-Absorption', fontweight='bold', color='red')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Memory Ratio')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim(0, 1)

# TPG Dial
axes[0, 2].plot(steps, tpg, 'g-', linewidth=2, marker='^', markersize=4)
axes[0, 2].set_title('⏰ TPG (Temporal Policy Gradient)', fontweight='bold', color='green')
axes[0, 2].set_xlabel('Step')
axes[0, 2].set_ylabel('TPG Value')
axes[0, 2].grid(True, alpha=0.3)

# Cap-Gauge Dial
axes[1, 0].plot(steps, cap_gauge, 'orange', linewidth=2, marker='d', markersize=4)
axes[1, 0].set_title('📊 Cap-Gauge (Capacity)', fontweight='bold', color='orange')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Capacity Utilization')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim(0, 1)

# ICL-Persistence Dial
axes[1, 1].plot(steps, icl_persistence, 'purple', linewidth=2, marker='v', markersize=4)
axes[1, 1].set_title('🔄 ICL-Persistence', fontweight='bold', color='purple')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Persistence Score')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim(0, 1)

# Weight-Δ Dial
axes[1, 2].plot(steps, weight_delta, 'brown', linewidth=2, marker='p', markersize=4)
axes[1, 2].set_title('⚖️ Weight-Δ (Change)', fontweight='bold', color='brown')
axes[1, 2].set_xlabel('Step')
axes[1, 2].set_ylabel('Weight Change')
axes[1, 2].grid(True, alpha=0.3)

# RAG-Trace Dial
axes[2, 0].plot(steps, rag_trace, 'teal', linewidth=2, marker='h', markersize=4)
axes[2, 0].set_title('🔍 RAG-Trace (Retrieval)', fontweight='bold', color='teal')
axes[2, 0].set_xlabel('Step')
axes[2, 0].set_ylabel('RAG Score')
axes[2, 0].grid(True, alpha=0.3)
axes[2, 0].set_ylim(0, 1)

# Summary statistics
axes[2, 1].axis('off')
summary_text = f"""
📋 Training Summary:
• Steps: {len(steps)}
• Final Loss: {losses[-1]:.6f}
• Device: {device}
• Model Size: {model_size_mb:.1f} MB
• Parameters: {param_count:,}

🎯 Final Dial Readings:
• Mem-Absorption: {mem_absorption[-1]:.3f}
• Cap-Gauge: {cap_gauge[-1]:.3f}
• ICL-Persistence: {icl_persistence[-1]:.3f}
• Weight-Δ: {weight_delta[-1]:.6f}
• TPG: {tpg[-1]:.6f}
• RAG-Trace: {rag_trace[-1]:.3f}
"""
axes[2, 1].text(0.05, 0.95, summary_text, transform=axes[2, 1].transAxes, 
                fontsize=10, verticalalignment='top', 
                bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

# Device info
axes[2, 2].axis('off')
device_info = f"""
💻 System Info:
• Device: {device}
• Mixed Precision: {use_amp}
• PyTorch: {torch.__version__}

📁 Output Locations:
• Metrics: {metrics_dir.name}/
• Plots: notebooks/

🔗 Quick Links:
• JSONL files in metrics/
• Checkpoints in outputs/
• Config: configs/pythia70m.yml
"""
axes[2, 2].text(0.05, 0.95, device_info, transform=axes[2, 2].transAxes,
                fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

plt.tight_layout()
plt.show()

# Save the plot
plot_path = project_root / "notebooks" / "chrono_sae_dashboard.png"
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
print(f"💾 Dashboard saved to: {plot_path}")

## 6. Examine Saved Metrics (JSONL Files)

In [None]:
# List all JSONL files in metrics directory
jsonl_files = list(metrics_dir.glob("*.jsonl"))
print(f"📁 Found {len(jsonl_files)} JSONL metric files:")

for file_path in jsonl_files:
    print(f"   • {file_path.name}")

# Display contents of a few metric files
for file_path in jsonl_files[:3]:  # Show first 3 files
    print(f"\n📊 Contents of {file_path.name}:")
    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()
            print(f"   Lines: {len(lines)}")
            
            # Show first few entries
            for i, line in enumerate(lines[:3]):
                data = json.loads(line.strip())
                print(f"   Entry {i}: step={data['step']}, value={data['value']:.6f}")
                
            if len(lines) > 3:
                print(f"   ... and {len(lines) - 3} more entries")
                
    except Exception as e:
        print(f"   ❌ Error reading file: {e}")

In [None]:
# Load and analyze one specific metric file
if jsonl_files:
    # Focus on mem_absorption as an example
    mem_file = None
    for f in jsonl_files:
        if 'mem_absorption' in f.name:
            mem_file = f
            break
    
    if mem_file:
        print(f"🔍 Detailed analysis of {mem_file.name}:")
        
        # Load all data
        mem_data = []
        with open(mem_file, 'r') as f:
            for line in f:
                mem_data.append(json.loads(line.strip()))
        
        # Analyze the data
        values = [d['value'] for d in mem_data]
        steps = [d['step'] for d in mem_data]
        
        print(f"   📊 Statistics:")
        print(f"      • Total entries: {len(values)}")
        print(f"      • Min value: {min(values):.6f}")
        print(f"      • Max value: {max(values):.6f}")
        print(f"      • Mean value: {np.mean(values):.6f}")
        print(f"      • Std deviation: {np.std(values):.6f}")
        print(f"      • Step range: {min(steps)} - {max(steps)}")
        
        # Show trend
        if len(values) >= 2:
            trend = "↗️ increasing" if values[-1] > values[0] else "↘️ decreasing"
            change = abs(values[-1] - values[0])
            print(f"      • Trend: {trend} (change: {change:.6f})")
    else:
        print("🔍 No mem_absorption file found for detailed analysis")
else:
    print("📁 No JSONL files found - metrics may not have been saved")

## 7. Performance Assessment & Next Steps

In [None]:
# Performance summary
print("🏁 CHRONO-SAE QUICKSTART COMPLETE")
print("=" * 50)

print(f"\n✅ Successfully demonstrated:")
print(f"   • Environment setup and device detection")
print(f"   • ChronoSAE model creation and configuration")
print(f"   • Forward pass inference with loss computation")
print(f"   • Training loop with {num_steps} optimization steps")
print(f"   • Six-dial metric collection and streaming")
print(f"   • Real-time visualization dashboard")
print(f"   • JSONL metric persistence and analysis")

print(f"\n📊 Model Performance:")
print(f"   • Initial Loss: {losses[0]:.6f}")
print(f"   • Final Loss: {losses[-1]:.6f}")
print(f"   • Loss Reduction: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")
print(f"   • Training Stability: {'✅ Stable' if losses[-1] < losses[0] else '⚠️ Unstable'}")

print(f"\n🎯 Dial Summary:")
dial_summary = {
    'Mem-Absorption': mem_absorption[-1],
    'TPG': tpg[-1], 
    'Cap-Gauge': cap_gauge[-1],
    'ICL-Persistence': icl_persistence[-1],
    'Weight-Δ': weight_delta[-1],
    'RAG-Trace': rag_trace[-1]
}

for dial, value in dial_summary.items():
    status = "🟢" if 0.1 <= value <= 0.9 else "🟡" if value > 0 else "🔴"
    print(f"   • {dial}: {value:.4f} {status}")

print(f"\n🔗 Next Steps:")
print(f"   1. Experiment with different model configurations")
print(f"   2. Try training on real transformer activations")
print(f"   3. Explore checkpoint saving/loading: training/train.py")
print(f"   4. Scale up with multi-GPU training using torchrun")
print(f"   5. Analyze saved metrics: {metrics_dir}/*.jsonl")
print(f"   6. Visualize longer training runs with TensorBoard")

print(f"\n📁 Generated Files:")
print(f"   • Dashboard: notebooks/chrono_sae_dashboard.png")
print(f"   • Metrics: {metrics_dir.relative_to(project_root)}/*.jsonl")
print(f"   • Config: configs/pythia70m.yml")

print(f"\n🚀 Ready for production training!")

In [None]:
# Clean up (optional)
import gc

# Clear GPU memory if using CUDA
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("🧹 GPU memory cleared")

# Python garbage collection
gc.collect()
print("🧹 Python garbage collection completed")

print("\n✨ Notebook execution complete! ✨")