# Head Specialization Analysis

This notebook analyzes attention head behavior to discover **functional specialization** during training.

**Goal**: Determine what each attention head has learned to focus on by correlating attention patterns with semantic driving features.

## Key Outputs
1. **HSI Scores** - Head Specialization Index for each head
2. **Functional Labels** - Human-readable labels (Safety Head, Lane Head, etc.)
3. **Visualizations** - Heatmaps, scatter plots, evolution charts
4. **Head Function Registry** - JSON file for real-time XAI explanations

## Data Sources
This analysis requires data from **two directories**:
1. **Feature Dir** (`runs/{model}/attention_logs/`): Training logs with semantic features from `AttentionLogger`
2. **Attention Dir** (`xai/attention_analysis/attention_extractions/`): Offline extracted attention weights

## 1. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - Update these paths for your setup
# ============================================================================

# Path to training logs (semantic features + metadata)
# These are generated by AttentionLogger during training
FEATURE_DIR = "../../runs/PPO_VEC_WAYFORMER/attention_logs/"  # UPDATE THIS

# Path to offline extracted attention weights
# These are generated by offline_attention_extraction.ipynb
ATTENTION_DIR = "./attention_extractions/"  # UPDATE THIS

# Output directory for results
OUTPUT_DIR = "./hsi_results/"

# Optional: Filter logs by training step range
STEP_RANGE = None  # Or set to (start, end) e.g., (10000, 50000)

## 2. Setup

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

# Import our analysis module
from head_specialization_analysis import (
    AttentionLogLoader,
    HeadSpecializationAnalyzer,
    HeadVisualization,
    FEATURE_TO_LABEL
)

# Set up matplotlib
%matplotlib inline
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

print("✓ Setup complete")

## 3. Load Attention Logs

The loader will merge:
- **Semantic features** from training logs (TTC, distance, spatial relationships)
- **Attention weights** from offline extractions

In [None]:
# Create analyzer with both directories
analyzer = HeadSpecializationAnalyzer(
    feature_dir=FEATURE_DIR,
    attention_dir=ATTENTION_DIR
)

# Load and merge data
analyzer.load_data(step_range=STEP_RANGE)

print(f"\nLoaded {len(analyzer.logs)} attention logs")
if analyzer.logs:
    print(f"Step range: {analyzer.logs[0].step} → {analyzer.logs[-1].step}")

In [None]:
# Inspect first log structure
if analyzer.logs:
    sample_log = analyzer.logs[0]
    
    print("=" * 50)
    print("ATTENTION LOG STRUCTURE")
    print("=" * 50)
    print(f"\nStep: {sample_log.step}")
    
    print("\nAttention weight keys:")
    for key, val in sample_log.attention_weights.items():
        print(f"  {key}: shape={val.shape}")
    
    print("\nSemantic features:")
    for key, val in sample_log.semantic_features.items():
        if isinstance(val, np.ndarray):
            print(f"  {key}: shape={val.shape}")
        else:
            print(f"  {key}: {val}")
    
    print("\nToken boundaries:")
    for key, val in sample_log.token_boundaries.items():
        print(f"  {key}: {val}")
    
    print("\nConfig:")
    for key, val in sample_log.config.items():
        print(f"  {key}: {val}")
else:
    print("⚠ No logs loaded. Check FEATURE_DIR and ATTENTION_DIR paths.")

## 4. Aggregate Attention by Vehicle

In [None]:
# Aggregate attention weights by vehicle
attention_per_vehicle = analyzer.aggregate_attention_by_vehicle()

print(f"\nAggregated attention shape: {attention_per_vehicle.shape}")
print(f"  (N_scenarios, N_heads, N_vehicles)")

In [None]:
# Visualize attention distribution per head
n_heads = attention_per_vehicle.shape[1]

fig, axes = plt.subplots(1, n_heads, figsize=(4*n_heads, 4))
if n_heads == 1:
    axes = [axes]

for h, ax in enumerate(axes):
    head_attn = attention_per_vehicle[:, h, :].flatten()
    ax.hist(head_attn, bins=50, alpha=0.7, edgecolor='black')
    ax.set_xlabel('Attention Weight')
    ax.set_ylabel('Count')
    ax.set_title(f'Head {h}')
    ax.axvline(head_attn.mean(), color='r', linestyle='--', label=f'Mean={head_attn.mean():.3f}')
    ax.legend(fontsize=8)

