# Interpretability Analysis Explorer

Interactive exploration of SHAP, Integrated Gradients, and LIME results across multiple architectures and datasets.

**Purpose**: Visualize and analyze SNP importance rankings for thesis/journal publication.

In [None]:
import os
import sys
import json
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr
import warnings

warnings.filterwarnings('ignore')

# Add project root to path
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

# Setup plotting
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 10

print("‚úì Libraries imported successfully")

## 1. Load Interpretability Results

Specify the base directory containing the analysis results.

In [None]:
# Configure base directory for analysis results
ANALYSIS_BASE_DIR = Path(project_root) / 'outputs' / 'interpretability_analysis'

# Also check for alternative paths
if not ANALYSIS_BASE_DIR.exists():
    alt_paths = [
        Path(project_root) / 'outputs',
        Path('.') / 'outputs' / 'interpretability_analysis',
    ]
    
    for alt_path in alt_paths:
        if alt_path.exists():
            if 'interpretability_analysis' not in str(alt_path):
                ANALYSIS_BASE_DIR = alt_path / 'interpretability_analysis'
            else:
                ANALYSIS_BASE_DIR = alt_path
            
            if ANALYSIS_BASE_DIR.exists():
                break

print(f"Analysis base directory: {ANALYSIS_BASE_DIR}")
print(f"Exists: {ANALYSIS_BASE_DIR.exists()}")

# Discover available results
available_results = {
    'checkpoints': [],
    'datasets': [],
    'methods': ['shap', 'ig', 'lime']
}

if ANALYSIS_BASE_DIR.exists():
    # Find checkpoint directories
    for item in ANALYSIS_BASE_DIR.iterdir():
        if item.is_dir() and item.name not in ['figures', 'data', 'publication_figures', 'supplementary_data']:
            available_results['checkpoints'].append(item.name)
            
            # Find datasets within checkpoint
            for dataset_item in item.iterdir():
                if dataset_item.is_dir() and dataset_item.name not in ['shap', 'ig', 'lime']:
                    if dataset_item.name not in available_results['datasets']:
                        available_results['datasets'].append(dataset_item.name)

print(f"\nüìä Available Results:")
print(f"  Checkpoints: {len(available_results['checkpoints'])}")
print(f"  Datasets: {len(available_results['datasets'])}")
print(f"  Methods: {available_results['methods']}")

if not available_results['checkpoints']:
    print("\n‚ö†Ô∏è No analysis results found yet. Run these commands first:")
    print("  python src/interpretability_pipeline.py")
    print("  or")
    print("  python src/shap_explainability.py --checkpoint_path <path>")

In [None]:
# Configure base directory for analysis results
# Update this path to your analysis output directory
ANALYSIS_BASE_DIR = Path('outputs/interpretability_analysis')

# Discover available results
def discover_results(base_dir: Path) -> Dict[str, List[str]]:
    """Discover available checkpoints and datasets."""
    results = {
        'checkpoints': [],
        'datasets': [],
        'methods': []
    }
    
    if not base_dir.exists():
        print(f"‚ö†Ô∏è Analysis directory not found: {base_dir}")
        return results
    
    # Discover checkpoints
    for checkpoint_dir in base_dir.iterdir():
        if checkpoint_dir.is_dir() and checkpoint_dir.name != 'figures' and checkpoint_dir.name != 'data':
            results['checkpoints'].append(checkpoint_dir.name)
    
    # Discover datasets and methods
    if results['checkpoints']:
        first_checkpoint = base_dir / results['checkpoints'][0]
        for dataset_dir in first_checkpoint.iterdir():
            if dataset_dir.is_dir():
                results['datasets'].append(dataset_dir.name)
                for method_dir in dataset_dir.iterdir():
                    if method_dir.is_dir():
                        results['methods'].append(method_dir.name)
    
    results['methods'] = list(set(results['methods']))
    results['checkpoints'].sort()
    results['datasets'].sort()
    results['methods'].sort()
    
    return results

# Load results
available_results = discover_results(ANALYSIS_BASE_DIR)
print(f"Discovered {len(available_results['checkpoints'])} checkpoints")
print(f"Discovered {len(available_results['datasets'])} datasets")
print(f"Discovered {len(available_results['methods'])} methods: {available_results['methods']}")

