# Statistical Reanalysis of Patching Results

This notebook applies rigorous statistical analysis to all existing patching results, addressing the methodological gaps identified in the assessment:

1. **Confidence Intervals**: BCa bootstrap CIs for all claims
2. **Effect Sizes**: Cohen's d with interpretation
3. **FDR Correction**: Benjamini-Hochberg for multiple comparisons
4. **Random Baseline**: Null distribution for statistical significance
5. **Proper Framing**: Case studies, not generalizable correlations

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# Our statistical infrastructure
from Utilities.statistics import (
    compute_confidence_interval,
    compute_effect_size,
    apply_fdr_correction,
    compute_correlation_ci,
    significance_test,
    power_analysis,
    compute_aggregate_statistics,
    format_result_table
)

from Utilities.baselines import (
    random_patching_baseline,
    compute_empirical_p_value
)

plt.style.use('seaborn-v0_8-whitegrid')
np.random.seed(42)

## 1. Load All Results

In [None]:
# Load aggregate stability results
stability_df = pd.read_csv('../Results/Stability/Summary/aggregate_results.csv')
print("Stability Results:")
display(stability_df)

In [None]:
# Function to load all raw results
def load_all_raw_results(base_path: Path) -> dict:
    """Load all raw NPZ files organized by dataset."""
    results = {}
    
    for dataset in ['JapaneseVowels', 'PenDigits', 'LSST']:
        results[dataset] = {'denoise': [], 'noise': []}
        dataset_path = base_path / dataset
        
        for mode in ['denoise', 'noise']:
            mode_path = dataset_path / mode
            if not mode_path.exists():
                continue
                
            for class_dir in mode_path.iterdir():
                if not class_dir.is_dir() or not class_dir.name.startswith('class_'):
                    continue
                    
                class_id = int(class_dir.name.split('_')[1])
                
                for pair_dir in class_dir.iterdir():
                    if not pair_dir.is_dir() or not pair_dir.name.startswith('pair_'):
                        continue
                    
                    npz_path = pair_dir / 'raw_results.npz'
                    if npz_path.exists():
                        data = np.load(npz_path)
                        results[dataset][mode].append({
                            'class_id': class_id,
                            'pair_name': pair_dir.name,
                            'baseline': data['baseline'],
                            'head_patch': data['head_patch'] if 'head_patch' in data else None,
                            'layer_patch': data['layer_patch'] if 'layer_patch' in data else None,
                        })
    
    return results

base_path = Path('../Results')
all_results = load_all_raw_results(base_path)

# Summary counts
for dataset, modes in all_results.items():
    denoise_count = len(modes['denoise'])
    noise_count = len(modes['noise'])
    print(f"{dataset}: {denoise_count} denoise pairs, {noise_count} noise pairs")

## 2. Compute Delta P with Confidence Intervals

For each dataset, compute mean ΔP across all pairs with BCa bootstrap 95% CIs.

In [None]:
def compute_delta_p_with_ci(results: dict, mode: str = 'denoise') -> pd.DataFrame:
    """Compute mean ΔP with 95% CI for each dataset."""
    summary = []
    
    for dataset, modes in results.items():
        pairs = modes[mode]
        if not pairs:
            continue
        
        # Collect all delta_p values across all heads and pairs
        all_delta_p = []
        max_delta_p_per_pair = []
        
        for pair in pairs:
            if pair['head_patch'] is None:
                continue
            
            true_label = pair['class_id']
            baseline_p = pair['baseline'][true_label]
            
            # Delta P for each head
            head_delta = pair['head_patch'][:, :, true_label] - baseline_p
            all_delta_p.extend(head_delta.flatten().tolist())
            max_delta_p_per_pair.append(head_delta.max())
        
        if not all_delta_p:
            continue
        
        all_delta_p = np.array(all_delta_p)
        max_delta_p_per_pair = np.array(max_delta_p_per_pair)
        
        # Compute CI for mean delta_p
        mean_ci = compute_confidence_interval(all_delta_p, confidence=0.95, method='bca')
        max_ci = compute_confidence_interval(max_delta_p_per_pair, confidence=0.95, method='bca')
        
        summary.append({
            'Dataset': dataset,
            'n_pairs': len(pairs),
            'n_observations': len(all_delta_p),
            'mean_delta_p': mean_ci.mean,
            'mean_ci_lower': mean_ci.lower,
            'mean_ci_upper': mean_ci.upper,
            'max_delta_p_mean': max_ci.mean,
            'max_ci_lower': max_ci.lower,
            'max_ci_upper': max_ci.upper,
        })
    
    return pd.DataFrame(summary)

