# Mechanistic Stability Analysis

This notebook tests whether the model's internal "explanations" (which components are important) remain stable when inputs are slightly perturbed but predictions stay the same.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from pathlib import Path

from Utilities.TST_trainer import TimeSeriesTransformer, load_dataset
from Utilities.utils import sweep_heads, get_probs
from Utilities.perturbations import (
    apply_perturbation, validate_perturbation, get_perturbation_configs
)
from Utilities.stability_metrics import (
    get_head_importance, compute_all_metrics, 
    plot_importance_comparison, create_summary_table
)

%matplotlib inline

## 1. Configuration

In [None]:
# Dataset to analyze
DATASET_NAME = 'JapaneseVowels'
MODEL_PATH = f'../TST_models/TST_{DATASET_NAME.lower()}.pth'

# Output directory
RESULTS_DIR = Path(f'../Results/Stability/{DATASET_NAME}')
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f'Dataset: {DATASET_NAME}')
print(f'Results will be saved to: {RESULTS_DIR}')

## 2. Load Data and Model

In [None]:
# Load dataset
train_loader, test_loader = load_dataset(DATASET_NAME, batch_size=32)

# Get full test set
X_test = test_loader.dataset.tensors[0]
y_test = test_loader.dataset.tensors[1]

# Determine dimensions
train_labels = train_loader.dataset.tensors[1]
test_labels = test_loader.dataset.tensors[1]
num_classes = int(torch.cat([train_labels, test_labels]).max().item()) + 1
seq_len, channels = X_test.shape[1], X_test.shape[2]

print(f'Test set: {X_test.shape[0]} samples')
print(f'Sequence length: {seq_len}, Channels: {channels}')
print(f'Number of classes: {num_classes}')

In [None]:
# Load model
model = TimeSeriesTransformer(
    input_dim=channels,
    num_classes=num_classes,
    seq_len=seq_len
)
model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu', weights_only=True))
model.eval()

# Model info
num_layers = len(model.transformer_encoder.layers)
num_heads = model.transformer_encoder.layers[0].self_attn.num_heads

print(f'Model loaded: {num_layers} layers, {num_heads} heads')

## 3. Validate Perturbations

Check which perturbation settings preserve accuracy (within 5% drop).

In [None]:
configs = get_perturbation_configs()
valid_configs = {}

print('Validating perturbations (must preserve accuracy within 5%):\n')

for method, param_list in configs.items():
    print(f'{method.upper()}:')
    valid_configs[method] = []
    
    for params in param_list:
        X_pert = apply_perturbation(X_test, method, seed=SEED, **params)
        result = validate_perturbation(model, X_test, X_pert, y_test)
        
        param_str = ', '.join(f'{k}={v}' for k, v in params.items())
        status = 'PASS' if result['valid'] else 'FAIL'
        
        print(f"  {param_str:20s} | Acc: {result['original_acc']:.3f} -> {result['perturbed_acc']:.3f} | {status}")
        
        if result['valid']:
            valid_configs[method].append(params)
    print()

## 4. Select Sample Pairs for Analysis

Pick pairs of (clean, corrupt) samples where the model makes different predictions.

In [None]:
# Get model predictions
with torch.no_grad():
    preds = model(X_test).argmax(dim=1)

# Find correctly and incorrectly classified samples
correct_mask = (preds == y_test)
correct_idx = torch.where(correct_mask)[0].numpy()
incorrect_idx = torch.where(~correct_mask)[0].numpy()

print(f'Correctly classified: {len(correct_idx)}')
print(f'Incorrectly classified: {len(incorrect_idx)}')

# Select pairs: use correctly classified as "clean", pick another sample as "corrupt"
NUM_PAIRS = min(20, len(correct_idx))
np.random.seed(SEED)
clean_indices = np.random.choice(correct_idx, NUM_PAIRS, replace=False)
corrupt_indices = np.random.choice(correct_idx, NUM_PAIRS, replace=False)

