## 1. Load Trained Models and Data

Load the best-performing model (HistGradientBoosting) and the test dataset for interpretation.

In [None]:
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from sklearn.inspection import permutation_importance
import warnings

warnings.filterwarnings('ignore')

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Load data
data_path = Path('data/initial-smell-dataset.csv')
df = pd.read_csv(data_path)

# Prepare features and target
X = df.drop(columns=['sample_id', 'scent_name', 'scent_id'], errors='ignore')
y = df['scent_id'].astype(int)

print(f"Data shape: {X.shape}")
print(f"Target shape: {y.shape}")

In [None]:
# For demo purposes, split data again (in practice, use the saved test set)
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=42, stratify=y
)

# Load best model (HistGradientBoosting)
model_path = Path('models/histgradientboosting_model.joblib')
best_model = joblib.load(model_path)

print(f"Loaded model from {model_path}")
print(f"Test set size: {X_test.shape[0]}")

## 2. Permutation Importance

**What it does**: Shuffles each feature and measures drop in model performance.

**Interpretation**: Higher importance = bigger performance drop when feature is shuffled = feature is critical for predictions.

**Advantages**: Model-agnostic, easy to understand

In [None]:
# Compute permutation importance
print("Computing permutation importance (this may take a minute)...\n")

perm_importance = permutation_importance(
    best_model, X_test, y_test,
    n_repeats=10,
    random_state=42,
    n_jobs=-1,
    scoring='accuracy'
)

# Create DataFrame
perm_df = pd.DataFrame({
    'Feature': X.columns,
    'Importance': perm_importance.importances_mean,
    'Std': perm_importance.importances_std
}).sort_values('Importance', ascending=False).reset_index(drop=True)

print("\nPermutation Importance Results:")
print(perm_df.to_string(index=False))

## 2.1 Visualize Permutation Importance

In [None]:
# Plot top 15 features
fig, ax = plt.subplots(figsize=(10, 8))

top_n = 15
perm_top = perm_df.head(top_n)

ax.barh(range(len(perm_top)), perm_top['Importance'],
        xerr=perm_top['Std'], alpha=0.8, color='steelblue', edgecolor='black')
ax.set_yticks(range(len(perm_top)))
ax.set_yticklabels(perm_top['Feature'], fontsize=11)
ax.set_xlabel('Drop in Accuracy (Importance)', fontweight='bold', fontsize=12)
ax.set_title(f'Top {top_n} Features - Permutation Importance\n(HistGradientBoosting)',
             fontsize=14, fontweight='bold')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

# Summary statistics
print(f"\nTotal features analyzed: {len(perm_df)}")
print(f"Features with positive importance: {(perm_df['Importance'] > 0).sum()}")
print(f"Features with negative importance: {(perm_df['Importance'] < 0).sum()}")
print(f"\nTop 5 Most Important:")
for idx, row in perm_df.head(5).iterrows():
    print(f"  {row['Feature']}: {row['Importance']:.6f} (Â±{row['Std']:.6f})")

## 3. SHAP (SHapley Additive exPlanations)

**What it does**: Uses game theory to fairly distribute model predictions among features.

**Interpretation**: SHAP value = contribution of feature to pushing prediction away from base value.

**Advantages**: Theoretically sound, provides local interpretability per sample

**Note**: Install via `pip install shap` if not already installed.

In [None]:
# Try importing SHAP
try:
    import shap
    HAS_SHAP = True
    print("SHAP library loaded successfully.\n")
except ImportError:
    HAS_SHAP = False
    print("SHAP not installed. Install via: pip install shap")
    print("Continuing with permutation importance only.\n")

if not HAS_SHAP:
    print("Skipping SHAP analysis...")