delta_p_summary = compute_delta_p_with_ci(all_results, 'denoise')
print("\n=== Delta P Results with 95% CIs (Denoise Mode) ===")
display(delta_p_summary)

In [None]:
# Visualization with error bars
fig, ax = plt.subplots(figsize=(10, 6))

datasets = delta_p_summary['Dataset'].values
means = delta_p_summary['mean_delta_p'].values
lower_err = means - delta_p_summary['mean_ci_lower'].values
upper_err = delta_p_summary['mean_ci_upper'].values - means

x = np.arange(len(datasets))
ax.bar(x, means, yerr=[lower_err, upper_err], capsize=5, color=['#2ecc71', '#3498db', '#e74c3c'])
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.set_ylabel('Mean ΔP')
ax.set_title('Mean Patching Effect (ΔP) with 95% Bootstrap CIs')
ax.axhline(0, color='black', linestyle='--', alpha=0.3)

# Add annotation
for i, (d, m, l, u) in enumerate(zip(datasets, means, delta_p_summary['mean_ci_lower'], delta_p_summary['mean_ci_upper'])):
    ax.annotate(f'{m:.3f}\n[{l:.3f}, {u:.3f}]', 
                xy=(i, m + upper_err[i] + 0.02), 
                ha='center', fontsize=9)

