# SHAP Sensitivity Analysis - Violin Plot Version

Refactored analysis with publication-quality plots.

**Structure:**
1. Data Loading (Cells 1-5)
2. Analysis Execution (Cell 6)
3. Plots 1-a through 3-b (Cells 7-14)

In [None]:
# ========================================================
# Cell 1: Global Configuration
# ========================================================

# CHANGE DATASET HERE
DATASET = 'german'  # Options: 'acs','diabetes', 'german'

# NOTE: Models are configured in plot_violin.py Config.MODELS
# Current models: ['dt', 'rf', 'xgb', 'ftt', 'mlp', 'tabpfn']
# (logistic regression 'lr' has been removed, 'mlp' added between 'ftt' and 'tabpfn')

# NOTE: For cross-dataset comparison, set these y-axis limits in plot_violin.py
# after running all 3 datasets to find the maximum values:
# 
# Config.L2_YLIM = (0, 2.5)  # Example: unified y-limit for L2 distance plots
# Config.FEATURE_L2_YLIM = (0, 0.5)  # Example: unified x-limit for feature-wise plots
# 
# This ensures consistent scales across acs, german, and diabetes datasets.

print(f"üìä Dataset: {DATASET}")

In [None]:
# ========================================================
# Cell 2: Setup and Imports
# ========================================================

import os
import re
import pickle
import numpy as np
import pandas as pd
import shap
from sklearn.model_selection import StratifiedKFold
from IPython.display import Javascript
import importlib
import sys

# Disable scrolling
display(Javascript('IPython.OutputArea.prototype._should_scroll = function(lines) {return false;}'))

# Force reload if already imported
if 'plot_violin' in sys.modules:
    del sys.modules['plot_violin']

# Import refactored module
from plot_violin import *

# Override Config.DATASET with notebook variable
Config.DATASET = DATASET

print("‚úÖ Imports successful!")
print(f"Dataset: {Config.DATASET}")
print(f"Models: {Config.MODELS}")
print(f"Module location: {sys.modules['plot_violin'].__file__}")
print("\nüìù Recent changes:")
print("  - Baselines: Using Monte Carlo (Mallows) for Jaccard/RBO, analytical for L2")
print("  - Visualization: Shaded region (min-max) across parameter sweeps")
print("  - Parameters: rho=[0.5,0.6,0.7], kappa=[5..15], q=[0.2,0.3,0.4]")
print("  - All baselines enabled (L2, Jaccard, RBO)")
print("  - RBO: p parameter now computed dynamically as p=1-(1/n_features)")
print("  - RBO: Uses full feature depth (r=d) for both baseline and distance calculations")
print("  - RBO: Properly computes distance as (1 - similarity) everywhere")
print("  - Feature sorting: descending order (largest at bottom for 1-d)")
print("  - Plot 2-b, 3-b: reversed lists (descending from top to bottom)")

In [None]:
# ========================================================
# Cell 3: Load Global Data (X, Y)
# ========================================================

with open(f'./dataset/{Config.DATASET}_X.pkl', 'rb') as f:
    X = pickle.load(f)

with open(f'./dataset/{Config.DATASET}_Y.pkl', 'rb') as f:
    Y = pickle.load(f)

print(f"\nLoaded Dataset: {Config.DATASET}")
print(f"  X shape: {X.shape}")
print(f"  Y shape: {Y.shape if hasattr(Y, 'shape') else len(Y)}")

In [None]:
# ========================================================
# Cell 4: File Parsing Helper Functions
# ========================================================

