In [None]:
!pip install shap
!pip install graphviz



In [1]:
import numpy as np
import pandas as pd
import pickle
import shap
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
import warnings
warnings.filterwarnings('ignore')
import os


os.makedirs('../models/explanations/shap', exist_ok=True)
os.makedirs('../models/explanations/local_trees', exist_ok=True)

# 1. Surrogate Tree


def load_data_and_models():
    """Load test data and trained models"""
    print("Loading data and models...")

    # Load data
    data_path = "../data/processed/"
    X_test = np.load(f"{data_path}X_test_scaled.npy")
    y_test = np.load(f"{data_path}y_test.npy")

    # Load feature names
    try:
        feature_names = np.load(f"{data_path}feature_names.npy")
        feature_names = [str(name) for name in feature_names]
        print(f" Loaded {len(feature_names)} feature names")
    except Exception as e:
        print(f" Error loading feature names: {e}")
        feature_names = [f'Feature_{i}' for i in range(X_test.shape[1])]
        print(" Using default feature names")

    # Load models
    models = {}
    individual_models = ['lightgbm', 'xgboost', 'random_forest', 'cnn', 'rnn']

    for model_name in individual_models:
        try:
            model_path = f'../models/advanced_models/{model_name}'
            if model_name in ['cnn', 'rnn']:
                from tensorflow.keras.models import load_model
                models[model_name] = load_model(f'{model_path}.h5')
                print(f" Loaded {model_name}")
            else:
                with open(f'{model_path}.pkl', 'rb') as f:
                    models[model_name] = pickle.load(f)
                print(f" Loaded {model_name}")
        except Exception as e:
            print(f" Could not load {model_name}: {e}")

    return X_test, y_test, feature_names, models

X_test, y_test, feature_names, models = load_data_and_models()

# Select working model
working_model = None
working_model_name = None

for model_name, model_obj in models.items():
    try:
        if hasattr(model_obj, 'predict_proba'):
            test_pred = model_obj.predict_proba(X_test[:1])
            working_model = model_obj
            working_model_name = model_name
            print(f" Using {model_name} for analysis")
            break
    except:
        continue

if working_model is None:
    raise Exception("No working models available!")



# 2. SHAP


print("\n" + "="*50)
print("SHAP ANALYSIS")
print("="*50)

def perform_shap_analysis(model, X_data, feature_names, model_name):
    """Perform comprehensive SHAP analysis"""
    print(" Computing SHAP values...")

    # Create explainer based on model type
    if 'lightgbm' in model_name.lower() or 'xgboost' in model_name.lower() or 'random' in model_name.lower():
        explainer = shap.TreeExplainer(model)
    else:
        explainer = shap.KernelExplainer(model.predict_proba, X_data[:100])

    # Calculate SHAP values
    shap_values = explainer.shap_values(X_data)

    # Handle binary classification
    if isinstance(shap_values, list):
        shap_values = shap_values[1]  # Use positive class

    print(f" SHAP values computed: {shap_values.shape}")

    # Create SHAP summary plot
    plt.figure(figsize=(10, 8))
    shap.summary_plot(shap_values, X_data, feature_names=feature_names, show=False)
    plt.title(f'SHAP Summary Plot - {model_name.upper()}', fontsize=14)
    plt.tight_layout()
    plt.savefig('../models/explanations/shap/shap_summary.png', dpi=300, bbox_inches='tight')
    plt.savefig('../models/explanations/shap/shap_summary.svg', format='svg', bbox_inches='tight')
    plt.show()

    # Create SHAP bar plot
    plt.figure(figsize=(10, 6))
    shap.summary_plot(shap_values, X_data, feature_names=feature_names, plot_type="bar", show=False)
    plt.title(f'SHAP Feature Importance - {model_name.upper()}', fontsize=14)
    plt.tight_layout()
    plt.savefig('../models/explanations/shap/shap_importance.png', dpi=300, bbox_inches='tight')
    plt.savefig('../models/explanations/shap/shap_importance.svg', format='svg', bbox_inches='tight')
    plt.show()

    return shap_values, explainer

# Run SHAP analysis
shap_values, explainer = perform_shap_analysis(working_model, X_test, feature_names, working_model_name)


# 3. Calculate Surrogate Tree


print("\n" + "="*50)
print("SURROGATE DECISION TREE")
print("="*50)

def create_surrogate_tree(model, X_data, y_true, feature_names, model_name):
    """Create surrogate decision tree that mimics the complex model"""
    print(" Generating model predictions...")

    # Get predictions from complex model
    if hasattr(model, 'predict_proba'):
        y_pred_proba = model.predict_proba(X_data)
        y_pred = (y_pred_proba[:, 1] > 0.5).astype(int)
    else:
        y_pred_raw = model.predict(X_data)
        if len(y_pred_raw.shape) > 1:
            y_pred = (y_pred_raw[:, 1] > 0.5).astype(int)
        else:
            y_pred = (y_pred_raw > 0.5).astype(int)

    print(f"Predictions generated: {len(y_pred)} samples")
    print(f"Class distribution: {pd.Series(y_pred).value_counts().to_dict()}")

    # Handle feature alignment
    n_features = min(X_data.shape[1], len(feature_names))
    X_tree = X_data[:, :n_features]
    feature_names_tree = feature_names[:n_features]

    print(" Training surrogate decision tree...")

    # Train surrogate tree
    surrogate_tree = DecisionTreeClassifier(
        max_depth=4,
        min_samples_split=30,
        min_samples_leaf=15,
        random_state=42
    )

    surrogate_tree.fit(X_tree, y_pred)
    print(" Surrogate tree trained successfully!")

    # Calculate accuracy
    accuracy = surrogate_tree.score(X_tree, y_pred)
    print(f"Surrogate tree accuracy: {accuracy:.3f}")

    return surrogate_tree, X_tree, feature_names_tree, y_pred, accuracy