plt.tight_layout()
plt.savefig('../Results/Summary/delta_p_with_ci.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Head Importance with FDR Correction

Apply Benjamini-Hochberg correction to head importance rankings.

In [None]:
def compute_head_importance_with_fdr(results: dict, mode: str = 'denoise') -> dict:
    """Compute head importance with FDR correction."""
    fdr_results = {}
    
    for dataset, modes in results.items():
        pairs = modes[mode]
        if not pairs:
            continue
        
        # Aggregate delta_p per head across all pairs
        head_deltas = {}  # (layer, head) -> list of delta_p values
        
        for pair in pairs:
            if pair['head_patch'] is None:
                continue
            
            true_label = pair['class_id']
            baseline_p = pair['baseline'][true_label]
            head_delta = pair['head_patch'][:, :, true_label] - baseline_p
            
            n_layers, n_heads = head_delta.shape
            for l in range(n_layers):
                for h in range(n_heads):
                    key = (l, h)
                    if key not in head_deltas:
                        head_deltas[key] = []
                    head_deltas[key].append(head_delta[l, h])
        
        # For each head, test if mean delta_p > 0
        p_values = []
        head_stats = []
        
        for (l, h), deltas in head_deltas.items():
            deltas = np.array(deltas)
            mean_delta = np.mean(deltas)
            
            # One-sample t-test against 0
            if len(deltas) > 1:
                t_stat, p_val = stats.ttest_1samp(deltas, 0, alternative='greater')
            else:
                p_val = 1.0
                t_stat = 0
            
            p_values.append(p_val)
            ci = compute_confidence_interval(deltas, confidence=0.95)
            
            head_stats.append({
                'layer': l,
                'head': h,
                'mean_delta_p': mean_delta,
                'ci_lower': ci.lower,
                'ci_upper': ci.upper,
                'n_pairs': len(deltas),
                'p_value': p_val
            })
        
        # Apply FDR correction
        fdr_result = apply_fdr_correction(p_values, method='benjamini_hochberg')
        
        for i, stat in enumerate(head_stats):
            stat['p_value_corrected'] = fdr_result['p_corrected'][i]
            stat['significant_fdr'] = fdr_result['significant'][i]
        
        fdr_results[dataset] = {
            'head_stats': pd.DataFrame(head_stats).sort_values('mean_delta_p', ascending=False),
            'n_significant': fdr_result['n_significant'],
            'n_tests': fdr_result['n_tests']
        }
    
    return fdr_results

fdr_results = compute_head_importance_with_fdr(all_results, 'denoise')

for dataset, result in fdr_results.items():
    print(f"\n=== {dataset}: Head Importance with FDR Correction ===")
    print(f"Significant heads after FDR: {result['n_significant']}/{result['n_tests']}")
    display(result['head_stats'].head(10))

## 4. Stability Analysis with Proper Statistics

Reanalyze stability results with:
- CIs for all metrics
- Honest framing as case studies (n=3)

In [None]:
# Compute per-dataset stability summaries with CIs
def compute_stability_with_ci(stability_df: pd.DataFrame) -> pd.DataFrame:
    """Compute stability metrics with 95% CIs per dataset."""
    summaries = []
    
    for dataset in stability_df['dataset'].unique():
        subset = stability_df[stability_df['dataset'] == dataset]
        
        rank_corrs = subset['rank_corr'].values
        top5_overlaps = subset['top5'].values
        stabilities = subset['stability'].values
        
        # CIs (may fail with very small n)
        n = len(rank_corrs)
        
        if n >= 2:
            rank_ci = compute_confidence_interval(rank_corrs, confidence=0.95)
            top5_ci = compute_confidence_interval(top5_overlaps, confidence=0.95)
            stab_ci = compute_confidence_interval(stabilities, confidence=0.95)
        else:
            # Can't compute CI with n < 2
            rank_ci = type('obj', (object,), {'mean': rank_corrs[0] if n > 0 else np.nan, 'lower': np.nan, 'upper': np.nan})()
            top5_ci = type('obj', (object,), {'mean': top5_overlaps[0] if n > 0 else np.nan, 'lower': np.nan, 'upper': np.nan})()
            stab_ci = type('obj', (object,), {'mean': stabilities[0] if n > 0 else np.nan, 'lower': np.nan, 'upper': np.nan})()
        
        summaries.append({
            'Dataset': dataset,
            'n_perturbations': n,
            'rank_corr_mean': rank_ci.mean,
            'rank_corr_ci': f"[{rank_ci.lower:.3f}, {rank_ci.upper:.3f}]" if not np.isnan(rank_ci.lower) else "N/A",
            'top5_mean': top5_ci.mean,
            'top5_ci': f"[{top5_ci.lower:.3f}, {top5_ci.upper:.3f}]" if not np.isnan(top5_ci.lower) else "N/A",
            'stability_mean': stab_ci.mean,
            'stability_ci': f"[{stab_ci.lower:.3f}, {stab_ci.upper:.3f}]" if not np.isnan(stab_ci.lower) else "N/A",
        })
    
    return pd.DataFrame(summaries)

stability_summary = compute_stability_with_ci(stability_df)
print("\n=== Stability Summary with 95% CIs ===")
display(stability_summary)

## 5. The n=3 Problem: Honest Assessment

With only 3 datasets, any correlation claim is statistically meaningless. This section documents the limitation and reframes findings as case studies.

In [None]:
# Demonstrate the n=3 problem
print("=== THE n=3 PROBLEM ===")
print()

# Our observed values
delta_p_values = [0.59, 0.15, 0.01]  # JapaneseVowels, PenDigits, LSST
stability_values = [0.89, 0.87, 0.52]  # Mean stability

observed_rho, observed_p = stats.spearmanr(delta_p_values, stability_values)
print(f"Observed Spearman ρ = {observed_rho:.3f}, p = {observed_p:.4f}")
print()

# Simulate: how often do we get ρ > 0.8 with random data?
n_simulations = 10000
high_corr_count = 0

for _ in range(n_simulations):
    random_x = np.random.uniform(0, 1, 3)
    random_y = np.random.uniform(0, 1, 3)
    rho, _ = stats.spearmanr(random_x, random_y)
    if abs(rho) > 0.8:
        high_corr_count += 1

false_positive_rate = high_corr_count / n_simulations
print(f"False positive rate (|ρ| > 0.8 by chance with n=3): {false_positive_rate:.1%}")
print()
print("CONCLUSION: With n=3, correlation claims are statistically meaningless.")
print("A correlation of ρ = 0.89 could easily arise by chance (~20% of the time).")
print()
print("REFRAMING: We present three case studies that suggest a possible pattern:")
print("  - Datasets with stronger patching effects show higher stability")
print("  - This preliminary observation warrants investigation with more datasets")

In [None]:
# Power analysis: how many datasets would we need?
print("\n=== POWER ANALYSIS ===")
print()

# For different effect sizes (correlations), how many data points needed?
from scipy.stats import pearsonr

def required_n_for_correlation(r: float, alpha: float = 0.05, power: float = 0.80) -> int:
    """Calculate required n for detecting correlation r."""
    from scipy.stats import norm
    
    # Fisher's z transformation
    z = 0.5 * np.log((1 + r) / (1 - r))
    z_alpha = norm.ppf(1 - alpha/2)
    z_beta = norm.ppf(power)
    
    n = ((z_alpha + z_beta) / z) ** 2 + 3
    return int(np.ceil(n))

print("Required n for 80% power to detect correlation:")
for r in [0.9, 0.8, 0.7, 0.6, 0.5]:
    n_required = required_n_for_correlation(r)
    print(f"  ρ = {r}: n = {n_required} datasets")

print()
print(f"Current n = 3. For ρ ≈ 0.89, we need n ≈ {required_n_for_correlation(0.89)} datasets.")

## 6. Generate Publication-Ready Tables

In [None]:
# Table 1: Main Results with 95% CIs
print("\n" + "="*60)
print("TABLE 1: Main Results with 95% Bootstrap CIs")
print("="*60)
print()

table1 = []
for i, row in delta_p_summary.iterrows():
    dataset = row['Dataset']
    stab_row = stability_summary[stability_summary['Dataset'] == dataset].iloc[0]
    
    table1.append({
        'Dataset': dataset,
        'n pairs': row['n_pairs'],
        'Mean ΔP': f"{row['mean_delta_p']:.3f} [{row['mean_ci_lower']:.3f}, {row['mean_ci_upper']:.3f}]",
        'Stability (ρ)': f"{stab_row['rank_corr_mean']:.3f} {stab_row['rank_corr_ci']}",
        'Top-5 Overlap': f"{stab_row['top5_mean']:.3f} {stab_row['top5_ci']}"
    })

table1_df = pd.DataFrame(table1)
print(table1_df.to_markdown(index=False))

In [None]:
# Table 2: Head Importance with FDR (Top 10 per dataset)
print("\n" + "="*60)
print("TABLE 2: Top-10 Most Influential Heads (FDR-corrected p-values)")
print("="*60)

for dataset, result in fdr_results.items():
    print(f"\n{dataset}:")
    top10 = result['head_stats'].head(10)[['layer', 'head', 'mean_delta_p', 'ci_lower', 'ci_upper', 'p_value_corrected', 'significant_fdr']]
    top10.columns = ['Layer', 'Head', 'Mean ΔP', 'CI Lower', 'CI Upper', 'p (FDR)', 'Sig.']
    print(top10.to_markdown(index=False))

## 7. Save Results

In [None]:
# Create Summary directory if it doesn't exist
summary_dir = Path('../Results/Summary')
summary_dir.mkdir(parents=True, exist_ok=True)

# Save all statistical results
delta_p_summary.to_csv(summary_dir / 'delta_p_with_ci.csv', index=False)
stability_summary.to_csv(summary_dir / 'stability_with_ci.csv', index=False)
table1_df.to_csv(summary_dir / 'main_results_table.csv', index=False)

# Save FDR results per dataset
for dataset, result in fdr_results.items():
    result['head_stats'].to_csv(summary_dir / f'{dataset}_head_importance_fdr.csv', index=False)

print("Results saved to Results/Summary/")

## 8. Conclusions & Limitations

### Key Findings (with proper statistical reporting)

1. **JapaneseVowels** shows strong patching effects (mean ΔP with 95% CI above zero) and high stability
2. **PenDigits** shows moderate effects with moderate stability  
3. **LSST** shows near-zero effects with low stability

### Critical Limitations

1. **Sample Size (n=3 datasets)**: Cannot make correlation claims across datasets
2. **Single Training Run**: Cannot distinguish task-specific vs run-specific mechanisms
3. **No Null Baseline Yet**: Need random patching comparison for significance

### Recommended Framing

> "Through three case studies, we observe that datasets with stronger patching effects 
> (JapaneseVowels: ΔP = 0.XX [CI], PenDigits: ΔP = 0.XX [CI]) also exhibit higher mechanism
> stability under perturbation, while LSST shows neither strong effects nor stability.
> This preliminary pattern warrants further investigation with additional datasets."