def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def parse_filename(filename):
    """Parse SHAP result filenames"""
    dataset = Config.DATASET
    
    # SHAP values pattern
    pattern_sv = re.compile(
        rf'{dataset}_+([a-zA-Z0-9]+)_+(\d+)_+(\d+)_+(\d+)(?:_+(?:chunk)?(\d+))?_+sv_?\.pkl'
    )
    
    # Probability pattern
    pattern_proba = re.compile(
        rf'{dataset}_+([a-zA-Z0-9]+)_+(\d+)_+(\d+)(?:_+(?:chunk)?(\d+))?_+proba_?\.pkl'
    )
    
    # Try SHAP values
    match = pattern_sv.match(filename)
    if match:
        model, split, m_seed, e_seed, chunk_id = match.groups()
        chunk_idx = int(chunk_id) if chunk_id is not None else -1
        
        return {
            'type': 'sv', 'model': model,
            'split': int(split), 'm_seed': int(m_seed), 'e_seed': int(e_seed),
            'chunk_idx': chunk_idx, 'path': os.path.join(Config.RESULT_DIR, filename)
        }
    
    # Try probabilities
    match = pattern_proba.match(filename)
    if match:
        model, split, m_seed, chunk_id = match.groups()
        chunk_idx = int(chunk_id) if chunk_id is not None else -1
        
        return {
            'type': 'proba', 'model': model,
            'split': int(split), 'm_seed': int(m_seed), 'e_seed': None,
            'chunk_idx': chunk_idx, 'path': os.path.join(Config.RESULT_DIR, filename)
        }
    
    return None


def merge_chunks(file_list, return_meta=False):
    """Load and merge chunked files"""
    if not file_list:
        return (None, {}) if return_meta else None
    
    file_list.sort(key=lambda x: x['chunk_idx'])
    merged_values = []
    meta = {}
    
    first_obj = load_pickle(file_list[0]['path'])
    if isinstance(first_obj, shap.Explanation) and return_meta:
        meta['base_values'] = first_obj.base_values
        meta['data'] = first_obj.data
        meta['feature_names'] = first_obj.feature_names
    
    for item in file_list:
        obj = load_pickle(item['path'])
        val = obj.values if isinstance(obj, shap.Explanation) else obj
        merged_values.append(val)
    
    full_values = np.concatenate(merged_values, axis=0)
    
    if return_meta and 'data' in meta and len(file_list) > 1:
        merged_X = []
        for item in file_list:
            obj = load_pickle(item['path'])
            if isinstance(obj, shap.Explanation):
                merged_X.append(obj.data)
        if merged_X:
            meta['data'] = np.concatenate(merged_X, axis=0)
    
    return (full_values, meta) if return_meta else full_values


print("‚úÖ Helper functions loaded")

In [None]:
# ========================================================
# Cell 5: Load SHAP Results into Arrays
# ========================================================

def load_model_results(model_name):
    """Load all SHAP results for a given model"""
    print(f"\nüîÑ Loading model: {model_name}")
    
    all_files = os.listdir(Config.RESULT_DIR)
    sv_map = {}
    proba_map = {}
    
    # Categorize files
    for fname in all_files:
        info = parse_filename(fname)
        if info and info['model'] == model_name:
            if info['type'] == 'sv':
                key = (info['split'], info['m_seed'], info['e_seed'])
                if key not in sv_map:
                    sv_map[key] = []
                sv_map[key].append(info)
            elif info['type'] == 'proba':
                key = (info['split'], info['m_seed'])
                if key not in proba_map:
                    proba_map[key] = []
                proba_map[key].append(info)
    
    if not sv_map:
        print(f"  ‚ö†Ô∏è  No SHAP files found")
        return None
    
    # Get dimensions
    feature_names = None
    data_by_split = {}
    base_values_map = {}
    max_samples = 0
    n_features = 0
    
    for (split, m_seed, e_seed), files in sv_map.items():
        values, meta = merge_chunks(files, return_meta=True)
        max_samples = max(max_samples, values.shape[0])
        n_features = values.shape[1]
        
        if feature_names is None:
            feature_names = meta.get('feature_names')
        if split not in data_by_split:
            data_by_split[split] = meta.get('data')
        if (split, m_seed) not in base_values_map:
            base_values_map[(split, m_seed)] = meta.get('base_values')
    
    print(f"  üìè Dimensions: N={max_samples}, F={n_features}")
    
    # Initialize arrays
    shap_array = np.full(
        (Config.N_SPLITS, Config.N_MODEL_SEEDS, Config.N_EXPLAINER_SEEDS,
         max_samples, n_features),
        np.nan, dtype=np.float32
    )
    
    n_classes = 1
    if proba_map:
        first_p = merge_chunks(next(iter(proba_map.values())))
        if first_p.ndim > 1:
            n_classes = first_p.shape[1]
    
    proba_array = np.full(
        (Config.N_SPLITS, Config.N_MODEL_SEEDS, max_samples, n_classes),
        np.nan, dtype=np.float32
    )
    
    # Fill arrays
    for split in range(Config.N_SPLITS):
        for m_seed in range(Config.N_MODEL_SEEDS):
            if (split, m_seed) in proba_map:
                p_data = merge_chunks(proba_map[(split, m_seed)])
                if p_data.ndim == 1:
                    p_data = p_data.reshape(-1, 1)
                curr_n = p_data.shape[0]
                proba_array[split, m_seed, :curr_n, :] = p_data
    
    for (split, m_seed, e_seed), files in sv_map.items():
        values = merge_chunks(files, return_meta=False)
        curr_n = values.shape[0]
        shap_array[split, m_seed, e_seed, :curr_n, :] = values
    
    print(f"  ‚úÖ {model_name} loaded")
    
    return {
        'shap_values': shap_array,
        'proba': proba_array,
        'data': data_by_split,
        'base_values': base_values_map,
        'feature_names': feature_names
    }