# Make sure they're different
for i in range(NUM_PAIRS):
    while corrupt_indices[i] == clean_indices[i]:
        corrupt_indices[i] = np.random.choice(correct_idx)

print(f'\nSelected {NUM_PAIRS} pairs for analysis')

## 5. Run Stability Analysis

For each perturbation config:
1. Perturb the "corrupt" input
2. Run patching analysis on (clean, corrupt) and (clean, perturbed)
3. Compare the head importance rankings

In [None]:
all_results = {}

for method, param_list in valid_configs.items():
    for params in param_list:
        config_name = f"{method}_" + "_".join(f"{v}" for v in params.values())
        print(f'\nRunning: {config_name}')
        
        # Store metrics for each pair
        pair_metrics = []
        
        for i in range(NUM_PAIRS):
            clean_idx = clean_indices[i]
            corrupt_idx = corrupt_indices[i]
            
            clean = X_test[clean_idx:clean_idx+1]
            corrupt = X_test[corrupt_idx:corrupt_idx+1]
            true_label = y_test[corrupt_idx].item()
            
            # Perturb the corrupt input
            corrupt_pert = apply_perturbation(corrupt, method, seed=SEED+i, **params)
            
            # Baseline patching: clean -> corrupt
            baseline_probs = sweep_heads(model, clean, corrupt, num_classes)
            baseline_raw = get_probs(model, corrupt)
            baseline_imp = get_head_importance(baseline_probs, baseline_raw, true_label)
            
            # Perturbed patching: clean -> perturbed_corrupt  
            perturbed_probs = sweep_heads(model, clean, corrupt_pert, num_classes)
            perturbed_raw = get_probs(model, corrupt_pert)
            perturbed_imp = get_head_importance(perturbed_probs, perturbed_raw, true_label)
            
            # Compute metrics
            metrics = compute_all_metrics(baseline_imp, perturbed_imp)
            pair_metrics.append(metrics)
        
        # Average metrics across pairs
        avg_metrics = {}
        for key in pair_metrics[0].keys():
            if isinstance(pair_metrics[0][key], dict):
                avg_metrics[key] = {}
                for subkey in pair_metrics[0][key].keys():
                    avg_metrics[key][subkey] = np.mean([m[key][subkey] for m in pair_metrics])
            else:
                avg_metrics[key] = np.mean([m[key] for m in pair_metrics])
        
        all_results[config_name] = avg_metrics
        print(f"  Rank corr: {avg_metrics['rank_correlation']:.3f}, Top-5: {avg_metrics['topk_overlap_k5']:.3f}")

## 6. Results Summary

In [None]:
# Create summary table
summary_table = create_summary_table(all_results)
print(summary_table)

# Save to file
with open(RESULTS_DIR / 'summary_table.md', 'w') as f:
    f.write(f'# Stability Results: {DATASET_NAME}\n\n')
    f.write(summary_table)

In [None]:
# Create DataFrame for easier analysis
rows = []
for config, metrics in all_results.items():
    row = {
        'config': config,
        'rank_correlation': metrics['rank_correlation'],
        'topk_overlap_k3': metrics['topk_overlap_k3'],
        'topk_overlap_k5': metrics['topk_overlap_k5'],
        'topk_overlap_k10': metrics['topk_overlap_k10'],
        'stability_score': metrics['stability_score']
    }
    rows.append(row)

df = pd.DataFrame(rows)
df.to_csv(RESULTS_DIR / 'stability_metrics.csv', index=False)
print(f'Saved to {RESULTS_DIR / "stability_metrics.csv"}')
df

## 7. Visualizations

In [None]:
# Bar chart of rank correlations
fig, ax = plt.subplots(figsize=(10, 5))
configs = list(all_results.keys())
correlations = [all_results[c]['rank_correlation'] for c in configs]

