In [1]:
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
import os
import joblib

# --- Configuration and Data Paths ---
BEST_PARAMS = {
    'subsample': 0.6, 
    'n_estimators': 400, 
    'min_child_weight': 1, 
    'max_depth': 5, 
    'learning_rate': 0.01, 
    'colsample_bytree': 0.6
}
RANDOM_STATE = 42
TARGET_COL = "status"
# Since the files are in the current directory, we only use the name:
TRAIN_PATH = 'train_clean.csv' 
TEST_PATH = 'test_clean.csv'
OUTPUT_DIR = 'docs/tuning_results/'  # Target folder for the outputs

# Set Matplotlib Backend for saving the plots
plt.switch_backend('Agg')

def load_data_and_train_model():
    """Loads the data and retrains the best XGBoost model."""
    print("--- 1. Load data and retrain the model ---")
    try:
        # Load data
        train_clean = pd.read_csv(TRAIN_PATH)
        test_clean  = pd.read_csv(TEST_PATH)
        
        # Split into X and y (as in Cathy's notebook)
        X_train = train_clean.drop(columns=[TARGET_COL])
        y_train = train_clean[TARGET_COL]
        X_test  = test_clean.drop(columns=[TARGET_COL])
        
        # Create and train the best model with the found parameters
        best_model = XGBClassifier(
            **BEST_PARAMS, 
            random_state=RANDOM_STATE, 
            eval_metric="logloss", 
            use_label_encoder=False, 
            n_jobs=-1
        )
        print("Model is being retrained with the best parameters...")
        best_model.fit(X_train, y_train)
        
        # Save the model (optional, but useful for subsequent steps)
        joblib.dump(best_model, os.path.join(OUTPUT_DIR, 'best_xgboost_model.joblib'))

        # SHAP requires the feature names in the DataFrame
        return X_test, best_model

    except Exception as e:
        print(f"ERROR loading or training: {e}")
        return None, None

def run_shap_analysis(X_test, best_model):
    """Calculates SHAP values and creates the plots and insights."""
    print("\n--- 2. Calculate SHAP values ---")
    
    # SHAP Explainer for tree-based models
    explainer = shap.TreeExplainer(best_model)
    
    # Calculate SHAP values
    # We use X_test to explain how predictions are made on the test data.
    shap_values = explainer.shap_values(X_test)
    print("SHAP values calculated.")

    # SHAP values here are 2D (n_samples Ã— n_features)
    shap_values_class_1 = shap_values  # No need to index 1, since we only have 2 classes.
    
    # --- 3. Create SHAP Summary Plot (Global Importance) ---
    print("\n--- 3. Create SHAP Summary Plot ---")
    try:
        plt.figure(figsize=(10, 8))
        shap.summary_plot(
            shap_values_class_1, 
            X_test, 
            show=False,
            plot_type="dot", 
            max_display=10  # Show the top 10 features
        )
        summary_plot_path = os.path.join(OUTPUT_DIR, 'shap_summary_plot.png')
        plt.tight_layout()
        plt.savefig(summary_plot_path, dpi=300)
        plt.close()
        print(f"Summary plot saved at: {summary_plot_path}")
        
    except Exception as e:
        print(f"Error creating the summary plot: {e}")
        return
    
    # --- 4. Create SHAP Dependence Plots (Top 3 Features) and generate insights ---
    print("\n--- 4. Create SHAP Dependence Plots and Insights File ---")
    
    # Determine the top 3 features based on Mean(|SHAP|)
    mean_abs_shap = np.abs(shap_values_class_1).mean(0)  # No need for .values
    top_features_indices = np.argsort(mean_abs_shap)[::-1]
    feature_names = X_test.columns.to_list()
    top_feature_names = [feature_names[i] for i in top_features_indices[:3]]
    
    insights = ["# SHAP Insights (Lan)\n\n"]
    insights.append(f"The SHAP analysis of the tuned XGBoost model identifies the following top 3 features that have the greatest influence on the prediction of startup success ({TARGET_COL}):")
    
    for feature_name in top_feature_names:
        print(f"Creating dependence plot for: {feature_name}")
        try:
            plt.figure(figsize=(8, 6))
            shap.dependence_plot(
                feature_name,
                shap_values_class_1,  # Directly use without .values
                X_test,
                interaction_index="auto", 
                show=False,
                title=f"SHAP Dependence Plot: {feature_name}"
            )
            dependence_plot_path = os.path.join(OUTPUT_DIR, f"shap_dependence_{feature_name}.png")
            plt.tight_layout()
            plt.savefig(dependence_plot_path, dpi=300)
            plt.close()
            
            insights.append(f"\n## {feature_name}")
            insights.append(f"**Interpretation (according to Dependence Plot - See {dependence_plot_path}):**")
            insights.append("*(Please fill in the detailed observation here, e.g.: 'High values for this feature (red) lead to positive SHAP values, meaning that a higher value increases the likelihood of startup success.')*")
            
        except Exception as e:
            print(f"Error creating the dependence plot for {feature_name}: {e}")

    # Save the insights
    insights_path = os.path.join(OUTPUT_DIR, 'shap_insights.md')
    with open(insights_path, 'w') as f:
        f.write('\n'.join(insights))
    print(f"\nInsights file created. Please fill in the explanations at: {insights_path}")


if __name__ == "__main__":
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    X_test, best_model = load_data_and_train_model()
    
    if X_test is not None and best_model is not None:
        run_shap_analysis(X_test, best_model)


--- 1. Load data and retrain the model ---
Model is being retrained with the best parameters...


Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)



--- 2. Calculate SHAP values ---
SHAP values calculated.

--- 3. Create SHAP Summary Plot ---
Summary plot saved at: docs/tuning_results/shap_summary_plot.png

--- 4. Create SHAP Dependence Plots and Insights File ---
Creating dependence plot for: age_last_milestone_year
Creating dependence plot for: milestones
Creating dependence plot for: funding_total_usd

Insights file created. Please fill in the explanations at: docs/tuning_results/shap_insights.md