# Load all models
print("="*60)
print("Loading SHAP Results")
print("="*60)

raw_results = {}
for model_name in Config.MODELS:
    res = load_model_results(model_name)
    if res:
        raw_results[model_name] = res

print(f"\n‚úÖ Loaded {len(raw_results)} models")

In [None]:
# ========================================================
# Cell 6: Generate Y_test Lists
# ========================================================

print("üîÑ Generating y_test lists...")

skf = StratifiedKFold(n_splits=Config.N_SPLITS, shuffle=True, random_state=42)

Y_arr = Y.values if hasattr(Y, 'values') else Y
X_arr = X.values if hasattr(X, 'values') else X

y_test_list = []
for _, test_index in skf.split(X_arr, Y_arr):
    y_test_list.append(Y_arr[test_index])

# Inject y_test
for model_name in raw_results:
    raw_results[model_name]['y_test'] = y_test_list

print("‚úÖ Y_test lists generated")

In [None]:
# ========================================================
# Cell 7: Run Sensitivity Analysis
# ========================================================

print("="*60)
print("Running Sensitivity Analysis")
print("="*60)

analysis_results = {}

for model_name, raw_data in raw_results.items():
    print(f"\nüîÑ Analyzing: {model_name}")
    
    analyzer = SensitivityAnalyzer(
        shap_5d=raw_data['shap_values'],
        proba_4d=raw_data['proba'],
        y_test_list=raw_data['y_test'],
        feature_names=raw_data['feature_names']
    )
    
    # Get masks for subgroups
    cert_masks = analyzer.get_certainty_masks()
    pred_masks = analyzer.get_prediction_masks()
    
    results = {
        'overall_pooled': analyzer.compute_overall_pooled(),
        'separated': analyzer.compute_separated_sensitivity(),
        'seedwise_explainer': analyzer.compute_seedwise_explainer(),
        'seedwise_model': analyzer.compute_seedwise_model(),
        'certainty': analyzer.compute_certainty_subgroups(),
        'prediction': analyzer.compute_prediction_subgroups(),
        'certainty_separated': analyzer.compute_certainty_separated_sensitivity(),
        'prediction_separated': analyzer.compute_prediction_separated_sensitivity(),
        'feature_overall': {
            'explainer': analyzer.compute_feature_sensitivity(mode='explainer'),
            'model': analyzer.compute_feature_sensitivity(mode='model')
        },
        'feature_certainty': {},
        'feature_prediction': {},
        'certainty_feature_separated': analyzer.compute_certainty_feature_separated(),
        'prediction_feature_separated': analyzer.compute_prediction_feature_separated(),
        'feature_names': raw_data['feature_names'],
        # Mean absolute SHAP values for y-axis labels
        'mean_abs_shap': analyzer.compute_mean_abs_shap(),
        'mean_abs_shap_certain': analyzer.compute_mean_abs_shap(subgroup_mask=cert_masks.get('Certain')),
        'mean_abs_shap_uncertain': analyzer.compute_mean_abs_shap(subgroup_mask=cert_masks.get('Uncertain')),
        'mean_abs_shap_tp': analyzer.compute_mean_abs_shap(subgroup_mask=pred_masks.get('TP')),
        'mean_abs_shap_tn': analyzer.compute_mean_abs_shap(subgroup_mask=pred_masks.get('TN')),
        'mean_abs_shap_fp': analyzer.compute_mean_abs_shap(subgroup_mask=pred_masks.get('FP')),
        'mean_abs_shap_fn': analyzer.compute_mean_abs_shap(subgroup_mask=pred_masks.get('FN')),
        # Probability data for plot 4
        'proba_4d': raw_data['proba'],
    }
    
    # Feature analyses for subgroups
    for grp, masks in cert_masks.items():
        results['feature_certainty'][grp] = analyzer.compute_feature_sensitivity(
            mode='explainer', subgroup_mask=masks
        )
    
    for grp, masks in pred_masks.items():
        results['feature_prediction'][grp] = analyzer.compute_feature_sensitivity(
            mode='explainer', subgroup_mask=masks
        )
    
    analysis_results[model_name] = results
    print(f"  ‚úÖ {model_name} complete")