plt.suptitle('Attention Weight Distribution per Head', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Aggregate Semantic Features

In [None]:
# Aggregate semantic features
features = analyzer.aggregate_features()

if features:
    print("\nAggregated feature shapes:")
    for name, arr in features.items():
        print(f"  {name}: {arr.shape}")
else:
    print("⚠ No semantic features found in logs.")
    print("  HSI analysis requires training logs (from AttentionLogger).")
    print("  Make sure FEATURE_DIR points to runs/{model}/attention_logs/")

## 6. Compute Head Specialization Index (HSI)

In [None]:
# Compute HSI
try:
    result = analyzer.compute_hsi()
    
    print("\n" + "=" * 60)
    print("HSI RESULTS SUMMARY")
    print("=" * 60)
    for h in range(len(result.hsi_scores)):
        label = result.head_labels[h]
        print(f"\nHead {h}:")
        print(f"  Name: {label['name']}")
        print(f"  HSI Score: {label['hsi']:.3f}")
        print(f"  Primary Feature: {label.get('primary_feature', 'N/A')}")
        print(f"  Correlation: {label.get('correlation', 0):.3f}")
        print(f"  Description: {label['description']}")
    print("=" * 60)
    
except Exception as e:
    print(f"⚠ Could not compute HSI: {e}")
    print("  This typically means semantic features are not available.")
    result = None

## 7. Visualizations

In [None]:
# Create visualizer
if result is not None:
    viz = HeadVisualization(analyzer)
    print("✓ Visualizer ready")
else:
    viz = None
    print("⚠ Cannot create visualizations without HSI results")

### 7.1 Correlation Heatmap

In [None]:
if viz is not None:
    viz.plot_correlation_heatmap()

### 7.2 HSI Bar Chart

In [None]:
if viz is not None:
    viz.plot_hsi_bar()

### 7.3 Attention vs. Feature Scatter Plots

In [None]:
if viz is not None and result is not None:
    # Plot scatter for each specialized head
    for h, label in result.head_labels.items():
        if label.get('primary_feature') and label['hsi'] >= 0.3:
            print(f"\n--- Head {h}: {label['name']} ---")
            viz.plot_attention_vs_feature(h, label['primary_feature'])

## 8. Export Head Function Registry

In [None]:
import os
import json

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

if result is not None:
    registry_path = os.path.join(OUTPUT_DIR, 'head_functions.json')
    analyzer.export_registry(registry_path)
    
    # Display the registry
    with open(registry_path, 'r') as f:
        registry = json.load(f)
    
    print("\nHEAD FUNCTION REGISTRY")
    print("=" * 50)
    print(json.dumps(registry, indent=2))
else:
    print("⚠ Cannot export registry without HSI results")

## 9. Save All Visualizations

In [None]:
if viz is not None:
    viz.plot_all(OUTPUT_DIR)
    print(f"\n✓ All visualizations saved to {OUTPUT_DIR}")
else:
    print("⚠ Cannot save visualizations without results")

## 10. Interpretation Guide

### What the Results Mean

| HSI Score | Interpretation |
|-----------|----------------|
| > 0.7 | **Strong specialization** - Head clearly focuses on specific feature |
| 0.5 - 0.7 | **Moderate specialization** - Head shows preference for certain features |
| 0.3 - 0.5 | **Weak specialization** - Slight tendency toward specific features |
| < 0.3 | **General context** - Diffuse attention, no clear specialization |

### Expected Head Functions

| Function | Feature | Expected Correlation | Interpretation |
|----------|---------|---------------------|----------------|
| **Safety Head** | TTC | Negative | Attends to low-TTC (collision threats) |
| **Proximity Head** | Distance | Negative | Attends to nearby vehicles |
| **Threat Assessment** | Closing Speed | Positive | Attends to approaching vehicles |
| **Traffic Flow** | Is Ahead | Positive | Attends to leading vehicles |
| **Lane Awareness** | Is Left/Right | Positive | Monitors adjacent lanes |

### Using the Registry for XAI

```python
# Load registry at inference time
with open('head_functions.json') as f:
    HEAD_REGISTRY = json.load(f)

# Interpret attention during inference
def explain_action(attention_weights):
    for head_idx, weight in enumerate(attention_weights):
        info = HEAD_REGISTRY[str(head_idx)]
        if weight > threshold:
            print(f"{info['name']} active: {info['description']}")
```

## 11. Summary

In [None]:
print("\n" + "=" * 60)
print("ANALYSIS SUMMARY")
print("=" * 60)
print(f"\nData Sources:")
print(f"  Feature Dir: {FEATURE_DIR}")
print(f"  Attention Dir: {ATTENTION_DIR}")
print(f"\nLogs analyzed: {len(analyzer.logs)}")
print(f"Scenarios processed: {attention_per_vehicle.shape[0]}")
print(f"Attention heads: {attention_per_vehicle.shape[1]}")

if result is not None:
    specialized = sum(1 for h in result.hsi_scores if h >= 0.3)
    strongly_specialized = sum(1 for h in result.hsi_scores if h >= 0.5)
    
    print(f"\nSpecialized heads (HSI ≥ 0.3): {specialized}")
    print(f"Strongly specialized (HSI ≥ 0.5): {strongly_specialized}")
    
    print(f"\nOutput files:")
    print(f"  Registry: {os.path.join(OUTPUT_DIR, 'head_functions.json')}")
    print(f"  Heatmap: {os.path.join(OUTPUT_DIR, 'correlation_heatmap.png')}")
    print(f"  HSI Chart: {os.path.join(OUTPUT_DIR, 'hsi_scores.png')}")

print("\n" + "=" * 60)
print("✓ Head Specialization Analysis Complete")
print("=" * 60)