In [None]:
%run '/home/christianl/Zhang-Lab/Zhang Lab Code/Boilerplate/Fig_config_utilities.py'

In [None]:
from shap_model_comparison import SHAPModelComparator
import os
import json
import pandas as pd
import numpy as np 

# output directory
os.makedirs('/home/christianl/Zhang-Lab/Zhang Lab Data/Saved SHAP values', exist_ok=True)

In [None]:
# loading testing set 
x_test_centered_df = pd.DataFrame(x_test_centered) 
subsetted_x_test_centered = x_test_centered_df.sample(n=1000, random_state=42)
feature_names = subsetted_x_test_centered.columns.tolist()

# loading feature names
feature_names = subsetted_x_test_centered.columns.tolist()


In [None]:
# fixing compatibility issue between version 2.0+ XGBoost class and SHAP package

def fix_xgboost_for_shap(model):
    try:
        booster = model.get_booster() if hasattr(model, 'get_booster') else model
        config = json.loads(booster.save_config())
        base_score = config['learner']['learner_model_param']['base_score']
        if base_score.startswith('[') and base_score.endswith(']'):
            base_score_float = float(base_score.strip('[]'))
            config['learner']['learner_model_param']['base_score'] = str(base_score_float)
            booster.load_config(json.dumps(config))
            print("✓ Fixed XGBoost model for SHAP compatibility")
    except Exception as e:
        print(f"Note: XGBoost fix not needed or failed: {e}")
    return model

xgbrf_loaded = fix_xgboost_for_shap(xgbrf_loaded)

In [None]:
# loading trained models
models = {
    'MLR': mlr_loaded,
    'XGBRF': xgbrf_loaded
}  # add when RNN retrained 'LEMBAS-RNN': rnn 06/01/26

In [None]:
# for when RNN is retrained and needs to be included 

import torch

class PyTorchRNNWrapper:
    def __init__(self, model, device='cpu'):
        self.model = model
        self.device = device
        self.model.eval()
    
    def predict(self, X):
        if isinstance(X, pd.DataFrame):
            X = X.values
        
        X_tensor = torch.FloatTensor(X).to(self.device)
        
        if len(X_tensor.shape) == 2:
            X_tensor = X_tensor.unsqueeze(1)  # Add sequence dimension
        
        with torch.no_grad():
            output = self.model(X_tensor)
        
        return output.cpu().numpy().flatten()

# Load and wrap
rnn_base_model = torch.load('models/lembas_rnn.pth')
rnn_model = PyTorchRNNWrapper(rnn_base_model)

# Test
test_pred = rnn_model.predict(subsetted_x_test_centered[:5])
print(f"✓ RNN loaded and wrapped. Test predictions: {test_pred[:3]}")

In [None]:
print("\n" + "="*80)
print("Initializing SHAP Comparator")
print("="*80)

comparator = SHAPModelComparator(
    models_dict=models,
    X_data=subsetted_x_test_centered, 
    feature_names=feature_names,
    background_samples=100  # make lower if SHAP is slow 
)

print(f"✓ Comparator initialized with {len(models)} models")
print(f"✓ Data: {subsetted_x_test_centered.shape[0]} samples, {subsetted_x_test_centered.shape[1]} features")

In [None]:
print("\n" + "="*80)
print("Computing MLR SHAP Values (this may take a few minutes...)")
print("="*80 + "\n")

# computing for MLR
mlr_shap_values = comparator.compute_shap_values('MLR','linear')

print("\n✓ SHAP computation complete!")


Computing MLR SHAP Values (this may take a few minutes...)

Computing SHAP values for MLR...


In [None]:
print("\n" + "="*80)
print("Computing XGBRF SHAP Values (this may take a few minutes...)")
print("="*80 + "\n")

# computing for XGBRF automatically
xgbrf_shap_values = comparator.compute_shap_values('XGBRF','kernel')

print("\n✓ SHAP computation complete!")