if available_results['checkpoints']:
    print(f"\n‚úì Results loaded. Available checkpoints:")
    for ckpt in available_results['checkpoints'][:5]:
        print(f"  - {ckpt}")
else:
    print("‚ö†Ô∏è No results found. Run interpretability_pipeline.py first.")

## 2. Utilities for SNP Analysis

Load and compare SNP importance rankings across methods and architectures.

In [None]:
def load_snp_ranking(checkpoint: str, dataset: str, method: str, base_dir: Path = ANALYSIS_BASE_DIR) -> pd.DataFrame:
    """Load SNP rankings for a specific analysis."""
    ranking_file = base_dir / checkpoint / dataset / method / f'top_{method}_snps.csv'
    
    if not ranking_file.exists():
        print(f"‚ö†Ô∏è File not found: {ranking_file}")
        return pd.DataFrame()
    
    return pd.read_csv(ranking_file)

def compare_top_snps_across_methods(checkpoint: str, dataset: str, top_k: int = 20) -> pd.DataFrame:
    """Compare top-K SNPs across SHAP, IG, and LIME for a given checkpoint/dataset."""
    
    rankings = {}
    for method in available_results['methods']:
        df = load_snp_ranking(checkpoint, dataset, method)
        if not df.empty:
            rankings[method] = df.head(top_k)[['Rank', 'SNP_ID']].set_index('SNP_ID')
    
    if not rankings:
        return pd.DataFrame()
    
    # Combine rankings
    combined = pd.concat([rankings[m] for m in sorted(rankings.keys())], axis=1)
    combined.columns = sorted(rankings.keys())
    combined = combined.fillna('-')
    
    return combined

def get_top_consensus_snps(dataset: str, top_k: int = 30, min_agreement: float = 0.5) -> pd.DataFrame:
    """Get consensus SNPs across all architectures."""
    
    snp_counts = {}
    snp_ranks = {}
    
    for checkpoint in available_results['checkpoints']:
        for method in available_results['methods']:
            df = load_snp_ranking(checkpoint, dataset, method)
            
            if df.empty:
                continue
            
            for _, row in df.iterrows():
                snp_id = row['SNP_ID']
                rank = row['Rank']
                
                if snp_id not in snp_counts:
                    snp_counts[snp_id] = 0
                    snp_ranks[snp_id] = []
                
                snp_counts[snp_id] += 1
                snp_ranks[snp_id].append(rank)
    
    total_combinations = len(available_results['checkpoints']) * len(available_results['methods'])
    
    consensus_data = []
    for snp_id, count in snp_counts.items():
        agreement = count / total_combinations
        
        if agreement >= min_agreement:
            consensus_data.append({
                'SNP_ID': snp_id,
                'Appearances': count,
                'Agreement_Ratio': agreement,
                'Mean_Rank': np.mean(snp_ranks[snp_id]),
                'Std_Rank': np.std(snp_ranks[snp_id]),
            })
    
    consensus_df = pd.DataFrame(consensus_data)
    consensus_df = consensus_df.sort_values('Agreement_Ratio', ascending=False).head(top_k)
    
    return consensus_df

print("‚úì Utility functions loaded")

## 3. Explore Top SNPs by Architecture

Compare the top SNPs across different architectures and interpretability methods.

In [None]:
# Select checkpoint and dataset to explore
if available_results['checkpoints']:
    selected_checkpoint = available_results['checkpoints'][0]  # First checkpoint by default
    selected_dataset = available_results['datasets'][0] if available_results['datasets'] else 'autism'
    
    print(f"Currently viewing: {selected_checkpoint} / {selected_dataset}")
    print(f"\nTo explore different results, modify selected_checkpoint or selected_dataset")
    print(f"\nAvailable checkpoints: {len(available_results['checkpoints'])}")
    print(f"Available datasets: {len(available_results['datasets'])}")
else:
    print("‚ö†Ô∏è No results available to explore")

## 4. Compare Top SNPs Across Methods

View how the top-20 SNPs compare across SHAP, IG, and LIME for a given architecture/dataset.

