In [4]:
#!/usr/bin/env python
# Naive Bayes SHAP – fixed dimension matching
from __future__ import annotations
import os, warnings
from pathlib import Path
from typing import Dict, Any, Tuple
import numpy as np, pandas as pd, shap, matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler

os.environ["SHAP_PROGRESS_BAR"] = "off"
warnings.filterwarnings("ignore", category=UserWarning)

class NB_SHAP:
    def __init__(self, cfg: Dict[str, Any]):
        self.c = cfg
        self.df = self._load(cfg["csv_path"])
        self.feats = [f for f in self.df.columns if f not in cfg["meta_cols"]]
        self.X, self.y = self._windows()
    
    def _load(self, p):                                  # 1) load csv
        df = pd.read_csv(p).loc[:,~pd.read_csv(p).columns.duplicated()]
        df[self.c["quarter_col"]] = pd.to_datetime(df[self.c["quarter_col"]])
        return df.sort_values([self.c["id_col"], self.c["quarter_col"]]).dropna()
    
    def _windows(self) -> Tuple[np.ndarray,np.ndarray]:  # 2) lag windows
        X, y = [], []
        L = self.c["lags"]
        for _, g in self.df.groupby(self.c["id_col"]):
            a, lbl = g[self.feats].to_numpy(), g[self.c["target_col"]].to_numpy()
            for i in range(L, len(g)):
                X.append(a[i-L:i].ravel())
                y.append(lbl[i])
        return np.asarray(X), np.asarray(y)
    
    def run(self):
        # 3) choose snapshot
        s = int(len(self.y) * 0.6)
        Xtr, ytr = self.X[s:][-self.c["win"]:], self.y[s:][-self.c["win"]:]
        
        if len(np.unique(ytr)) < 2: 
            raise ValueError("窗口只有单一类别！")
        
        scaler = StandardScaler()
        Xtr_std = scaler.fit_transform(Xtr)
        model = GaussianNB().fit(Xtr_std, ytr)
        
        # 4) sample 300 rows to explain
        Xexp = shap.sample(Xtr_std, 300, random_state=self.c["seed"])
        
        # feature names
        fnames = [f"{f}_t-{i}" for i in range(self.c["lags"], 0, -1) for f in self.feats]
        
        # Debug: print shapes
        print(f"Xexp shape: {Xexp.shape}")
        print(f"Feature names length: {len(fnames)}")
        print(f"Expected features: {len(self.feats) * self.c['lags']}")
        
        # wrapper function
        def pred(x): 
            return model.predict_proba(x)
        
        # Create explainer with background data
        background = shap.sample(Xtr_std, 100, random_state=self.c["seed"])
        expl = shap.KernelExplainer(pred, background)
        
        # Get SHAP values - handle multiclass case
        shap_values_all = expl.shap_values(Xexp, nsamples="auto")
        
        print(f"Raw SHAP values shape: {np.array(shap_values_all).shape}")
        
        # Handle different SHAP output formats
        if isinstance(shap_values_all, list):
            # Multiclass case - get positive class (class 1)
            if 1 in model.classes_:
                class_idx = model.classes_.tolist().index(1)
            else:
                class_idx = 1 if len(model.classes_) > 1 else 0
            sv = shap_values_all[class_idx]
        else:
            # Handle 3D array case (n_samples, n_features, n_classes)
            if len(shap_values_all.shape) == 3:
                # Get positive class (class 1)
                if 1 in model.classes_:
                    class_idx = model.classes_.tolist().index(1)
                else:
                    class_idx = 1 if shap_values_all.shape[2] > 1 else 0
                sv = shap_values_all[:, :, class_idx]
            else:
                sv = shap_values_all
        
        print(f"SHAP values shape after processing: {sv.shape}")
        
        # Fix dimension mismatch - remove bias column if present
        if sv.shape[1] == Xexp.shape[1] + 1:
            print("Removing bias column from SHAP values")
            sv = sv[:, :-1]
        elif sv.shape[1] != Xexp.shape[1]:
            raise ValueError(f"Shape mismatch: SHAP values {sv.shape[1]} vs features {Xexp.shape[1]}")
        
        # Ensure feature names match
        if len(fnames) != sv.shape[1]:
            print(f"Adjusting feature names: {len(fnames)} -> {sv.shape[1]}")
            fnames = fnames[:sv.shape[1]]
        
        # Create output directory
        out = Path(self.c["out"])
        out.mkdir(exist_ok=True)
        
        # Create DataFrame with matching dimensions
        Xexp_df = pd.DataFrame(Xexp[:, :sv.shape[1]], columns=fnames)
        
        print(f"Final shapes - SHAP: {sv.shape}, Features: {Xexp_df.shape}")
        
        # Generate plots with better formatting
        try:
            # Filter to top features to reduce crowding
            top_features = min(self.c["max_display"], 20)  # Cap at 20 for readability
            
            # -- dot plot --
            plt.figure(figsize=(14, max(8, top_features * 0.4)))
            shap.summary_plot(sv, Xexp_df,
                              plot_type="dot", max_display=top_features,
                              show=False)
            plt.tight_layout()
            plt.savefig(out/"NB_dot.png", dpi=300, bbox_inches='tight', 
                       facecolor='white', edgecolor='none')
            plt.close()
            
            # -- bar plot --
            plt.figure(figsize=(12, max(6, top_features * 0.3)))
            shap.summary_plot(sv, Xexp_df,
                              plot_type="bar", max_display=top_features,
                              show=False)
            plt.tight_layout()
            plt.savefig(out/"NB_bar.png", dpi=300, bbox_inches='tight',
                       facecolor='white', edgecolor='none')
            plt.close()
            
            # -- waterfall plot for single prediction --
            plt.figure(figsize=(12, 8))
            shap.waterfall_plot(shap.Explanation(values=sv[0], 
                                                base_values=expl.expected_value[class_idx] if hasattr(expl.expected_value, '__len__') else expl.expected_value,
                                                data=Xexp_df.iloc[0],
                                                feature_names=fnames[:sv.shape[1]]),
                               max_display=15, show=False)
            plt.tight_layout()
            plt.savefig(out/"NB_waterfall.png", dpi=300, bbox_inches='tight',
                       facecolor='white', edgecolor='none')
            plt.close()
            
            print(f"✅ plots saved to {out}")
            print(f"📊 Showing top {top_features} features out of {sv.shape[1]} total")
            
        except Exception as e:
            print(f"❌ Plot generation failed: {e}")
            print(f"SHAP shape: {sv.shape}, Data shape: {Xexp_df.shape}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    CFG = {
        "csv_path": "cvm_indicators_dataset_2011-2021.csv",
        "id_col": "ID", "quarter_col": "QUARTER", "target_col": "LABEL",
        "meta_cols": ["ID", "QUARTER", "LABEL"],
        "lags": 4,
        "win": 10000,
        "seed": 42,
        "out": "nb_kernel_shap_results",
        "max_display": 30
    }
    NB_SHAP(CFG).run()

Xexp shape: (300, 336)
Feature names length: 336
Expected features: 336


  0%|          | 0/300 [00:00<?, ?it/s]

Raw SHAP values shape: (300, 336, 2)
SHAP values shape after processing: (300, 336)
Final shapes - SHAP: (300, 336), Features: (300, 336)
✅ plots saved to nb_kernel_shap_results
📊 Showing top 20 features out of 336 total
