# XGBoost Model Decision Analysis

This notebook analyzes the trained XGBoost model for CCS prediction:
- Feature importance analysis
- Tree structure visualization
- SHAP value analysis for comprehensive feature impact understanding

In [1]:
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sqlite3
import sys
sys.path.append('../..')

from xgboost import plot_tree, plot_importance
import shap

from utils import Utils

  from .autonotebook import tqdm as notebook_tqdm


## 1. Load Model and Prepare Data

In [2]:
# Load the trained model
model = joblib.load('ccsbase2.joblib')
print(f"Model loaded successfully")
print(f"Number of estimators: {model.n_estimators}")
print(f"Max depth: {model.max_depth}")
print(f"Learning rate: {model.learning_rate}")

Model loaded successfully
Number of estimators: 6000
Max depth: 10
Learning rate: 0.03


configuration generated by an older version of XGBoost, please export the model by calling
`Booster.save_model` from that version first, then load it back in current version. See:

    https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html

for more details about differences between saving model and serializing.

  setstate(state)


In [4]:
# Get the adduct list (same as used during training)
database_file = '../../ccs.db'
conn = sqlite3.connect(database_file)
query = "SELECT adduct FROM master_clean GROUP BY adduct HAVING COUNT(*) >= 100 ORDER BY adduct"
adducts = sorted(pd.read_sql_query(query, conn).to_numpy().tolist())
adducts = [adduct[0] for adduct in adducts]
conn.close()

print(f"Number of adducts: {len(adducts)}")
print(f"Adducts: {adducts}")

DatabaseError: Execution failed on sql 'SELECT adduct FROM master_clean GROUP BY adduct HAVING COUNT(*) >= 100 ORDER BY adduct': no such table: master_clean

In [5]:
# Create feature names
feature_names = ['MolecularWeight', 'AdductMass', 'Charge', 'LabuteASA']

# Add adduct one-hot encoding names
for adduct in adducts:
    feature_names.append(f'Adduct_{adduct}')
feature_names.append('Adduct_Other')

# Add Morgan fingerprint bits
for i in range(1024):
    feature_names.append(f'MorganFP_{i}')

print(f"Total features: {len(feature_names)}")

NameError: name 'adducts' is not defined

In [6]:
# Load test data for SHAP analysis
test_df = pd.read_csv('../../pretrained/test_data.csv')
print(f"Test data shape: {test_df.shape}")

utils = Utils()

# Calculate features for test data
X_list = []
y_list = []
valid_indices = []

for idx, row in test_df.iterrows():
    feat_values = utils.calculate_descriptors(
        row['smi'], row['mass'], row['z'], adducts, row['adduct']
    )
    if feat_values is not None:
        X_list.append(feat_values)
        y_list.append(row['ccs'])
        valid_indices.append(idx)

X_test = np.array(X_list)
y_test = np.array(y_list)

print(f"Valid test samples: {len(X_test)}")

Test data shape: (9749, 8)


NameError: name 'adducts' is not defined

## 2. Feature Importance Analysis

In [None]:
# Get feature importances from model
importance_types = ['weight', 'gain', 'cover']

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

for ax, imp_type in zip(axes, importance_types):
    importance = model.get_booster().get_score(importance_type=imp_type)
    
    # Convert to dataframe and sort
    imp_df = pd.DataFrame({
        'feature': list(importance.keys()),
        'importance': list(importance.values())
    }).sort_values('importance', ascending=True).tail(20)
    
    # Map feature indices to names
    imp_df['feature_name'] = imp_df['feature'].apply(
        lambda x: feature_names[int(x[1:])] if x.startswith('f') else x
    )
    
    ax.barh(imp_df['feature_name'], imp_df['importance'])
    ax.set_xlabel(f'Importance ({imp_type})')
    ax.set_title(f'Top 20 Features by {imp_type.capitalize()}')