# Create surrogate tree
surrogate_tree, X_tree, feature_names_tree, y_pred, accuracy = create_surrogate_tree(
    working_model, X_test, y_test, feature_names, working_model_name
)


# 4. Visualize Surrogate Tree


def visualize_decision_tree(tree_model, feature_names, class_names, model_name, accuracy):
    """Create compact decision tree visualizations"""
    print("Creating tree visualizations...")

    # Main compact tree
    plt.figure(figsize=(16, 10))
    plot_tree(tree_model,
             feature_names=feature_names,
             class_names=class_names,
             filled=True,
             rounded=True,
             fontsize=9,
             proportion=True,
             impurity=False)

    plt.title(f'Surrogate Decision Tree for {model_name.upper()}\n(Accuracy: {accuracy:.3f})',
              fontsize=14, pad=10)
    plt.tight_layout(pad=1.0)
    plt.savefig('../models/explanations/local_trees/compact_decision_tree.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.savefig('../models/explanations/local_trees/compact_decision_tree.svg',
                format='svg', bbox_inches='tight', pad_inches=0.1)
    plt.show()

    # Minimal version
    plt.figure(figsize=(12, 8))
    plot_tree(tree_model,
             feature_names=feature_names,
             class_names=['No', 'Yes'],
             filled=True,
             rounded=True,
             fontsize=8,
             proportion=True,
             impurity=False)

    plt.title(f'Minimal Tree - {model_name}', fontsize=12)
    plt.tight_layout(pad=0.8)
    plt.savefig('../models/explanations/local_trees/minimal_decision_tree.png',
                dpi=300, bbox_inches='tight', pad_inches=0.05)
    plt.savefig('../models/explanations/local_trees/minimal_decision_tree.svg',
                format='svg', bbox_inches='tight', pad_inches=0.05)
    plt.show()

# Visualize trees
visualize_decision_tree(surrogate_tree, feature_names_tree,
                       ['No Heart Disease', 'Heart Disease'],
                       working_model_name, accuracy)


# 5. Rules and Featurees

print("\n" + "="*50)
print("DECISION RULES & FEATURE ANALYSIS")
print("="*50)

# Export decision rules
tree_rules = export_text(surrogate_tree,
                       feature_names=feature_names_tree,
                       decimals=2,
                       show_weights=True)

print("Key Decision Rules:")
print("=" * 40)
lines = tree_rules.split('\n')
for line in lines[:15]:  # Show first 15 rules
    print(line)

# Save full rules to file
with open('../models/explanations/local_trees/decision_rules.txt', 'w') as f:
    f.write(tree_rules)

# Feature importance comparison
print("\nFEATURE IMPORTANCE COMPARISON")
print("=" * 40)

# SHAP importance
shap_importance = pd.DataFrame({
    'feature': feature_names_tree,
    'shap_importance': np.abs(shap_values).mean(0)[:len(feature_names_tree)]
}).sort_values('shap_importance', ascending=False)

# Surrogate tree importance
tree_importance = pd.DataFrame({
    'feature': feature_names_tree,
    'tree_importance': surrogate_tree.feature_importances_
}).sort_values('tree_importance', ascending=False)

print("Top 10 Features - SHAP vs Tree:")
print("SHAP Importance:")
print(shap_importance.head(10).to_string(index=False))
print("\nTree Importance:")
print(tree_importance.head(10).to_string(index=False))


# 6. Summarty


print("\n" + "="*50)
print(" COMPREHENSIVE ANALYSIS SUMMARY")
print("="*50)

print(f"""
MODEL PERFORMANCE:
------------------
• Model Used: {working_model_name.upper()}
• Test Samples: {len(X_test):,}
• Surrogate Tree Accuracy: {accuracy:.3f}
• Class Distribution: {pd.Series(y_pred).value_counts().to_dict()}

KEY FINDINGS:
-------------
1. Primary Risk Factors: {list(tree_importance.head(3)['feature'])}
2. Decision Depth: {surrogate_tree.get_depth()} levels
3. Number of Leaves: {surrogate_tree.get_n_leaves()}
4. Top SHAP Features: {list(shap_importance.head(3)['feature'])}

CLINICAL INSIGHTS:
------------------
• The surrogate tree provides transparent decision pathways
• {accuracy:.1%} accuracy in mimicking complex model behavior
• Clear risk stratification rules identified
• Suitable for clinical decision support implementation

FILES CREATED:
--------------
• SHAP summary plots (.png, .svg)
• Surrogate tree visualizations (.png, .svg)
• Decision rules (text file)
• Feature importance comparison
""")

# Save summary to file
summary_text = f"""
SHAP + Surrogate Tree Analysis Summary
======================================

Model: {working_model_name}
Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')}

Performance Metrics:
- Test samples: {len(X_test):,}
- Surrogate accuracy: {accuracy:.3f}
- Tree depth: {surrogate_tree.get_depth()}
- Number of leaves: {surrogate_tree.get_n_leaves()}

Top Features:
- SHAP: {list(shap_importance.head(5)['feature'])}
- Tree: {list(tree_importance.head(5)['feature'])}

Key Decision Patterns:
{tree_rules[:1000]}...
"""

with open('../models/explanations/analysis_summary.txt', 'w') as f:
    f.write(summary_text)

print("Analysis completed successfully!")
print("All files saved to: ../models/explanations/")

📁 Loading data and models...


FileNotFoundError: [Errno 2] No such file or directory: '../data/processed/X_test_scaled.npy'