In [None]:
# Compare top SNPs across methods
if available_results['checkpoints']:
    comparison_df = compare_top_snps_across_methods(selected_checkpoint, selected_dataset, top_k=20)
    
    if not comparison_df.empty:
        print(f"Top 20 SNPs Comparison ({selected_checkpoint} - {selected_dataset}):\n")
        print(comparison_df.to_string())
    else:
        print("‚ö†Ô∏è No data available for this checkpoint/dataset combination")
else:
    print("‚ö†Ô∏è No results available")

## 5. Consensus SNPs Across Architectures

Identify SNPs that are consistently important across multiple architectures and methods (robust for publication).

In [None]:
# Compute consensus SNPs for autism dataset
if available_results['datasets']:
    consensus_snps = get_top_consensus_snps(selected_dataset, top_k=30, min_agreement=0.3)
    
    if not consensus_snps.empty:
        print(f"\nConsensus SNPs for {selected_dataset}:")
        print(f"(SNPs appearing in ‚â•30% of architecture/method combinations)\n")
        print(consensus_snps.to_string(index=False))
        
        # Summary statistics
        print(f"\nüìä Summary:")
        print(f"  Total consensus SNPs: {len(consensus_snps)}")
        print(f"  Mean agreement ratio: {consensus_snps['Agreement_Ratio'].mean():.2%}")
        print(f"  Max agreement ratio: {consensus_snps['Agreement_Ratio'].max():.2%}")
    else:
        print("‚ö†Ô∏è No consensus SNPs found with current threshold")
else:
    print("‚ö†Ô∏è No datasets available")

## 6. Generate Publication-Quality Figures

Create high-resolution figures suitable for thesis and journal articles.

In [None]:
# Create output directory for figures
FIGURE_OUTPUT_DIR = Path(ANALYSIS_BASE_DIR) / 'publication_figures'
FIGURE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Publication figures will be saved to: {FIGURE_OUTPUT_DIR}")

# Figure 1: Consensus SNPs bar plot
def create_consensus_figure(dataset: str, top_k: int = 30):
    """Create consensus SNP visualization."""
    consensus_df = get_top_consensus_snps(dataset, top_k=top_k, min_agreement=0.2)
    
    if consensus_df.empty:
        print(f"No data for {dataset}")
        return
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    colors = plt.cm.viridis(consensus_df['Agreement_Ratio'] / consensus_df['Agreement_Ratio'].max())
    bars = ax.barh(range(len(consensus_df)), consensus_df['Agreement_Ratio'], color=colors)
    
    ax.set_yticks(range(len(consensus_df)))
    ax.set_yticklabels(consensus_df['SNP_ID'], fontsize=9)
    ax.set_xlabel('Consensus Ratio (fraction of models)', fontsize=12, fontweight='bold')
    ax.set_ylabel('SNP Identifier', fontsize=12, fontweight='bold')
    ax.set_title(f'Top {top_k} Consensus SNPs\n{dataset.capitalize()} Dataset - Robust Across Multiple Architectures', 
                 fontsize=14, fontweight='bold', pad=20)
    
    ax.invert_yaxis()
    ax.grid(axis='x', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    # Add percentage labels
    for i, (bar, val) in enumerate(zip(bars, consensus_df['Agreement_Ratio'])):
        ax.text(val + 0.01, i, f'{val:.0%}', va='center', fontsize=8)
    
    plt.tight_layout()
    
    output_file = FIGURE_OUTPUT_DIR / f'consensus_snps_{dataset}_top{top_k}.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_file}")
    plt.close()

# Generate figures for each dataset
for dataset in available_results['datasets']:
    create_consensus_figure(dataset, top_k=30)

## 7. Export Results for Publication

Save consensus SNP rankings and summary statistics for supplementary materials.

In [None]:
# Export all consensus SNPs to CSV for supplementary materials
DATA_OUTPUT_DIR = Path(ANALYSIS_BASE_DIR) / 'supplementary_data'
DATA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for dataset in available_results['datasets']:
    consensus_df = get_top_consensus_snps(dataset, top_k=200, min_agreement=0.2)
    
    if not consensus_df.empty:
        output_file = DATA_OUTPUT_DIR / f'consensus_snps_{dataset}_full.csv'
        consensus_df.to_csv(output_file, index=False)
        print(f"‚úì Exported: {output_file} ({len(consensus_df)} SNPs)")

print(f"\n‚úì All results exported to: {DATA_OUTPUT_DIR}")