plt.tight_layout()
plt.savefig('feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Create comprehensive importance table
all_features = set()
for imp_type in importance_types:
    importance = model.get_booster().get_score(importance_type=imp_type)
    all_features.update(importance.keys())

importance_data = []
for feat in all_features:
    feat_idx = int(feat[1:]) if feat.startswith('f') else -1
    feat_name = feature_names[feat_idx] if feat_idx >= 0 and feat_idx < len(feature_names) else feat
    
    row_data = {'feature': feat, 'feature_name': feat_name}
    for imp_type in importance_types:
        importance = model.get_booster().get_score(importance_type=imp_type)
        row_data[imp_type] = importance.get(feat, 0)
    importance_data.append(row_data)

importance_df = pd.DataFrame(importance_data)
importance_df = importance_df.sort_values('gain', ascending=False)

print("Top 30 Features by Gain:")
importance_df.head(30)

## 3. Tree Structure Visualization

In [None]:
# Plot a single tree (first tree)
fig, ax = plt.subplots(figsize=(30, 15))
plot_tree(model, num_trees=0, ax=ax, rankdir='TB')
plt.title('Tree 0 Structure')
plt.savefig('tree_0.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
# Plot first 5 trees for comparison
for tree_idx in range(min(5, model.n_estimators)):
    fig, ax = plt.subplots(figsize=(25, 12))
    plot_tree(model, num_trees=tree_idx, ax=ax, rankdir='TB')
    plt.title(f'Tree {tree_idx} Structure')
    plt.savefig(f'tree_{tree_idx}.png', dpi=80, bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
# Get tree statistics
booster = model.get_booster()
trees_df = booster.trees_to_dataframe()

print(f"Total nodes across all trees: {len(trees_df)}")
print(f"\nNode types:")
print(trees_df['Feature'].value_counts().head(20))

In [None]:
# Analyze tree depth distribution
depth_stats = trees_df.groupby('Tree')['Depth'].max()

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(depth_stats, bins=20, edgecolor='black')
ax.set_xlabel('Max Tree Depth')
ax.set_ylabel('Count')
ax.set_title('Distribution of Tree Depths')
plt.savefig('tree_depth_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Mean max depth: {depth_stats.mean():.2f}")
print(f"Std max depth: {depth_stats.std():.2f}")

## 4. SHAP Value Analysis

In [None]:
# Create SHAP explainer
explainer = shap.TreeExplainer(model)

# Use a sample for faster computation (adjust size as needed)
sample_size = min(1000, len(X_test))
np.random.seed(42)
sample_indices = np.random.choice(len(X_test), sample_size, replace=False)
X_sample = X_test[sample_indices]

print(f"Computing SHAP values for {sample_size} samples...")
shap_values = explainer.shap_values(X_sample)
print("SHAP values computed.")

In [None]:
# SHAP Summary Plot (bar)
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_sample, feature_names=feature_names, 
                  plot_type='bar', max_display=30, show=False)
plt.title('SHAP Feature Importance (Mean |SHAP|)')
plt.tight_layout()
plt.savefig('shap_importance_bar.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# SHAP Beeswarm Plot (shows direction of feature impact)
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_sample, feature_names=feature_names, 
                  max_display=30, show=False)
plt.title('SHAP Summary (Beeswarm)')
plt.tight_layout()
plt.savefig('shap_beeswarm.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# SHAP Dependence Plots for top features
# Find top features by mean absolute SHAP
mean_abs_shap = np.abs(shap_values).mean(axis=0)
top_feature_indices = np.argsort(mean_abs_shap)[-6:][::-1]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, feat_idx in enumerate(top_feature_indices):
    shap.dependence_plot(feat_idx, shap_values, X_sample, 
                         feature_names=feature_names, ax=axes[i], show=False)

plt.tight_layout()
plt.savefig('shap_dependence_top6.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# SHAP Waterfall plot for a single prediction
# Show how each feature contributes to a specific prediction
sample_idx = 0

plt.figure(figsize=(12, 8))
shap.waterfall_plot(shap.Explanation(
    values=shap_values[sample_idx], 
    base_values=explainer.expected_value,
    data=X_sample[sample_idx],
    feature_names=feature_names
), max_display=20, show=False)
plt.title(f'SHAP Waterfall for Sample {sample_idx}')
plt.tight_layout()
plt.savefig('shap_waterfall_example.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Force plot for a few samples
shap.initjs()

# Single prediction force plot
shap.force_plot(explainer.expected_value, shap_values[0], X_sample[0], 
                feature_names=feature_names)

## 5. Grouped Feature Analysis

In [None]:
# Group features and analyze importance by category
feature_groups = {
    'Molecular Properties': [0, 1, 2, 3],  # MW, AdductMass, Charge, LabuteASA
    'Adduct Encoding': list(range(4, 4 + len(adducts) + 1)),
    'Morgan Fingerprints': list(range(4 + len(adducts) + 1, len(feature_names)))
}

group_importance = {}
for group_name, indices in feature_groups.items():
    group_shap = np.abs(shap_values[:, indices]).mean()
    group_importance[group_name] = group_shap

# Plot grouped importance
fig, ax = plt.subplots(figsize=(10, 5))
groups = list(group_importance.keys())
values = list(group_importance.values())

bars = ax.bar(groups, values, color=['#2ecc71', '#3498db', '#9b59b6'])
ax.set_ylabel('Mean |SHAP Value|')
ax.set_title('Feature Group Importance')

# Add value labels on bars
for bar, val in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
            f'{val:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig('grouped_feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nFeature Group Importance:")
for group, imp in group_importance.items():
    print(f"  {group}: {imp:.4f}")

In [None]:
# Analyze top Morgan fingerprint bits
fp_start = 4 + len(adducts) + 1
fp_shap = np.abs(shap_values[:, fp_start:]).mean(axis=0)

top_fp_indices = np.argsort(fp_shap)[-20:][::-1]

print("Top 20 Morgan Fingerprint Bits by SHAP Importance:")
for i, idx in enumerate(top_fp_indices):
    print(f"  {i+1}. Bit {idx}: {fp_shap[idx]:.4f}")

# Plot top fingerprint bits
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(range(20), [fp_shap[i] for i in top_fp_indices])
ax.set_xticks(range(20))
ax.set_xticklabels([f'Bit {i}' for i in top_fp_indices], rotation=45, ha='right')
ax.set_ylabel('Mean |SHAP Value|')
ax.set_title('Top 20 Morgan Fingerprint Bits')
plt.tight_layout()
plt.savefig('top_fingerprint_bits.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Feature Interactions

In [None]:
# Compute SHAP interaction values (computationally expensive)
# Using a smaller sample for interaction analysis
interaction_sample_size = min(200, len(X_sample))
X_interaction = X_sample[:interaction_sample_size]

print(f"Computing SHAP interaction values for {interaction_sample_size} samples...")
print("(This may take a while)")

# Focus on top features only for interaction analysis
top_n = 10
top_indices = np.argsort(mean_abs_shap)[-top_n:][::-1]

# Get interaction values for the full model
shap_interaction = explainer.shap_interaction_values(X_interaction)
print("Interaction values computed.")

In [None]:
# Interaction heatmap for top features
interaction_matrix = np.abs(shap_interaction[:, top_indices, :][:, :, top_indices]).mean(axis=0)

top_feature_names = [feature_names[i] for i in top_indices]

fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(interaction_matrix, cmap='YlOrRd')
ax.set_xticks(range(len(top_indices)))
ax.set_yticks(range(len(top_indices)))
ax.set_xticklabels(top_feature_names, rotation=45, ha='right')
ax.set_yticklabels(top_feature_names)
plt.colorbar(im, label='Mean |Interaction Value|')
ax.set_title('SHAP Interaction Values (Top 10 Features)')
plt.tight_layout()
plt.savefig('shap_interaction_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Summary

In [None]:
# Generate summary report
print("=" * 60)
print("XGBoost Model Analysis Summary")
print("=" * 60)

print(f"\nModel Configuration:")
print(f"  - Number of trees: {model.n_estimators}")
print(f"  - Max depth: {model.max_depth}")
print(f"  - Learning rate: {model.learning_rate}")
print(f"  - Total features: {len(feature_names)}")

print(f"\nTop 10 Most Important Features (by SHAP):")
top_10_indices = np.argsort(mean_abs_shap)[-10:][::-1]
for i, idx in enumerate(top_10_indices):
    print(f"  {i+1}. {feature_names[idx]}: {mean_abs_shap[idx]:.4f}")

print(f"\nFeature Group Contributions:")
total_importance = sum(group_importance.values())
for group, imp in group_importance.items():
    pct = (imp / total_importance) * 100
    print(f"  - {group}: {pct:.1f}%")

print("\n" + "=" * 60)