In [None]:
if HAS_SHAP:
    print("Computing SHAP values (this may take 2-5 minutes)...\n")
    
    try:
        # Get the model component from pipeline
        model_component = best_model.named_steps['model']
        preprocessor = best_model.named_steps['pre']
        
        # Transform test data
        X_test_transformed = preprocessor.transform(X_test)
        
        # Create SHAP explainer
        # Use TreeExplainer for HistGradientBoosting
        explainer = shap.TreeExplainer(model_component)
        shap_values = explainer.shap_values(X_test_transformed)
        
        print("SHAP values computed successfully.")
        print(f"SHAP values shape: {np.array(shap_values).shape}")
        
        # For multiclass, aggregate across classes
        if isinstance(shap_values, list):
            # Mean absolute SHAP values across classes
            shap_values_abs = np.mean([np.abs(sv) for sv in shap_values], axis=0)
        else:
            shap_values_abs = np.abs(shap_values)
        
        # Global feature importance from SHAP
        feature_importance_shap = np.mean(shap_values_abs, axis=0)
        
        # Get feature names from preprocessor
        try:
            feature_names_transformed = preprocessor.get_feature_names_out().tolist()
        except:
            feature_names_transformed = [f"Feature_{i}" for i in range(X_test_transformed.shape[1])]
        
        # Create SHAP importance dataframe
        shap_df = pd.DataFrame({
            'Feature': feature_names_transformed,
            'SHAP_Importance': feature_importance_shap
        }).sort_values('SHAP_Importance', ascending=False).reset_index(drop=True)
        
        print("\nSHAP Feature Importance (Top 20):")
        print(shap_df.head(20).to_string(index=False))
        
    except Exception as e:
        print(f"Error computing SHAP: {e}")
        HAS_SHAP = False

## 3.1 SHAP Summary Plots

## 4. Comparison: Permutation vs SHAP Importance

Compare rankings between permutation importance and SHAP to validate feature importance findings.

In [None]:
if HAS_SHAP:
    # Compare top features from both methods
    print("\nComparison of Top Features:\n")
    print(f"{'Permutation Top 10':<30} | {'SHAP Top 10':<30}")
    print("-" * 62)
    
    perm_top10 = perm_df.head(10)['Feature'].values
    shap_top10 = shap_df.head(10)['Feature'].values
    
    for i in range(10):
        perm_feat = perm_top10[i] if i < len(perm_top10) else ""
        shap_feat = shap_top10[i] if i < len(shap_top10) else ""
        print(f"{perm_feat:<30} | {shap_feat:<30}")
    
    # Find common top features
    common_top5 = set(perm_top10[:5]) & set(shap_top10[:5])
    print(f"\nCommon features in both Top 5: {common_top5}")
else:
    print("SHAP not available. Showing permutation importance only.")

## 5. Feature Importance by Sensor Type

Group features by sensor type to understand which **types** of sensors matter most for scent detection.

In [None]:
# Define sensor groups
sensor_groups = {
    'Gas Sensors': ['gas_bme', 'srawVoc', 'srawNox', 'NO2', 'ethanol', 'VOC_multichannel', 'COandH2'],
    'Environmental': ['temp_C', 'humidity_pct', 'pressure_kPa'],
    'Temporal': ['time_s'],
    'Metadata': ['trial_number', 'phase']
}

# Aggregate importance by group
group_importance = {}
for group, features in sensor_groups.items():
    group_features = [f for f in perm_df['Feature'] if f in features]
    group_imp = perm_df[perm_df['Feature'].isin(group_features)]['Importance'].sum()
    group_importance[group] = group_imp

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
groups = list(group_importance.keys())
importances = list(group_importance.values())

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
bars = ax.bar(groups, importances, alpha=0.8, color=colors, edgecolor='black')

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{height:.4f}', ha='center', va='bottom', fontweight='bold')

ax.set_ylabel('Cumulative Importance', fontweight='bold', fontsize=12)
ax.set_title('Feature Importance by Sensor Type (Permutation)', fontsize=14, fontweight='bold')
plt.xticks(rotation=15, ha='right')
plt.tight_layout()
plt.show()

print("Feature Importance by Sensor Type:")
for group, imp in sorted(group_importance.items(), key=lambda x: x[1], reverse=True):
    print(f"  {group}: {imp:.6f}")