colors = ['blue' if 'gaussian' in c else 'green' if 'time' in c else 'red' for c in configs]
ax.bar(configs, correlations, color=colors)
ax.axhline(y=0.8, color='gray', linestyle='--', label='High stability threshold')
ax.set_ylabel('Rank Correlation')
ax.set_xlabel('Perturbation Config')
ax.set_title(f'Mechanism Stability: {DATASET_NAME}')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'rank_correlation_bar.png', dpi=150)
plt.show()

In [None]:
# Top-K overlap comparison
fig, ax = plt.subplots(figsize=(10, 5))
x = np.arange(len(configs))
width = 0.25

for i, k in enumerate([3, 5, 10]):
    values = [all_results[c][f'topk_overlap_k{k}'] for c in configs]
    ax.bar(x + i*width, values, width, label=f'K={k}')

ax.set_ylabel('Jaccard Overlap')
ax.set_xlabel('Perturbation Config')
ax.set_title(f'Top-K Head Overlap: {DATASET_NAME}')
ax.set_xticks(x + width)
ax.set_xticklabels(configs, rotation=45, ha='right')
ax.legend()
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'topk_overlap_bar.png', dpi=150)
plt.show()

## 8. Case Study: Example Comparison

Visualize how head importance changes for a specific example.

In [None]:
# Pick first valid perturbation for case study
case_method = list(valid_configs.keys())[0]
case_params = valid_configs[case_method][0]
case_config = f"{case_method}_" + "_".join(f"{v}" for v in case_params.values())

# Use first pair
clean_idx = clean_indices[0]
corrupt_idx = corrupt_indices[0]

clean = X_test[clean_idx:clean_idx+1]
corrupt = X_test[corrupt_idx:corrupt_idx+1]
true_label = y_test[corrupt_idx].item()

# Perturb
corrupt_pert = apply_perturbation(corrupt, case_method, seed=SEED, **case_params)

# Get importance
baseline_probs = sweep_heads(model, clean, corrupt, num_classes)
baseline_raw = get_probs(model, corrupt)
baseline_imp = get_head_importance(baseline_probs, baseline_raw, true_label)

perturbed_probs = sweep_heads(model, clean, corrupt_pert, num_classes)
perturbed_raw = get_probs(model, corrupt_pert)
perturbed_imp = get_head_importance(perturbed_probs, perturbed_raw, true_label)

# Plot comparison
fig = plot_importance_comparison(
    baseline_imp, perturbed_imp,
    title=f'Case Study: {case_config}',
    save_path=str(RESULTS_DIR / 'case_study_comparison.png')
)
plt.show()

## 9. Findings Summary

In [None]:
# Calculate overall statistics
avg_rank_corr = df['rank_correlation'].mean()
avg_topk5 = df['topk_overlap_k5'].mean()
avg_stability = df['stability_score'].mean()

findings = f"""# Stability Analysis Findings: {DATASET_NAME}

## Summary Statistics
- Average rank correlation: {avg_rank_corr:.3f}
- Average top-5 overlap: {avg_topk5:.3f}
- Average stability score: {avg_stability:.3f}

## Interpretation

"""

if avg_rank_corr > 0.7:
    findings += "- Mechanistic explanations appear **relatively stable** under tested perturbations\n"
elif avg_rank_corr > 0.4:
    findings += "- Mechanistic explanations show **moderate stability** - some heads consistently important\n"
else:
    findings += "- Mechanistic explanations appear **unstable** - different heads become important under perturbation\n"

if avg_topk5 > 0.5:
    findings += "- Top-5 most important heads show **good overlap** across conditions\n"
else:
    findings += "- Top-5 most important heads show **limited overlap** - explanations are context-dependent\n"

findings += f"""\n## Perturbation-Specific Results

{summary_table}

## Notes
- Analysis based on {NUM_PAIRS} sample pairs
- Random seed: {SEED}
"""

print(findings)

with open(RESULTS_DIR / 'findings.md', 'w') as f:
    f.write(findings)

print(f'\nSaved to {RESULTS_DIR / "findings.md"}')

---
## Done!

Results saved to `Results/Stability/{DATASET_NAME}/`