In [2]:
#!/usr/bin/env python
# coding: utf-8
#
# LightGBM SHAP Analysis - v2 (GPU-Accelerated & Robust)
#
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Dict, Any

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.preprocessing import StandardScaler
import shap
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=UserWarning)

class LightGBM_SHAP_Analysis:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.df = self._load_data(config['csv_path'])
        self.feat_cols = [c for c in self.df.columns if c not in config['meta_cols']]
        self.X_all, self.y_all = self._make_windows()

    def _load_data(self, path: str | Path) -> pd.DataFrame:
        print("─" * 60 + "\n1. Loading and cleaning data...")
        df = pd.read_csv(path).loc[:, ~pd.read_csv(path).columns.duplicated()]
        req = set(self.config['meta_cols'])
        if missing := req - set(df.columns): raise KeyError(f"Missing cols: {missing}")
        df[self.config['quarter_col']] = pd.to_datetime(df[self.config['quarter_col']])
        df.sort_values([self.config['id_col'], self.config['quarter_col']], inplace=True)
        df = df.dropna()
        return df

    def _make_windows(self) -> (np.ndarray, np.ndarray):
        print("2. Preparing sequence data...")
        X, y = [], []
        cfg = self.config
        for _, g in self.df.groupby(cfg['id_col']):
            g = g.sort_values(cfg['quarter_col'])
            arr, lbl = g[self.feat_cols].to_numpy(), g[cfg['target_col']].to_numpy()
            for i in range(cfg['lags'], len(g)):
                X.append(arr[i - cfg['lags']:i].ravel())
                y.append(lbl[i])
        return np.asarray(X), np.asarray(y)

    def run_shap_analysis(self):
        """
        Trains the champion LightGBM model on a recent data snapshot and
        generates SHAP plots to explain its predictions.
        """
        print("\n3. Preparing data for SHAP analysis...")
        
        n = len(self.y_all)
        analysis_start_point = int(n * 0.6)
        X_analysis, y_analysis = self.X_all[analysis_start_point:], self.y_all[analysis_start_point:]
        
        win_size = self.config['sliding_win_size']
        X_train, y_train = X_analysis[-win_size:], y_analysis[-win_size:]
        
        if len(np.unique(y_train)) < 2:
            print("⚠️  Initial window has only one class. Expanding backward...")
            extra_history = win_size
            while len(np.unique(y_train)) < 2 and extra_history < len(X_analysis):
                extra_history += win_size
                X_train = X_analysis[-extra_history:]
                y_train = y_analysis[-extra_history:]
            print(f"   Success! New training snapshot size: {len(y_train)}")

        print(f"   Training model on a snapshot of {len(y_train)} samples.")
        
        scaler = StandardScaler().fit(X_train)
        X_train_std = scaler.transform(X_train)
        
        champion_params = self.config['lgbm_champion_params']
        model = lgb.LGBMClassifier(**champion_params).fit(X_train_std, y_train)
        
        print("\n4. Setting up SHAP TreeExplainer...")
        explainer = shap.TreeExplainer(model)
        
        print("5. Calculating SHAP values...")
        shap_values = explainer.shap_values(X_train_std)
        
        print("6. Generating SHAP summary plots...")
        
        # --- 🔥 FIX: Robustly handle SHAP value structure ---
        # Check if shap_values is a list (for multi-class output) or a single array
        if isinstance(shap_values, list) and len(shap_values) == 2:
            # This is the expected case for binary classification
            sv_to_plot = shap_values[1] # Use the values for the positive class (class 1)
        elif isinstance(shap_values, np.ndarray) and shap_values.ndim == 2:
            # This happens when TreeExplainer returns values for the positive class directly
            sv_to_plot = shap_values
        else:
            raise TypeError(f"Unexpected SHAP values structure: {type(shap_values)}")

        base_feat_names = self.feat_cols
        lags = self.config['lags']
        feature_names = [f"{feat}_t-{i}" for i in range(lags, 0, -1) for feat in base_feat_names]
        
        X_train_df = pd.DataFrame(X_train_std, columns=feature_names)
        
        # Generate and save the dot plot
        shap.summary_plot(sv_to_plot, X_train_df, show=False)
        plt.title("SHAP Summary for Optuna-Tuned LightGBM (Distress Class)")
        plt.tight_layout()
        plt.savefig("shap_lgbm_summary_dot.png", dpi=300)
        plt.close()
        print("   ✅ Dot plot saved as 'shap_lgbm_summary_dot.png'")

        # Generate and save the bar plot
        shap.summary_plot(sv_to_plot, X_train_df, plot_type="bar", show=False)
        plt.title("Mean Absolute SHAP for Optuna-Tuned LightGBM")
        plt.tight_layout()
        plt.savefig("shap_lgbm_summary_bar.png", dpi=300)
        plt.close()
        print("   ✅ Bar plot saved as 'shap_lgbm_summary_bar.png'")
        
        print("\nSHAP Analysis Complete!")

if __name__ == "__main__":
    CONFIG = {
        "csv_path": r'cvm_indicators_dataset_2011-2021.csv',
        "id_col": "ID", "quarter_col": "QUARTER", "target_col": "LABEL",
        "meta_cols": ["ID", "QUARTER", "LABEL"],
        "lags": 4, "seed": 42,
        "sliding_win_size": 3000, 
        
        # Using the champion parameters discovered by Optuna for LightGBM
        "lgbm_champion_params": {
            "objective": 'binary',
            "metric": 'auc',
            "random_state": 42,
            "verbose": -1,
            # 🔥 Using GPU for acceleration
            "device": 'gpu', 
            # Best values from your experiment results
            'n_estimators': 400, 
            'learning_rate': 0.010268965803862608, 
            'num_leaves': 146, 
            'scale_pos_weight': 50
        }
    }

    # You might need to install shap, matplotlib, and a GPU-enabled lightgbm
    # pip install shap matplotlib
    # pip install lightgbm --config-settings=cmake.define.USE_GPU=ON
    shap_analyzer = LightGBM_SHAP_Analysis(config=CONFIG)
    shap_analyzer.run_shap_analysis()

────────────────────────────────────────────────────────────
1. Loading and cleaning data...
2. Preparing sequence data...

3. Preparing data for SHAP analysis...
   Training model on a snapshot of 3000 samples.





4. Setting up SHAP TreeExplainer...
5. Calculating SHAP values...
6. Generating SHAP summary plots...
   ✅ Dot plot saved as 'shap_lgbm_summary_dot.png'
   ✅ Bar plot saved as 'shap_lgbm_summary_bar.png'

SHAP Analysis Complete!
