# WS5: Circuit Analysis Exploration Notebook

This notebook provides an interactive exploration of circuit checkpoint analysis using the tested WS5 pipeline.

## Overview

- **Objective**: Analyze how circuits evolve during fine-tuning on WS2 synthetic constraints
- **Data**: WS3 checkpoints at 25%, 50%, 75%, and 100% training completion
- **Methods**: Circuit comparison, saturation detection, learning pattern extraction

## Setup

In [None]:
# Import required libraries
import sys
import json
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any

# Import WS5 analysis functions (tested pipeline)
from core import (
    load_checkpoint_circuits, generate_circuit_analysis, compare_attribution_graphs,
    detect_saturation, extract_learning_patterns, save_analysis_results
)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

print("✅ WS5 Analysis Pipeline loaded successfully")
print(f"📅 Notebook started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## 1. Load and Explore Checkpoints

First, let's load the checkpoint metadata from the WS3 fine-tuning run.

In [None]:
# Define paths
CHECKPOINT_DIR = Path("../ws3/outputs/run_20250719_181803/circuit_checkpoints")
WS2_DATASET_PATH = Path("../data/ws2_synthetic_corpus_hf")
BASE_MODEL_PATH = Path("../models/gemma-2b")

# Load checkpoint information
print("📁 Loading checkpoint information...")
checkpoints = load_checkpoint_circuits(CHECKPOINT_DIR)

print(f"✅ Found {len(checkpoints)} checkpoints")
for name, info in sorted(checkpoints.items(), key=lambda x: x[1].step):
    print(f"   {name}: step {info.step}, progress {info.progress:.1%}, loss {info.loss:.3f}")

In [None]:
# Create checkpoint progression visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Extract data for plotting
sorted_checkpoints = sorted(checkpoints.values(), key=lambda x: x.step)
steps = [cp.step for cp in sorted_checkpoints]
progress = [cp.progress for cp in sorted_checkpoints]
losses = [cp.loss for cp in sorted_checkpoints]
learning_rates = [cp.learning_rate for cp in sorted_checkpoints]
epochs = [cp.epoch for cp in sorted_checkpoints]

# Plot training progression
ax1.plot(steps, progress, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel('Training Step')
ax1.set_ylabel('Progress (%)')
ax1.set_title('Training Progress')
ax1.grid(True, alpha=0.3)

ax2.plot(steps, losses, 'o-', linewidth=2, markersize=8, color='red')
ax2.set_xlabel('Training Step')
ax2.set_ylabel('Loss')
ax2.set_title('Loss Progression')
ax2.grid(True, alpha=0.3)

ax3.plot(steps, learning_rates, 'o-', linewidth=2, markersize=8, color='green')
ax3.set_xlabel('Training Step')
ax3.set_ylabel('Learning Rate')
ax3.set_title('Learning Rate Schedule')
ax3.grid(True, alpha=0.3)
ax3.ticklabel_format(style='scientific', axis='y', scilimits=(0,0))

ax4.plot(steps, epochs, 'o-', linewidth=2, markersize=8, color='purple')
ax4.set_xlabel('Training Step')
ax4.set_ylabel('Epoch')
ax4.set_title('Epoch Progression')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('WS3 Fine-tuning Checkpoint Analysis', y=1.02, fontsize=16)
plt.show()

print(f"📈 Loss improvement: {losses[0] - losses[-1]:.3f}")
print(f"🎯 Final checkpoint: {sorted_checkpoints[-1].name} at step {sorted_checkpoints[-1].step}")

## 2. Prepare Constraint Examples

Load the WS2 synthetic dataset and prepare constraint examples for analysis.

In [None]:
# Load WS2 dataset
try:
    from datasets import load_from_disk
    ws2_dataset = load_from_disk(str(WS2_DATASET_PATH))
    print(f"✅ Loaded WS2 dataset: {len(ws2_dataset)} examples")
except Exception as e:
    print(f"⚠️  Could not load WS2 dataset: {e}")
    ws2_dataset = None

# Extract constraint examples
constraint_examples = {}
test_prompts = []

if ws2_dataset:
    for example in ws2_dataset:
        constraint_type = example['constraint_type']
        text = example['text']
        
        if constraint_type not in constraint_examples:
            constraint_examples[constraint_type] = []
        
        constraint_examples[constraint_type].append(text)
    
    # Limit to manageable number for analysis
    for constraint_type in constraint_examples:
        constraint_examples[constraint_type] = constraint_examples[constraint_type][:5]
    
    # Create test prompts list
    for examples in constraint_examples.values():
        test_prompts.extend(examples)
    
    print("📋 Constraint Examples:")
    for constraint_type, examples in constraint_examples.items():
        print(f"   {constraint_type}: {len(examples)} examples")
        for i, example in enumerate(examples[:2]):
            print(f"     {i+1}. {example}")
        if len(examples) > 2:
            print(f"     ... and {len(examples)-2} more")
else:
    # Fallback test prompts
    test_prompts = [
        "The blarf cat is happy",
        "The gleem day was sad",
        "The zephyr car goes fast", 
        "The glide bird flies upward",
        "The cascade water falls downward"
    ]
    constraint_examples = {
        'simple_mapping': test_prompts[:3],
        'spatial_relationship': test_prompts[3:]
    }
    print("📝 Using fallback test prompts")

print(f"\n🧪 Total test prompts: {len(test_prompts)}")

## 3. Circuit Analysis Options

Choose analysis mode based on available resources:

- **Quick Mode**: Metadata analysis only (CPU-friendly)
- **Full Mode**: Load models and run circuit analysis (requires more resources)

For this demonstration, we'll start with Quick Mode and show how to extend to Full Mode.

In [None]:
# Configuration
QUICK_MODE = True  # Set to False for full model loading and circuit analysis
USE_CIRCUIT_TRACER = False  # Set to True if GPU and circuit tracer available

print(f"⚙️  Analysis Mode: {'Quick' if QUICK_MODE else 'Full'}")
print(f"🔬 Circuit Tracer: {'Enabled' if USE_CIRCUIT_TRACER else 'Disabled'}")

if QUICK_MODE:
    print("\n📊 Running Quick Analysis (metadata only)...")
    
    # Analyze checkpoint progression
    analysis_results = {
        'analysis_info': {
            'timestamp': datetime.now().isoformat(),
            'mode': 'quick',
            'num_checkpoints': len(checkpoints)
        },
        'checkpoints': checkpoints,
        'progression_analysis': {
            'loss_trend': losses,
            'loss_improvement': losses[0] - losses[-1],
            'training_steps': steps,
            'learning_rate_decay': learning_rates[0] / learning_rates[-1] if learning_rates[-1] > 0 else float('inf')
        }
    }
    
    print("✅ Quick analysis complete")
    
else:
    print("\n🧠 Running Full Analysis (loading models)...")
    print("⚠️  This may take several minutes and requires significant memory")
    
    # Full analysis would go here
    # See the "Full Analysis Example" section below for implementation

## 4. Analysis Results Visualization

Visualize the training progression and any detected patterns.

In [None]:
# Loss analysis
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Loss progression
ax1.plot(steps, losses, 'o-', linewidth=3, markersize=10, label='Training Loss')
ax1.set_xlabel('Training Step')
ax1.set_ylabel('Loss')
ax1.set_title('Fine-tuning Loss Progression')
ax1.grid(True, alpha=0.3)
ax1.legend()

# Add checkpoint labels
for i, (step, loss) in enumerate(zip(steps, losses)):
    checkpoint_name = sorted_checkpoints[i].name.replace('checkpoint-', '')
    ax1.annotate(checkpoint_name, (step, loss), 
                textcoords="offset points", xytext=(0,10), ha='center')

# Loss improvement rate
if len(losses) > 1:
    loss_diffs = [losses[i] - losses[i+1] for i in range(len(losses)-1)]
    mid_steps = [(steps[i] + steps[i+1])/2 for i in range(len(steps)-1)]
    
    ax2.bar(range(len(loss_diffs)), loss_diffs, alpha=0.7)
    ax2.set_xlabel('Checkpoint Interval')
    ax2.set_ylabel('Loss Improvement')
    ax2.set_title('Loss Improvement Between Checkpoints')
    ax2.set_xticks(range(len(loss_diffs)))
    ax2.set_xticklabels([f'{i+1}→{i+2}' for i in range(len(loss_diffs))])
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary statistics
print("📊 Training Summary:")
print(f"   Total loss improvement: {losses[0] - losses[-1]:.3f}")
print(f"   Average loss per step: {(losses[0] - losses[-1]) / steps[-1]:.5f}")
if len(loss_diffs) > 0:
    print(f"   Biggest improvement interval: {np.argmax(loss_diffs)+1}→{np.argmax(loss_diffs)+2}")
    print(f"   Improvement magnitude: {max(loss_diffs):.3f}")

## 5. Full Analysis Example (Optional)

This section shows how to run the complete analysis pipeline when resources are available.

**Note**: Uncomment and run this section only if you have:
- Sufficient memory (16GB+ recommended)
- GPU access (for circuit tracer)
- Time for model loading (~5-10 minutes per checkpoint)

In [None]:
# # Full Analysis Example - Uncomment to run
# if not QUICK_MODE and BASE_MODEL_PATH.exists():
#     print("🚀 Starting full circuit analysis...")
#     
#     analyses = []
#     
#     for name, checkpoint_info in sorted(checkpoints.items(), key=lambda x: x[1].step):
#         print(f"\n🔄 Analyzing {name}...")
#         
#         try:
#             # Load model with adapter
#             model = load_model_for_analysis(checkpoint_info, BASE_MODEL_PATH)
#             
#             # Generate circuit analysis
#             analysis = generate_circuit_analysis(
#                 model, test_prompts, checkpoint_info, USE_CIRCUIT_TRACER
#             )
#             analyses.append(analysis)
#             
#             print(f"✅ {name} analysis complete")
#             
#             # Clean up to save memory
#             del model
#             torch.cuda.empty_cache() if torch.cuda.is_available() else None
#             
#         except Exception as e:
#             print(f"❌ Failed to analyze {name}: {e}")
#             continue
#     
#     print(f"\n✅ Completed {len(analyses)} analyses")
#     
#     # Run comparisons
#     if len(analyses) >= 2:
#         print("\n🔄 Computing checkpoint comparisons...")
#         comparisons = []
#         
#         for i in range(len(analyses) - 1):
#             comparison = compare_attribution_graphs(analyses[i], analyses[i + 1])
#             comparisons.append(comparison)
#             print(f"   {comparison.checkpoint_1} → {comparison.checkpoint_2}: "
#                   f"similarity {comparison.similarity_score:.3f}")
#         
#         # Detect saturation
#         print("\n📈 Detecting saturation...")
#         saturation_result = detect_saturation(analyses)
#         
#         if saturation_result['saturated']:
#             print(f"🎯 Saturation detected at step {saturation_result['saturation_step']}")
#         else:
#             print("📊 No saturation detected")
#         
#         # Extract learning patterns
#         if constraint_examples:
#             print("\n🧬 Extracting learning patterns...")
#             learning_patterns = extract_learning_patterns(analyses, constraint_examples)
#             
#             if learning_patterns.get('learning_order'):
#                 print("📚 Learning order:")
#                 for constraint, improvement in learning_patterns['learning_order']:
#                     print(f"   {constraint}: {improvement:.3f} improvement")
#         
#         # Store full results
#         analysis_results = {
#             'analysis_info': {
#                 'timestamp': datetime.now().isoformat(),
#                 'mode': 'full',
#                 'num_checkpoints': len(checkpoints),
#                 'use_circuit_tracer': USE_CIRCUIT_TRACER
#             },
#             'checkpoints': checkpoints,
#             'analyses': analyses,
#             'comparisons': comparisons,
#             'saturation': saturation_result,
#             'learning_patterns': learning_patterns if 'learning_patterns' in locals() else {}
#         }
# 
# else:
#     print("⏭️  Skipping full analysis (quick mode enabled or base model not found)")

print("💡 To run full analysis, set QUICK_MODE = False and ensure base model is available")

## 6. Save Results and Summary

Save the analysis results and generate a summary report.

In [None]:
# Save results
output_file = f"analysis_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
save_analysis_results(analysis_results, output_file)
print(f"💾 Results saved to {output_file}")

# Generate summary
print("\n" + "="*60)
print("🎯 WS5 CIRCUIT ANALYSIS SUMMARY")
print("="*60)

print(f"📅 Analysis completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"📁 Checkpoint directory: {CHECKPOINT_DIR}")
print(f"🔢 Checkpoints analyzed: {len(checkpoints)}")
print(f"⚙️  Analysis mode: {analysis_results['analysis_info']['mode']}")

if 'progression_analysis' in analysis_results:
    prog = analysis_results['progression_analysis']
    print(f"📉 Loss improvement: {prog['loss_improvement']:.3f}")
    print(f"🎯 Final step: {prog['training_steps'][-1]}")

if 'saturation' in analysis_results and analysis_results['saturation'].get('saturated'):
    print(f"🔄 Saturation detected at step: {analysis_results['saturation']['saturation_step']}")
else:
    print("📈 No saturation detected")

if 'learning_patterns' in analysis_results and analysis_results['learning_patterns'].get('learning_order'):
    print("📚 Learning order detected:")
    for constraint, improvement in analysis_results['learning_patterns']['learning_order']:
        print(f"   {constraint}: {improvement:.3f}")

print("\n✅ Analysis complete! Use the CLI for additional analysis options.")
print(f"💡 Try: python cli.py analyze -c {CHECKPOINT_DIR} --quick")

## 7. Next Steps and Extensions

This notebook demonstrates the WS5 analysis pipeline. Here are suggested next steps:

### Immediate Actions
1. **Run Full Analysis**: Set `QUICK_MODE = False` and run with GPU access for detailed circuit analysis
2. **Circuit Tracer Integration**: Enable `USE_CIRCUIT_TRACER = True` for detailed attribution graphs
3. **Custom Constraints**: Modify `constraint_examples` to test specific learning hypotheses

### Extended Analysis
1. **Saturation Studies**: Experiment with different saturation detection thresholds
2. **Learning Rate Effects**: Compare results across different learning rate schedules
3. **Constraint Complexity**: Analyze learning patterns for different constraint complexities

### Visualization Enhancements
1. **Interactive Plots**: Add widget controls for parameter exploration
2. **Circuit Graphs**: Visualize attribution graphs when available
3. **Comparative Analysis**: Side-by-side comparison of different training runs

### Integration
1. **WS1 Circuit Tracer**: Full integration with circuit tracer for detailed analysis
2. **WS3 Pipeline**: Automated analysis as part of fine-tuning workflow
3. **Reporting**: Automated report generation for multiple experiments

---

**WS5 Analysis Pipeline Complete**

This notebook provides a comprehensive framework for analyzing circuit evolution during fine-tuning. The tested pipeline ensures reliable results while remaining flexible for different computational environments.

*Generated by WS5: Circuit Analysis Pipeline*