print("\n" + "="*60)
print("‚úÖ All analyses complete!")
print("="*60)

---
# Plots
---

In [None]:
# Plot 1-a: Overall Sensitivity (All 25 seeds pooled)
# With baselines (shaded regions) for all metrics
plot_1a_overall_pooled(analysis_results, metric='l2', show_baseline=True)
plot_1a_overall_pooled(analysis_results, metric='jaccard', show_baseline=True)
plot_1a_overall_pooled(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# Plot 1-b: Separated Sensitivity (Explainer vs Model)
# With baselines (shaded regions) for all metrics
plot_1b_separated(analysis_results, metric='l2', show_baseline=True)
plot_1b_separated(analysis_results, metric='jaccard', show_baseline=True)
plot_1b_separated(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# Plot 1-b: Separated Sensitivity (Explainer vs Model)
# With baselines (shaded regions) for all metrics
plot_1b_separated(analysis_results, metric='l2', show_baseline=True)
plot_1b_separated(analysis_results, metric='jaccard', show_baseline=True)
plot_1b_separated(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# ========================================================
# Plot 1-c: Seed-wise Decomposition (all models, all metrics)
# With baselines for all metrics (L2, Jaccard, RBO)
# ========================================================
for model_name in Config.MODELS:
    for metric in ['l2', 'jaccard', 'rbo']:
        plot_1c_seedwise(analysis_results, model_name, metric=metric, sensitivity_type='explainer', show_baseline=True)
        plot_1c_seedwise(analysis_results, model_name, metric=metric, sensitivity_type='model', show_baseline=True)

In [None]:
# ========================================================
# Plot 1-d: Feature-level Overall Sensitivity
# Shows which features are most unstable (largest mean absolute differences)
# ========================================================
for model_name in Config.MODELS:
    plot_1d_feature_overall(analysis_results, model_name, mode='explainer', use_boxplot=True)
    plot_1d_feature_overall(analysis_results, model_name, mode='model', use_boxplot=True)

In [None]:
# Plot 2-a: Certainty-based Sensitivity (Explainer Only)
# With baselines (shaded regions) for all metrics
plot_2a_certainty(analysis_results, metric='l2', show_baseline=True)
plot_2a_certainty(analysis_results, metric='jaccard', show_baseline=True)
plot_2a_certainty(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# Plot 2-a: Certainty-based Sensitivity (Explainer Only)
# With baselines (shaded regions) for all metrics
plot_2a_certainty(analysis_results, metric='l2', show_baseline=True)
plot_2a_certainty(analysis_results, metric='jaccard', show_baseline=True)
plot_2a_certainty(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# Plot 2-b: Certainty Feature Sensitivity (Explainer Only)
# Distinct colors: Certain=blue, Uncertain=red
for model_name in Config.MODELS:
    plot_2b_certainty_features(analysis_results, model_name, subgroup='Certain')
    plot_2b_certainty_features(analysis_results, model_name, subgroup='Uncertain')

# Plot 3-a: Prediction-based Sensitivity (Explainer Only)
# With baselines (shaded regions) for all metrics
plot_3a_prediction(analysis_results, metric='l2', show_baseline=True)
plot_3a_prediction(analysis_results, metric='jaccard', show_baseline=True)
plot_3a_prediction(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# Plot 3-a: Prediction-based Sensitivity (Explainer Only)
# With baselines (shaded regions) for all metrics
plot_3a_prediction(analysis_results, metric='l2', show_baseline=True)
plot_3a_prediction(analysis_results, metric='jaccard', show_baseline=True)
plot_3a_prediction(analysis_results, metric='rbo', show_baseline=True)

In [None]:
# Plot 3-b: Prediction Feature Sensitivity (Explainer Only)
# Distinct colors: TP=green, TN=blue, FP=red, FN=orange
for model_name in Config.MODELS:
    for subgroup in ['TP', 'TN', 'FP', 'FN']:
        plot_3b_prediction_features(analysis_results, model_name, subgroup=subgroup)

**Note:** Plot 3-b shows feature-wise sensitivity for prediction subgroups (explainer seeds only). Model seeds removed because they change prediction masks.

In [None]:
# ========================================================
# Plot 4: Probability Distribution by Model Seeds
# Shows how prediction probabilities vary across different model seeds
# ========================================================
for model_name in Config.MODELS:
    plot_4_probability_distribution(analysis_results, model_name)

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np

# # Model Îç∞Ïù¥ÌÑ∞ Î∂ÑÌè¨ ÌôïÏù∏
# model_data = analysis_results['dt']['separated']['model']['jaccard']

# fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# # ÌûàÏä§ÌÜ†Í∑∏Îû®
# axes[0].hist(model_data, bins=30, edgecolor='black', alpha=0.7, color='green')
# axes[0].set_xlabel('Jaccard Distance')
# axes[0].set_ylabel('Count')
# axes[0].set_title('Model Data - Histogram')
# axes[0].axvline(0, color='red', linestyle='--', label='Zero')
# axes[0].legend()

# # ÌÜµÍ≥Ñ Ï†ïÎ≥¥
# stats_text = f"""Model Data Statistics:

# Min: {np.min(model_data):.4f}
# Max: {np.max(model_data):.4f}
# Mean: {np.mean(model_data):.4f}
# Median: {np.median(model_data):.4f}
# Std: {np.std(model_data):.4f}

# Q1 (25%): {np.percentile(model_data, 25):.4f}
# Q3 (75%): {np.percentile(model_data, 75):.4f}

# Count = 0: {np.sum(model_data == 0)}
# Count < 0.1: {np.sum(model_data < 0.1)}
# Count < 0.2: {np.sum(model_data < 0.2)}
# """

# axes[1].text(0.1, 0.5, stats_text, fontsize=11, verticalalignment='center', family='monospace')
# axes[1].axis('off')
# plt.tight_layout()
# plt.show()

# print(f"\nData points: {len(model_data)}")
# print(f"% below 0.1: {100 * np.sum(model_data < 0.1) / len(model_data):.1f}%")
# print(f"% below 0.2: {100 * np.sum(model_data < 0.2) / len(model_data):.1f}%")