# Head Specialization Analysis - Complete Workflow

This notebook demonstrates the complete workflow for analyzing attention head specialization:
1. Extract attention weights + semantic features from checkpoints
2. Compute Head Specialization Index (HSI)
3. Visualize results and interpret head functions

**Prerequisites:**
- Trained PPO model with Wayformer encoder
- Dataset for scenario sampling

## Setup

In [None]:
import os
import sys
import pickle
import numpy as np
import matplotlib.pyplot as plt

# Add project root to path
project_root = os.path.abspath("../..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from xai.attention_analysis.head_specialization_analysis import HeadSpecializationAnalyzer

print("‚úì Setup complete")

## Configuration

In [None]:
# Paths
RUN_DIR = "../../runs/PPO_VEC_WAYFORMER"  # Your training run directory
DATASET_PATH = "../../training.tfrecord"  # Dataset for scenario sampling
EXTRACTION_DIR = "./extractions"  # Where to save/load extraction results
ANALYSIS_DIR = "./hsi_results"  # Where to save analysis results

# Parameters
N_SCENARIOS = 100  # Number of scenarios to analyze
CHECKPOINTS = ["model_final.pkl"]  # Can add multiple: ["model_10000.pkl", "model_50000.pkl", ...]

os.makedirs(EXTRACTION_DIR, exist_ok=True)
os.makedirs(ANALYSIS_DIR, exist_ok=True)

print(f"Run directory: {RUN_DIR}")
print(f"Dataset: {DATASET_PATH}")
print(f"Analyzing {len(CHECKPOINTS)} checkpoint(s) with {N_SCENARIOS} scenarios each")

---
## Step 1: Extract Attention + Semantic Features

Run `offline_extraction.py` to extract both attention weights and semantic features from scenarios.

This step:
- Loads the trained model checkpoint
- Samples scenarios from the dataset  
- Extracts attention weights from the Wayformer encoder
- Computes semantic features (TTC, distance, closing speed, etc.)
- Aggregates attention per vehicle per head
- Saves everything to a unified `.pkl` file

In [None]:
# Option 1: Run via command line (recommended for large extractions)
# Uncomment and run in terminal:

cmd = f"""python offline_extraction.py \\
    --run_dir {RUN_DIR} \\
    --dataset {DATASET_PATH} \\
    --n_scenarios {N_SCENARIOS} \\
    --output_dir {EXTRACTION_DIR} \\
    --checkpoints {' '.join(CHECKPOINTS)}
"""

print("Run this command in terminal:")
print(cmd)
print("\nOr uncomment the next cell to run from notebook...")

In [None]:
# Option 2: Run from notebook (may be slower, but shows progress)
# !cd ../../ && python xai/attention_analysis/offline_extraction.py \
#     --run_dir {RUN_DIR} \
#     --dataset {DATASET_PATH} \
#     --n_scenarios {N_SCENARIOS} \
#     --output_dir xai/attention_analysis/{EXTRACTION_DIR}

### Inspect Extraction Output

In [None]:
import glob

# Find extraction files
extraction_files = sorted(glob.glob(os.path.join(EXTRACTION_DIR, "extraction_*.pkl")))

if not extraction_files:
    print("‚ö† No extraction files found. Please run Step 1 first.")
else:
    print(f"Found {len(extraction_files)} extraction file(s):")
    for f in extraction_files:
        size_mb = os.path.getsize(f) / 1024 / 1024
        print(f"  - {os.path.basename(f)} ({size_mb:.2f} MB)")
    
    # Load and inspect the first one
    with open(extraction_files[0], 'rb') as f:
        data = pickle.load(f)
    
    print(f"\nüìä Extraction Summary:")
    print(f"  Checkpoint: {data['checkpoint']}")
    print(f"  Training Step: {data['step']}")
    print(f"  Number of Scenarios: {data['n_scenarios']}")
    
    # Inspect first scenario
    scenario = data['scenarios'][0]
    print(f"\nüìã Per-Scenario Data:")
    print(f"  Attention per vehicle shape: {scenario['attention_per_vehicle'].shape}")
    print(f"  Semantic features: {list(scenario['semantic_features'].keys())}")
    
    # Show sample feature values
    print(f"\nüìà Example Features (Scenario 0):")
    for key in ['distance_to_ego', 'ttc', 'closing_speed']:
        if key in scenario['semantic_features']:
            vals = scenario['semantic_features'][key]
            print(f"  {key}: {vals[:3]}... (shape: {vals.shape})")

---
## Step 2: Compute Head Specialization Index (HSI)

The HSI analysis computes correlations between each attention head and semantic features:
- For each head $h$ and feature $f$: $\rho_{h,f} = \text{corr}(\text{attention}_h, f)$
- HSI Score: $\text{HSI}_h = \max_f |\rho_{h,f}|$
- Primary Function: The feature with the strongest correlation

Heads with high HSI (>0.3) are considered "specialized".

In [None]:
# Create analyzer
analyzer = HeadSpecializationAnalyzer(extraction_dir=EXTRACTION_DIR)

# Load extraction data
analyzer.load_data()  # Loads first .pkl file by default

# Compute HSI
results = analyzer.compute_hsi()

print("\n‚úì HSI computation complete!")

### Inspect Results

In [None]:
print("="*60)
print("HEAD SPECIALIZATION SUMMARY")
print("="*60)

for head_idx, label in results.head_labels.items():
    specialized = "‚úì" if label['hsi'] >= analyzer.HSI_THRESHOLD else "‚úó"
    print(f"\n{specialized} Head {head_idx}: {label['name']}")
    print(f"   HSI Score: {label['hsi']:.3f}")
    if label.get('primary_feature'):
        print(f"   Primary Feature: {label['primary_feature']}")
        print(f"   Correlation: {label['correlation']:.3f} ({label['correlation_sign']})")
        print(f"   Description: {label['description']}")

print("\n" + "="*60)

---
## Step 3: Visualizations

### 3.1 Correlation Heatmap

Shows correlations between all heads and all semantic features.

In [None]:
from xai.attention_analysis.head_specialization_analysis import HeadVisualization

viz = HeadVisualization(analyzer)

# Generate heatmap
fig = viz.plot_correlation_heatmap()
plt.show()

### 3.2 HSI Bar Chart

Shows specialization scores for each head.

In [None]:
fig = viz.plot_hsi_bar()
plt.show()

### 3.3 Scatter Plots for Specialized Heads

For each specialized head, plot attention vs. its primary feature.

In [None]:
for head_idx, label in results.head_labels.items():
    if label.get('primary_feature'):
        print(f"\nüìä Head {head_idx}: {label['name']}")
        fig = viz.plot_attention_vs_feature(
            head_idx=head_idx,
            feature_name=label['primary_feature']
        )
        plt.show()

---
## Step 4: Export Results

In [None]:
# Export head registry to JSON
registry_path = os.path.join(ANALYSIS_DIR, "head_registry.json")
analyzer.export_registry(registry_path)

# Save all visualizations
viz_dir = os.path.join(ANALYSIS_DIR, "visualizations")
analyzer.visualize_all(viz_dir)

print(f"\n‚úì Results saved to: {ANALYSIS_DIR}")
print(f"  - Registry: {registry_path}")
print(f"  - Visualizations: {viz_dir}")

---
## Step 5: Evolution Analysis (Optional)

If you extracted multiple checkpoints, analyze how head specialization evolves during training.

In [None]:
if len(extraction_files) > 1:
    print("üìà Evolution Analysis\n")
    
    evolution_data = []
    
    for extraction_file in extraction_files:
        # Create fresh analyzer for each checkpoint
        analyzer_tmp = HeadSpecializationAnalyzer(extraction_dir=EXTRACTION_DIR)
        analyzer_tmp.load_data(extraction_file=os.path.basename(extraction_file))
        result = analyzer_tmp.compute_hsi()
        
        evolution_data.append({
            'step': result.checkpoint_step,
            'hsi_scores': result.hsi_scores,
            'primary_features': result.primary_features
        })
    
    # Plot HSI evolution
    n_heads = len(evolution_data[0]['hsi_scores'])
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    for h in range(n_heads):
        steps = [d['step'] for d in evolution_data]
        hsi_values = [d['hsi_scores'][h] for d in evolution_data]
        ax.plot(steps, hsi_values, marker='o', label=f'Head {h}')
    
    ax.axhline(y=0.3, color='red', linestyle='--', label='Specialization Threshold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('HSI Score')
    ax.set_title('Head Specialization Evolution During Training')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\n‚úì Evolution analysis complete!")
else:
    print("‚Ñπ Only one checkpoint found. Extract multiple checkpoints to see evolution.")

---
## Interpretation Guide

### Understanding HSI Scores

**HSI (Head Specialization Index)** measures how strongly a head correlates with semantic features:
- **HSI > 0.5**: Strong specialization (head has a clear function)
- **0.3 < HSI < 0.5**: Moderate specialization
- **HSI < 0.3**: General attention (no clear specialization)

### Common Head Functions

1. **Safety Head** (correlates with TTC)
   - Focuses on vehicles with low time-to-collision
   - Critical for collision avoidance
   
2. **Proximity Head** (correlates with distance)
   - Attends to nearby vehicles
   - Important for lane changing and merging
   
3. **Traffic Flow Head** (correlates with is_ahead)
   - Monitors vehicles ahead in the driving direction
   - Essential for speed regulation
   
4. **Lane Monitoring Heads** (correlates with is_left/is_right)
   - Tracks specific lane zones
   - Useful for lane-keeping and blind spot awareness

### What to Look For

‚úÖ **Healthy Model**:
- Multiple specialized heads (HSI > 0.3)
- At least one Safety Head (TTC correlation)
- Diverse primary features across heads

‚ö†Ô∏è **Potential Issues**:
- All heads are general (HSI < 0.3) ‚Üí Model may be under-trained
- No Safety Head ‚Üí Model might not be safety-conscious
- All heads have same primary feature ‚Üí Redundant attention

### Next Steps

1. **If specialization is weak**: Train longer or adjust network architecture
2. **If missing critical functions**: Adjust reward function or training scenarios
3. **For evolution analysis**: Compare early vs. late training to see when specialization emerges

---
## Summary

This notebook demonstrated:
1. ‚úÖ Extracting unified attention + semantic data
2. ‚úÖ Computing Head Specialization Index (HSI)
3. ‚úÖ Visualizing head functions
4. ‚úÖ Interpreting results
5. ‚úÖ (Optional) Analyzing evolution across checkpoints

**Files Generated:**
- `extractions/extraction_*.pkl`: Raw extraction data
- `hsi_results/head_registry.json`: Head function labels
- `hsi_results/visualizations/`: All plots

For more details, see:
- `offline_extraction.py`: Extraction implementation
- `head_specialization_analysis.py`: HSI computation and visualization