In [None]:
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
from collections import defaultdict, Counter
import gc
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import contextlib
import io
from scipy import interp
from sklearn.preprocessing import label_binarize
from sklearn.utils.class_weight import compute_sample_weight
           
# ================== Library Imports ==================
try:
    from skopt import BayesSearchCV
    from skopt.space import Real, Integer, Categorical
    SKOPT_AVAILABLE = True
except ImportError:
    SKOPT_AVAILABLE = False
    print("Warning: scikit-optimize is not installed. Bayesian optimization will not be available.")

try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    SHAP_AVAILABLE = False
    print("Warning: shap is not installed. Model interpretation will not be available.")

try:
    import xlsxwriter
    XLSXWRITER_AVAILABLE = True
except ImportError:
    XLSXWRITER_AVAILABLE = False
    print("Warning: xlsxwriter is not installed. Saving SHAP values to Excel will not be available.")

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, cohen_kappa_score
from sklearn.preprocessing import StandardScaler

# Import with proper error handling
try:
    import xgboost as xgb
    XGB_AVAILABLE = True
except ImportError:
    XGB_AVAILABLE = False

try:
    import lightgbm as lgb
    LGB_AVAILABLE = True
except ImportError:
    LGB_AVAILABLE = False

warnings.filterwarnings('ignore')

# ================== Configuration ==================
class Config:
    BASE_DIR = Path('.')
    DATA_PATH = BASE_DIR / "cleaned_data.csv"
    OUTPUT_DIR = BASE_DIR / "Global_teeth_selection"
    TARGET_COLUMN = 'perio_label_cdc'
    AGE_COLUMN = 'RIDAGEYR'
    WEIGHT_COLUMN = 'WTMEC2YR'
    MIN_AGE = 35
    RANDOM_SEEDS = [42, 123, 456, 999, 2025]
    CV_FOLDS = 5
    INTERPROXIMAL_SITES = {'D', 'S', 'P', 'A'}
    CLASS_LABELS = {0: 'Healthy', 1: 'Other', 2: 'Severe'}
    
    TOP_FEATURES_COUNT = 10
    SHAP_SAMPLE_SIZE = 1000
    SHAP_PLOT_MAX_DISPLAY = 20

    CPI_RAMFJORD_TOOTH_NUMBERS = {
        'Ramfjord': ['16', '14', '21', '24', '26', '36', '34', '41', '44', '46'],
        'CPI': [ '11', '16', '17', '26', '27', '31', '36', '37', '46', '47']
    }
    
    NHANES_TO_FDI_MAPPING = {
        '01': '18', '02': '17', '03': '16', '04': '15', '05': '14', '06': '13', '07': '12', '08': '11',
        '09': '21', '10': '22', '11': '23', '12': '24', '13': '25', '14': '26', '15': '27', '16': '28',
        '17': '38', '18': '37', '19': '36', '20': '35', '21': '34', '22': '33', '23': '32', '24': '31',
        '25': '41', '26': '42', '27': '43', '28': '44', '29': '45', '30': '46', '31': '47', '32': '48'
    }

    # Optimized Bayesian Hyperparameter Space
    BAYESIAN_HYPERPARAMETER_SPACES = {
        'XGBoost': {
            'n_estimators': Integer(1000, 1300),  
            'max_depth': Integer(2, 4),  
            'learning_rate': Real(0.05, 0.15, 'log-uniform'),  
            'subsample': Real(0.7, 0.9, 'uniform'),  
            'colsample_bytree': Real(0.85, 0.95, 'uniform'),  
            'gamma': Real(0.2, 0.8, 'uniform'),  
            'reg_alpha': Real(1e-3, 1e-2, 'log-uniform'),  
            'reg_lambda': Real(1e-3, 1e-2, 'log-uniform'),  
            'min_child_weight': Integer(8, 12)  
        },
        'LightGBM': {
            'n_estimators': Integer(800, 1200),  
            'max_depth': Integer(1, 3),  
            'num_leaves': Integer(80, 150),  
            'learning_rate': Real(0.08, 0.20, 'log-uniform'),  
            'subsample': Real(0.7, 0.9, 'uniform'),  
            'colsample_bytree': Real(0.6, 0.8, 'uniform'),  
            'reg_alpha': Real(1e-5, 1e-3, 'log-uniform'),  
            'reg_lambda': Real(5.0, 15.0, 'log-uniform'),  
            'min_child_samples': Integer(10, 25)  
        }
    }

    BAYESIAN_N_ITER = 30

Config.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ================== Utility Functions ==================
def create_model(model_name, random_state=42, **kwargs):
    """Creates a model instance based on the model name."""
    if model_name == 'XGBoost':
        if not XGB_AVAILABLE:
            raise ImportError("XGBoost is not available")
        return xgb.XGBClassifier(
            objective='multi:softprob', 
            random_state=random_state, 
            use_label_encoder=False, 
            eval_metric='mlogloss', 
            n_jobs=1, 
            verbosity=0, 
            **kwargs
        )
    elif model_name == 'LightGBM':
        if not LGB_AVAILABLE:
            raise ImportError("LightGBM is not available")
        return lgb.LGBMClassifier(
            objective='multiclass', 
            random_state=random_state, 
            n_jobs=-1,  
            verbose=-1,
            class_weight= None,
            verbosity=-1,
            **kwargs
        )
    else: 
        raise ValueError(f"Unknown model: {model_name}")

# ================== Data Processor ==================
class DataProcessor:
    def __init__(self):
        self.pd_columns = []
        self.cal_columns = []

    def load_and_preprocess(self, filepath):
        print(f"Loading data from: {filepath}")
        try:
            df = pd.read_csv(filepath, low_memory=False)
        except FileNotFoundError:
            print(f"Error: Data file not found at {filepath}. Please ensure 'cleaned_data.csv' is in the correct directory.")
            return pd.DataFrame()

        required_cols = [Config.AGE_COLUMN, Config.WEIGHT_COLUMN]
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            print(f"Warning: Missing required columns: {missing_cols}")
            for col in missing_cols:
                if col == Config.WEIGHT_COLUMN: df[col] = 1.0
                elif col == Config.AGE_COLUMN: df[col] = 40
        
        print(f"Initial sample size: {len(df)}")
        df = df[df[Config.AGE_COLUMN] >= Config.MIN_AGE].copy()
        print(f"Sample size after age filtering (>= {Config.MIN_AGE}): {len(df)}")
        
        initial_weighted_size = len(df)
        df = df[(df[Config.WEIGHT_COLUMN].notna()) & (df[Config.WEIGHT_COLUMN] > 0)].copy()
        print(f"Sample size after weight filtering: {len(df)} (removed {initial_weighted_size - len(df)} records)")
        
        df = self._identify_periodontal_features(df)
        df = self._handle_missing_values(df)
        df = self._apply_cdc_classification(df)
        
        print("\nData preprocessing complete.")
        return df

    def _identify_periodontal_features(self, df):
        pd_pattern = r'^OHX\d{2}PC[ADSP]$'
        cal_pattern = r'^OHX\d{2}LA[ADSP]$'
        self.pd_columns = sorted([c for c in df.columns if re.search(pd_pattern, c)])
        self.cal_columns = sorted([c for c in df.columns if re.search(cal_pattern, c)])
        for col in self.pd_columns + self.cal_columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
        return df

    def _handle_missing_values(self, df):
        for col in self.pd_columns + self.cal_columns:
            df[col] = df[col].fillna(np.nan)
        return df

    def _apply_cdc_classification(self, df):
        pd_cols = self.pd_columns  # PD （OHX##PC[ADSP]）
        cal_cols = self.cal_columns # CAL （OHX##LA[ADSP]）
        impute_cols = pd_cols + cal_cols
        df_imputed = df[impute_cols].copy()
        df_imputed.replace(99, np.nan, inplace=True)

        #  Construct the Max of Adjacent Sites per Tooth (One Table Each for PPD and CAL)
        tooth_prefixes = sorted({c[:5] for c in impute_cols}) 
        tooth_pd_cols = {tp: [] for tp in tooth_prefixes}
        tooth_cal_cols = {tp: [] for tp in tooth_prefixes}

        for col in pd_cols:
            if col[-1] in Config.INTERPROXIMAL_SITES:
                tooth_pd_cols[col[:5]].append(col)
        for col in cal_cols:
            if col[-1] in Config.INTERPROXIMAL_SITES:
                tooth_cal_cols[col[:5]].append(col)

        tooth_pd_max = {}
        tooth_cal_max = {}
        for tp in tooth_prefixes:
            cols_pd = tooth_pd_cols[tp]
            cols_cal = tooth_cal_cols[tp]
            tooth_pd_max[tp]  = df_imputed[cols_pd].max(axis=1) if cols_pd else pd.Series(np.nan, index=df.index)
            tooth_cal_max[tp] = df_imputed[cols_cal].max(axis=1) if cols_cal else pd.Series(np.nan, index=df.index)

        tooth_pd_max_df  = pd.DataFrame(tooth_pd_max)
        tooth_cal_max_df = pd.DataFrame(tooth_cal_max)

        # Count the 'Number of Teeth Meeting the Threshold' (Different Tooth Constraints)
        def count_teeth_ge(mat: pd.DataFrame, thr: float) -> pd.Series:
            return (mat >= thr).sum(axis=1)

        n_CAL3 = count_teeth_ge(tooth_cal_max_df, 3.0)
        n_CAL4 = count_teeth_ge(tooth_cal_max_df, 4.0)
        n_CAL6 = count_teeth_ge(tooth_cal_max_df, 6.0)
        n_PPD4 = count_teeth_ge(tooth_pd_max_df,  4.0)
        n_PPD5 = count_teeth_ge(tooth_pd_max_df,  5.0)

        # Classification Logic (Prioritize Severe Assessment, Followed by Moderate, Then Mild)
        severe = (n_CAL6 >= 2) & (n_PPD5 >= 1) 

        other = ((~severe) & (n_CAL3 >= 2) & ((n_PPD4 >= 2) | (n_PPD5 >= 1)))
        
        label = np.where(severe, 2, np.where(other, 1, 0)).astype(int)

        df[Config.TARGET_COLUMN] = label

        return df

    def _create_max_features(self, df, column_list, feature_type):
        if not column_list:
            return pd.DataFrame(index=df.index)
        
        data_for_agg = df[column_list].replace(99, np.nan)
        grouped_cols = defaultdict(list)
        for col in column_list:
            prefix = col[:-1]
            grouped_cols[prefix].append(col)
        
        feature_series_list = []
        for prefix, columns in grouped_cols.items():
            tooth_data = data_for_agg[columns]
            max_values = tooth_data.max(axis=1).fillna(np.nan)
            max_values.name = f'{prefix}_max'
            feature_series_list.append(max_values)
        
        if feature_series_list:
            X_max = pd.concat(feature_series_list, axis=1)
        else:
            X_max = pd.DataFrame(index=df.index)
        
        return X_max

    def get_combined_feature_set(self, df):
        X_pd_max = self._create_max_features(df, self.pd_columns, "PD")
        print(f"PD features shape: {X_pd_max.shape}")
        
        X_cal_max = self._create_max_features(df, self.cal_columns, "CAL")
        print(f"CAL features shape: {X_cal_max.shape}")
        
        if X_pd_max.empty and X_cal_max.empty:
            X_combined_max = pd.DataFrame(index=df.index)
        elif X_pd_max.empty:
            X_combined_max = X_cal_max.copy()
        elif X_cal_max.empty:
            X_combined_max = X_pd_max.copy()
        else:
            X_pd_max = X_pd_max.reindex(df.index)
            X_cal_max = X_cal_max.reindex(df.index)
            X_combined_max = pd.concat([X_pd_max, X_cal_max], axis=1)
        
        print(f"Combined features shape: {X_combined_max.shape}")
        
        y = df[Config.TARGET_COLUMN]
        sample_weights = df[Config.WEIGHT_COLUMN] if Config.WEIGHT_COLUMN in df.columns else None
        return X_combined_max, y, sample_weights

    def nhanes_to_fdi(self, nhanes_tooth):
        return Config.NHANES_TO_FDI_MAPPING.get(nhanes_tooth, nhanes_tooth)

    def extract_tooth_numbers_from_features(self, feature_list):
        tooth_numbers_fdi = set()
        for feature in feature_list:
            if 'OHX' in feature and ('PC_max' in feature or 'LA_max' in feature):
                tooth_num = feature.replace('OHX', '').replace('PC_max', '').replace('LA_max', '')
                if tooth_num.isdigit():
                    fdi_num = self.nhanes_to_fdi(tooth_num)
                    tooth_numbers_fdi.add(fdi_num)
        return sorted(list(tooth_numbers_fdi))

# ================== Bayesian Hyperparameter Tuning ==================
class HyperparameterTuner:
    def __init__(self, cv_folds=5, random_state=42):
        self.cv_folds = cv_folds
        self.random_state = random_state
    
    def tune_model_bayes(self, model_name, base_model, X, y, sample_weight, param_space, scoring='roc_auc_ovr'):
        print(f"  Tuning hyperparameters for {model_name} with Bayesian Optimization...")
        cv = StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state)
        bayes_search = BayesSearchCV(
            base_model, param_space, n_iter=Config.BAYESIAN_N_ITER, 
            cv=cv, scoring=scoring, n_jobs=-1, random_state=self.random_state, verbose=0
        )
        
        fit_params = {}
        
        if sample_weight is not None:
            class_sample_weights = compute_sample_weight(
                class_weight='balanced', 
                y=y
            )
            combined_weights = sample_weight * class_sample_weights
            fit_params['sample_weight'] = combined_weights
        else:
            class_sample_weights = compute_sample_weight(
                class_weight='balanced', 
                y=y
            )
            fit_params['sample_weight'] = class_sample_weights
                
            print(f"    Applied class balancing during hyperparameter tuning for {model_name}")

        bayes_search.fit(X, y, **fit_params)
            
        tuned_model = base_model.set_params(**bayes_search.best_params_)
        print(f"    Best params: {bayes_search.best_params_}")
        print(f"    Best score: {bayes_search.best_score_:.4f}")
        gc.collect()
        return tuned_model, bayes_search.best_params_

# ================== Multi-Seed Stability Analyzer ==================
class MultiSeedStabilityAnalyzer:
    def __init__(self, data_processor, random_seeds=Config.RANDOM_SEEDS):
        self.data_processor = data_processor
        self.random_seeds = random_seeds
        self.stability_results = {}
        self.all_run_shap_values = defaultdict(list)

    def analyze_top_teeth_stability(self, X_train, y_train, sample_weights, model_name, best_params, current_output_dir):

        import contextlib, io, sys
        print(f"\n--- Analyzing Feature Stability for {model_name} (Robust Multi-Run SHAP) ---")
        
        if not SHAP_AVAILABLE:
            print("SHAP is not available. Skipping stability analysis.")
            return {}

        self.all_run_shap_values.clear()

        print(f"Creating a fixed background sample of size {Config.SHAP_SAMPLE_SIZE} for SHAP analysis.")
        X_sample = X_train.sample(min(len(X_train), Config.SHAP_SAMPLE_SIZE), random_state=42)

        total_runs = len(self.random_seeds) * Config.CV_FOLDS
        run_count = 0

        for seed_idx, seed in enumerate(self.random_seeds):
            print(f"\n  Processing Seed {seed_idx + 1}/{len(self.random_seeds)}: {seed}")
            cv = StratifiedKFold(n_splits=Config.CV_FOLDS, shuffle=True, random_state=seed)
            
            for fold_idx, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
                run_count += 1
                print(f"    Fold {fold_idx + 1}/{Config.CV_FOLDS} (Overall Run {run_count}/{total_runs})")

                X_train_sub, _ = X_train.iloc[train_idx], X_train.iloc[val_idx]
                y_train_sub, _ = y_train.iloc[train_idx], y_train.iloc[val_idx]
                
                sw_train_sub = sample_weights.iloc[train_idx] if sample_weights is not None else None
                
                model = create_model(model_name, random_state=seed)
                model.set_params(**best_params)
                
                fit_params = {}

                class_sample_weights = compute_sample_weight('balanced', y=y_train_sub)
                fit_params['sample_weight'] = sw_train_sub * class_sample_weights if sw_train_sub is not None else class_sample_weights
                model.set_params(verbosity=0)
                
                with contextlib.redirect_stdout(io.StringIO()):
                    model.fit(X_train_sub, y_train_sub, **fit_params)
                
                try:
                    explainer = shap.TreeExplainer(model, model_output="probability", feature_perturbation="interventional")
                except Exception:
                    explainer = shap.TreeExplainer(model)
                shap_values = explainer.shap_values(X_sample)

                if isinstance(shap_values, list):
                    mean_shap = np.mean([np.abs(sv).mean(axis=0) for sv in shap_values], axis=0)
                else:
                    mean_shap = np.abs(shap_values).mean(axis=0)

                for feature_name, shap_val in zip(X_train.columns, mean_shap):
                    self.all_run_shap_values[feature_name].append(shap_val)
                
                del model, explainer, shap_values
                gc.collect()
        
        stability_analysis = self._calculate_and_visualize_stability_results(model_name, current_output_dir)
        self.stability_results[model_name] = stability_analysis
        return stability_analysis

    def _calculate_and_visualize_stability_results(self, model_name, current_output_dir):

        print(f"\n  Aggregating and analyzing SHAP results for {model_name}...")
        
        feature_stability_data = [{'feature': f, 'mean_shap': np.mean(s), 'std_shap': np.std(s)} 
                                  for f, s in self.all_run_shap_values.items()]
        feature_df = pd.DataFrame(feature_stability_data).sort_values('mean_shap', ascending=False)

        detailed_importance_data = []
        for _, row in feature_df.iterrows():
            match = re.search(r'OHX(\d{2})(PC|LA)_max', row['feature'])
            if match:
                nhanes_num = match.group(1)
                measurement_type = 'PD' if match.group(2) == 'PC' else 'CAL'
                fdi_num = self.data_processor.nhanes_to_fdi(nhanes_num)
                
                detailed_importance_data.append({
                    'tooth_fdi': fdi_num,
                    'measurement_type': measurement_type,
                    'feature': row['feature'],
                    'mean_shap': row['mean_shap'],
                    'shap_error': row['std_shap']
                })
        detailed_importance_df = pd.DataFrame(detailed_importance_data)

        tooth_df_aggregated = detailed_importance_df.groupby('tooth_fdi').agg(
            total_mean_shap=('mean_shap', 'sum'),
            shap_error=('shap_error', 'mean')
        ).sort_values('total_mean_shap', ascending=False).reset_index()

        print("\n--- Full Tooth Importance Ranking (Aggregated) ---")
        print(tooth_df_aggregated.to_string(index=False))

        detailed_importance_df.to_csv(current_output_dir / f"{model_name}_detailed_feature_importance.csv", index=False)
        tooth_df_aggregated.to_csv(current_output_dir / f"{model_name}_aggregated_tooth_importance.csv", index=False)
        print(f"\nSaved detailed and aggregated importance to: {current_output_dir}")

        self.visualize_detailed_importance(detailed_importance_df, tooth_df_aggregated, model_name, current_output_dir)

        analysis = {
            'tooth_level_importance': tooth_df_aggregated,
            'feature_level_stability': feature_df,
            'detailed_importance': detailed_importance_df
        }
        return analysis

    def visualize_detailed_importance(self, detailed_df, aggregated_df, model_name, current_output_dir):

        print(f"Generating detailed visualizations for {model_name}...")
        try:
            plt.style.use('seaborn-v0_8-whitegrid')
        except:
            try:
                plt.style.use('seaborn-whitegrid')
            except:
                pass  # Use default style if seaborn styles not available

        # Paired Bar Chart (PD vs. CAL)
        top_teeth = aggregated_df.head(10)['tooth_fdi'].tolist()
        plot_data = detailed_df[detailed_df['tooth_fdi'].isin(top_teeth)]

        pivot_df = plot_data.pivot_table(index='tooth_fdi', columns='measurement_type', values='mean_shap').fillna(0)
        pivot_df = pivot_df.reindex(top_teeth)

        fig, ax = plt.subplots(figsize=(18, 9))
        
        bar_width = 0.35
        index = np.arange(len(pivot_df.index))

        bars1 = ax.bar(index - bar_width/2, pivot_df['PD'], bar_width, label='PD (Probing Depth)', color='royalblue')
        
        bars2 = ax.bar(index + bar_width/2, pivot_df['CAL'], bar_width, label='CAL (Clinical Attachment Level)', color='skyblue')

        ax.set_ylabel('Robust Mean(|SHAP|) Value', fontsize=14)
        ax.set_xlabel('Tooth (FDI Notation)', fontsize=14)
        ax.set_title(f'PD vs. CAL Importance for Top 10 Teeth - {model_name}', fontsize=16, pad=20)
        ax.set_xticks(index)
        ax.set_xticklabels(pivot_df.index, rotation=45, ha="right")
        ax.legend()
        fig.tight_layout()
        
        output_path_paired = current_output_dir / f"{model_name}_paired_importance_barchart.png"
        plt.savefig(output_path_paired, dpi=800)
        plt.close(fig)
        print(f"Saved paired bar chart to: {output_path_paired}")

        # Dental Heatmap
        FDI_LAYOUT = {'18':(0,0),'17':(0,1),'16':(0,2),'15':(0,3),'14':(0,4),'13':(0,5),'12':(0,6),'11':(0,7),
                    '21':(0,8),'22':(0,9),'23':(0,10),'24':(0,11),'25':(0,12),'26':(0,13),'27':(0,14),'28':(0,15),
                    '48':(1,0),'47':(1,1),'46':(1,2),'45':(1,3),'44':(1,4),'43':(1,5),'42':(1,6),'41':(1,7),
                    '31':(1,8),'32':(1,9),'33':(1,10),'34':(1,11),'35':(1,12),'36':(1,13),'37':(1,14),'38':(1,15)}
        heatmap_data = np.full((2, 16), np.nan)
        
        importance_dict = aggregated_df.set_index('tooth_fdi')['total_mean_shap'].to_dict()
        for fdi, pos in FDI_LAYOUT.items():
            if fdi in importance_dict:
                heatmap_data[pos] = importance_dict[fdi]

        fig = plt.figure(figsize=(16, 6))
        ax1 = fig.add_axes([0.05, 0.35, 0.9, 0.55])
        im = ax1.imshow(heatmap_data, cmap='Reds', interpolation='nearest', aspect='auto')
        cbar = fig.colorbar(im, ax=ax1, fraction=0.02, pad=0.04)
        cbar.set_label('Aggregated Robust Mean(|SHAP|)', rotation=270, labelpad=15)

        for fdi, pos in FDI_LAYOUT.items():
            row, col = pos
            val = heatmap_data[row, col]
            if not np.isnan(val):
                txt_color = 'white' if val > np.nanmax(heatmap_data)/2 else 'black'
                ax1.text(col, row - 0.15, fdi, ha='center', va='center', color=txt_color, weight='bold', fontsize=12)
                importance_str = f"{val:.3f}"
                ax1.text(col, row + 0.15, importance_str, ha='center', va='center', color=txt_color, fontsize=10)

        ax1.set_title(f'Aggregated Tooth Importance Heatmap for {model_name}', fontsize=16, pad=10)
        ax1.set_xticks([])
        ax1.set_yticks([])

        output_path_heatmap = current_output_dir / f"{model_name}_aggregated_importance_heatmap_with_values.png"
        fig.savefig(output_path_heatmap, dpi=800, bbox_inches='tight')
        plt.close(fig)
        print(f"Saved heatmap with value table to: {output_path_heatmap}")

# ================== Model Evaluation Pipeline ==================
class ModelPipeline:
    def __init__(self, random_seeds=Config.RANDOM_SEEDS, cv_folds=Config.CV_FOLDS):
        self.random_seeds = random_seeds
        self.cv_folds = cv_folds
        self.tuner = HyperparameterTuner(cv_folds=cv_folds)
        self.best_params_cache = {}

    def _calculate_weighted_metrics(self, y_true, y_pred, y_pred_proba, sample_weight=None):
        auc_macro = roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro', sample_weight=sample_weight)
        accuracy = accuracy_score(y_true, y_pred, sample_weight=sample_weight)
        f1_macro = f1_score(y_true, y_pred, average='macro', sample_weight=sample_weight)
        qwk = cohen_kappa_score(y_true, y_pred, weights='quadratic', sample_weight=sample_weight)
        return {'auc_macro': auc_macro, 'accuracy': accuracy, 'f1_macro': f1_macro, 'qwk': qwk}

    def tune_hyperparameters_once(self, model_name, X, y, sample_weight, feature_set_name='Combined'):
        param_key = f"{model_name}_{feature_set_name}"
        if param_key in self.best_params_cache:
            print(f"  Using cached hyperparameters for {param_key}")
            return self.best_params_cache[param_key]
        
        if model_name in Config.BAYESIAN_HYPERPARAMETER_SPACES and SKOPT_AVAILABLE:
            base_model = create_model(model_name)
            _, best_params = self.tuner.tune_model_bayes(
                model_name, base_model, X, y, sample_weight, 
                Config.BAYESIAN_HYPERPARAMETER_SPACES[model_name]
            )
            self.best_params_cache[param_key] = best_params
            del base_model
            gc.collect()
        else:
            print(f"  Skipping Bayesian tuning for {model_name} (scikit-optimize not available or not configured). Using default parameters.")
            best_params = {}
            self.best_params_cache[param_key] = best_params
        return best_params

    def evaluate_feature_sets_multiseed(self, X, y, sample_weights, model_name, best_params, stability_analyzer, current_output_dir):
        print(f"\n{'='*50}\nMULTI-SEED EVALUATION: {model_name}\n{'='*50}")
        all_results = []

        # SHAP stability analysis
        stability_results = stability_analyzer.analyze_top_teeth_stability(
            X, y, sample_weights, model_name, best_params, current_output_dir
        )
        
        # Create feature sets based on stability results
        if 'tooth_level_importance' in stability_results:
            consensus_df = self._create_consensus_feature_sets(X, stability_results, model_name).get('SHAP_Consensus', pd.DataFrame())
        else:
            consensus_df = pd.DataFrame()

        # Define feature sets
        feature_sets = {
            'Combined': X,
            'CPI': self._get_clinical_index_features(X, 'CPI'),
            'Ramfjord': self._get_clinical_index_features(X, 'Ramfjord'),
            f'SHAP_Consensus_{model_name}': consensus_df
        }

        # Evaluate each feature set
        roc_data = {}
        for fs_name, X_fs in feature_sets.items():
            if X_fs.empty:
                print(f"Skipping {fs_name} (no features)")
                continue

            print(f"\n>>> Tuning {model_name} / {fs_name} ...")
            cache_name = 'SHAP_Consensus' if fs_name.startswith('SHAP_Consensus_') else fs_name
            best_params_fs = self.tune_hyperparameters_once(
                model_name, X_fs, y, sample_weights,
                feature_set_name=cache_name
            )

            print(f"\nEvaluating {fs_name} ({X_fs.shape[1]} features)...")
            metrics_across_seeds, roc_curves = self._evaluate_with_multiseed_cv(
                X_fs, y, sample_weights, model_name, best_params_fs
            )

            for metric, vals in metrics_across_seeds.items():
                all_results.append({
                    'Model': model_name,
                    'Feature_Set': fs_name,
                    'Metric': metric,
                    'Mean': np.mean(vals),
                    'Std': np.std(vals),
                    'N_Features': X_fs.shape[1]
                })

            roc_data[fs_name] = roc_curves

        # Create ROC plots
        n_feature_sets = len([fs for fs in feature_sets.values() if not fs.empty])
        if n_feature_sets > 0:
            fig, axes = plt.subplots(1,3, figsize=(4, 12))
            axes = axes.flatten()

            for idx, class_label in enumerate(Config.CLASS_LABELS.values()):
                ax = axes[idx]
                
                colors = ['blue', 'red', 'green', 'orange', 'purple']
                color_idx = 0
                
                for fs_name, curves in roc_data.items():
                    fpr_tpr_list = curves[class_label]
                    all_fpr = np.unique(np.concatenate([fpr for fpr, _ in fpr_tpr_list]))
                    mean_tpr = np.zeros_like(all_fpr)
                    for fpr, tpr in fpr_tpr_list:
                        mean_tpr += np.interp(all_fpr, fpr, tpr)
                    mean_tpr /= len(fpr_tpr_list)
                    mean_auc = auc(all_fpr, mean_tpr)

                    display_name = fs_name
                    if fs_name.startswith('SHAP_Consensus_'):
                        display_name = f"SHAP_{model_name}"
                    
                    ax.plot(all_fpr, mean_tpr,
                            label=f"{display_name} (AUC={mean_auc:.2f})",
                            linewidth=2,
                            color=colors[color_idx % len(colors)])
                    color_idx += 1

                ax.plot([0,1],[0,1],'--', color='gray', alpha=0.5)
                ax.set_title(f"{class_label} - {model_name}")
                ax.set_xlabel("False Positive Rate")
                ax.set_ylabel("True Positive Rate")
                ax.legend(loc="lower right", fontsize='small')

            fig.suptitle(f"{model_name} - Cross-validation ROC-AUC (Training Set)", fontsize=16)
            fig.tight_layout(rect=[0,0,1,0.96])
            out_path = Config.OUTPUT_DIR / f"roc_group_{model_name}_train.png"
            fig.savefig(out_path, dpi=300, bbox_inches='tight')
            plt.close(fig)
            print(f"Saved grouped ROC figure: {out_path}")

        return all_results, stability_results

    def _create_consensus_feature_sets(self, X_full, stability_results, model_name):
        feature_sets = {}

        if 'tooth_level_importance' in stability_results:
            tooth_df = stability_results['tooth_level_importance']
            top_teeth_fdi = tooth_df.head(Config.TOP_FEATURES_COUNT)['tooth_fdi'].tolist()

            shap_features = []
            for tooth_fdi in top_teeth_fdi:
                nhanes_tooth = next((nhanes for nhanes, fdi in Config.NHANES_TO_FDI_MAPPING.items() if fdi == tooth_fdi), None)
                if nhanes_tooth:
                    for p_type in ['PC', 'LA']:
                        feature = f'OHX{nhanes_tooth}{p_type}_max'
                        if feature in X_full.columns:
                            shap_features.append(feature)

            if shap_features:
                key = 'SHAP_Consensus'
                feature_sets[key] = X_full[shap_features]
                print(f"  Created {key} from robust ranking with teeth: {sorted(top_teeth_fdi)}")

        feature_sets['CPI_Reference'] = self._get_clinical_index_features(X_full, 'CPI')
        feature_sets['Ramfjord_Reference'] = self._get_clinical_index_features(X_full, 'Ramfjord')

        return feature_sets

    def _get_clinical_index_features(self, X_combined, method_name):
        fdi_numbers = Config.CPI_RAMFJORD_TOOTH_NUMBERS[method_name]
        selected_features = []
        for fdi_num in fdi_numbers:
            nhanes_num = next((nhanes for nhanes, fdi in Config.NHANES_TO_FDI_MAPPING.items() if fdi == fdi_num), None)
            if nhanes_num:
                for p_type in ['PC', 'LA']:
                    feature = f'OHX{nhanes_num}{p_type}_max'
                    if feature in X_combined.columns: selected_features.append(feature)
        
        return X_combined[selected_features] if selected_features else pd.DataFrame()

    def _evaluate_with_multiseed_cv(self, X, y, sample_weights, model_name, best_params):
        all_metrics = defaultdict(list)
        roc_curves = {cls: [] for cls in Config.CLASS_LABELS.values()}

        classes = list(Config.CLASS_LABELS.keys())
        y_bin_full = label_binarize(y, classes=classes)

        for seed in self.random_seeds:
            cv = StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=seed)
            for train_idx, val_idx in cv.split(X, y):
                X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
                y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
                w_train = sample_weights.iloc[train_idx] if sample_weights is not None else None
                w_val   = sample_weights.iloc[val_idx]   if sample_weights is not None else None

                model = create_model(model_name, random_state=seed)
                model.set_params(**best_params)

                fit_params = {}
                
                class_sample_weights = compute_sample_weight(
                    class_weight='balanced', 
                    y=y_train
                )
                
                if w_train is not None:
                    combined_weights = w_train * class_sample_weights
                    fit_params['sample_weight'] = combined_weights
                else:
                    fit_params['sample_weight'] = class_sample_weights

                model.fit(X_train, y_train, **fit_params)

                y_score = model.predict_proba(X_val)
                y_val_bin = y_bin_full[val_idx, :]

                for i, class_name in enumerate(Config.CLASS_LABELS.values()):
                    fpr_i, tpr_i, _ = roc_curve(
                        y_val_bin[:, i],
                        y_score[:, i],
                        sample_weight=w_val
                    )
                    roc_curves[class_name].append((fpr_i, tpr_i))

                metrics = self._calculate_weighted_metrics(
                    y_val, model.predict(X_val), y_score, w_val
                )
                for name, v in metrics.items():
                    all_metrics[name].append(v)

        return all_metrics, roc_curves

def plot_separate_test_rocs(all_test_predictions, model_names):
    """Plot ROC curves for each model separately"""
    from sklearn.preprocessing import label_binarize
    from sklearn.metrics import roc_curve, auc

    for model_name in model_names:
        model_entries = [e for e in all_test_predictions if e['model_name'] == model_name]
        if not model_entries:
            continue

        fig, axes = plt.subplots(1,3, figsize=(4, 12))
        axes = axes.flatten()

        for idx, class_label in enumerate(Config.CLASS_LABELS.values()):
            ax = axes[idx]
            colors = ['blue', 'red', 'green', 'orange', 'purple']
            color_idx = 0

            for entry in model_entries:
                feature_set = entry['feature_set']
                y_score = entry['y_score']
                y_true = entry['y_true']
                sample_weight = entry.get('sample_weight', None)

                classes = list(Config.CLASS_LABELS.keys())
                y_true_bin = label_binarize(y_true, classes=classes)

                fpr, tpr, _ = roc_curve(
                    y_true_bin[:, idx],
                    y_score[:, idx],
                    sample_weight=sample_weight
                )
                roc_auc = auc(fpr, tpr)

                display_name = feature_set
                if feature_set.startswith('SHAP_Consensus_'):
                    display_name = "SHAP_Consensus"

                ax.plot(fpr, tpr,
                        label=f"{display_name} (AUC={roc_auc:.2f})",
                        linewidth=2,
                        color=colors[color_idx % len(colors)])
                color_idx += 1

            ax.plot([0, 1], [0, 1], '--', color='gray', alpha=0.5)
            ax.set_title(class_label, fontsize=14)
            ax.set_xlabel("False Positive Rate")
            ax.set_ylabel("True Positive Rate")
            ax.legend(fontsize='small', loc='lower right')

        fig.suptitle(f"{model_name} - Test Set ROC Curves by Disease Severity", fontsize=16)
        fig.tight_layout(rect=[0, 0, 1, 0.96])
        output_path = Config.OUTPUT_DIR / f"roc_test_{model_name}_by_severity.png"
        fig.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"Saved test ROC figure for {model_name}: {output_path}")

def evaluate_on_test_set(pipeline, stability_analyzer, model_names, X_train_full, X_test, y_train_full, y_test, sw_train_full, sw_test):
    final_test_results = []
    all_test_predictions = []

    for model_name in model_names:
        stability_results = stability_analyzer.stability_results.get(model_name, {})

        # Rebuild feature sets
        feature_sets = {
            'Combined': X_train_full,
            'CPI': pipeline._get_clinical_index_features(X_train_full, 'CPI'),
            'Ramfjord': pipeline._get_clinical_index_features(X_train_full, 'Ramfjord')
        }
        
        # Add SHAP consensus if available
        consensus_key = f'SHAP_Consensus_{model_name}'
        consensus_df = pd.DataFrame()
        if 'tooth_level_importance' in stability_results:
            consensus_df = pipeline._create_consensus_feature_sets(
                X_train_full, stability_results, model_name
            ).get('SHAP_Consensus', pd.DataFrame())
        feature_sets[consensus_key] = consensus_df

        for feature_set_name, X_fs_train in feature_sets.items():
            if X_fs_train.empty:
                continue

            print(f"\nEvaluating {model_name} / {feature_set_name} on test set...")

            # Get cached parameters
            cache_name = 'SHAP_Consensus' if feature_set_name.startswith('SHAP_Consensus_') else feature_set_name
            param_key = f"{model_name}_{cache_name}"
            best_params = pipeline.best_params_cache.get(param_key, {})

            X_fs_test = X_test.reindex(columns=X_fs_train.columns)
            if X_fs_test.isna().any().any():
                nan_counts = X_fs_test.isna().sum()
                missing_info = nan_counts[nan_counts > 0].to_dict()
                print(f"  Warning: Test set for {model_name}/{feature_set_name} has NaNs in columns (filling with 0): {missing_info}")
                X_fs_test = X_fs_test.fillna(0)

            # Train final model
            model = create_model(model_name, random_state=42)
            model.set_params(**best_params)

            fit_params = {}

            class_sample_weights = compute_sample_weight(class_weight='balanced', y=y_train_full)
            if sw_train_full is not None:
                combined_weights = sw_train_full * class_sample_weights
                fit_params['sample_weight'] = combined_weights
            else:
                fit_params['sample_weight'] = class_sample_weights

            model.fit(X_fs_train, y_train_full, **fit_params)

            # Predictions
            y_pred = model.predict(X_fs_test)
            y_pred_proba = model.predict_proba(X_fs_test)

            # Metrics
            test_metrics = pipeline._calculate_weighted_metrics(
                y_test, y_pred, y_pred_proba, sw_test
            )
            for metric_name, metric_value in test_metrics.items():
                final_test_results.append({
                    'Model': model_name,
                    'Feature_Set': feature_set_name,
                    'Metric': metric_name,
                    'Test_Score': metric_value,
                    'N_Features': X_fs_train.shape[1]
                })

            print(f"  {feature_set_name} ({X_fs_train.shape[1]} features): "
                  f"AUC={test_metrics['auc_macro']:.3f}, "
                  f"Acc={test_metrics['accuracy']:.3f}, "
                  f"QWK={test_metrics['qwk']:.3f}")

            all_test_predictions.append({
                'model_name': model_name,
                'feature_set': feature_set_name,
                'y_score': y_pred_proba,
                'y_true': y_test,
                'sample_weight': sw_test
            })

    return final_test_results, all_test_predictions

# ================== Main Execution ==================
def main():
    print("="*80 + "\nMULTI-SEED TEETH STABILITY ANALYSIS (SHAP-FOCUSED, v8)\n" + "="*80)
    
    # Check for required libraries
    if not XGB_AVAILABLE and not LGB_AVAILABLE:
        print("ERROR: Neither XGBoost nor LightGBM is available. Please install at least one.")
        return
    
    data_processor = DataProcessor()
    df = data_processor.load_and_preprocess(Config.DATA_PATH)

    if df.empty:
        print("Execution halted due to missing data file.")
        return

    X_combined_unscaled, y, sample_weights = data_processor.get_combined_feature_set(df)

    X_combined_scaled = X_combined_unscaled  # no scaling for tree models
    
    print("\n" + "="*60 + "\nDATA SPLITTING\n" + "="*60)
    X_train_full, X_test, y_train_full, y_test = train_test_split(
        X_combined_scaled, y, test_size=0.2, random_state=42, stratify=y
    )

    sw_train_full, sw_test = (None, None)
    if sample_weights is not None:
        sw_train_full, sw_test = train_test_split(
            sample_weights, test_size=0.2, random_state=42, stratify=y
        )
    else:
        sw_train_full, sw_test = None, None
        
    #save test set
    test_idx = X_test.index
    test_full = df.loc[test_idx].copy()

    test_full["label_name"] = test_full[Config.TARGET_COLUMN].map(Config.CLASS_LABELS)

    front = [Config.TARGET_COLUMN, "label_name", Config.WEIGHT_COLUMN, Config.AGE_COLUMN]
    front_exist = [c for c in front if c in test_full.columns]
    test_full = test_full[front_exist + [c for c in test_full.columns if c not in front_exist]]

    test_csv_path = Config.OUTPUT_DIR / "test_set_full.csv"
    test_full.to_csv(test_csv_path, index=False)
    print(f"Full test set saved to: {test_csv_path}")


    print(f"Dataset sizes:")
    print(f"  Training: {len(X_train_full):,} samples ({len(X_train_full)/len(X_combined_scaled)*100:.1f}%)")
    print(f"  Test:     {len(X_test):,} samples ({len(X_test)/len(X_combined_scaled)*100:.1f}%)")
    
    print(f"\nClass distribution:")
    for split_name, y_split in [("Training", y_train_full), ("Test", y_test)]:
        dist = y_split.value_counts().sort_index()
        dist_pct = (dist / len(y_split) * 100).round(1)
        print(f"  {split_name}: {dict(zip([Config.CLASS_LABELS[i] for i in dist.index], [f'{dist[i]} ({dist_pct[i]}%)' for i in dist.index]))}")
    
    print(f"\nCross-validation setup:")
    print(f"  Method: {Config.CV_FOLDS}-fold stratified cross-validation")
    print(f"  Seeds: {len(Config.RANDOM_SEEDS)} different random seeds")
    print(f"  Total CV runs: {Config.CV_FOLDS * len(Config.RANDOM_SEEDS)} per feature set per model")
    
    model_names = [name for name, available in [('XGBoost', XGB_AVAILABLE), ('LightGBM', LGB_AVAILABLE)] if available]
    print(f"Available models: {model_names}")

    pipeline = ModelPipeline()
    stability_analyzer = MultiSeedStabilityAnalyzer(data_processor)

    all_results = []
    all_stability_results = {}

    for model_name in model_names:
        print(f"\n{'='*60}\nPROCESSING MODEL: {model_name}\n{'='*60}")
        best_params = pipeline.tune_hyperparameters_once(
            model_name, X_train_full, y_train_full, sw_train_full
        )
        
        model_results, stability_results = pipeline.evaluate_feature_sets_multiseed(
            X_train_full, y_train_full, sw_train_full, model_name, best_params, stability_analyzer, Config.OUTPUT_DIR
        )
        
        all_results.extend(model_results)
        all_stability_results[model_name] = stability_results

    # Final test set evaluation
    final_test_results, all_test_predictions = evaluate_on_test_set(
        pipeline, stability_analyzer,
        model_names,
        X_train_full, X_test,
        y_train_full, y_test,
        sw_train_full, sw_test
    )

    # Save cross-validation results
    results_df = pd.DataFrame(all_results)
    results_df.to_csv(Config.OUTPUT_DIR / "cross_validation_results.csv", index=False)
    print(f"Saved: cross_validation_results.csv")

    # Plot combined ROC curves
    if all_test_predictions:
        plot_separate_test_rocs(all_test_predictions, model_names)

    # Save test results
    if final_test_results:
        test_results_df = pd.DataFrame(final_test_results)
        test_results_df.to_csv(Config.OUTPUT_DIR / "final_test_results.csv", index=False)
        print(f"Saved: final_test_results.csv")

    # CV vs Test comparison
    if final_test_results:
        cv_summary = results_df.groupby(['Model', 'Feature_Set', 'Metric'])['Mean'].first().reset_index()
        cv_summary = cv_summary.rename(columns={'Mean': 'CV_Score'})
        comparison_df = pd.merge(
            cv_summary,
            test_results_df,
            on=['Model', 'Feature_Set', 'Metric'],
            how='outer'
        )
        comparison_df['Difference'] = comparison_df['Test_Score'] - comparison_df['CV_Score']
        comparison_df['Overfitting'] = comparison_df['Difference'] < -0.01
        comparison_df.to_csv(Config.OUTPUT_DIR / "cv_vs_test_comparison.csv", index=False)
        print(f"Saved: cv_vs_test_comparison.csv")

    # Update stability analyzer
    stability_analyzer.stability_results = all_stability_results
    
    # Final consolidated summary
    final_summary_data = []
    for model_name, stability in all_stability_results.items():
        if 'tooth_level_importance' not in stability: 
            continue
        
        tooth_ranking_df = stability['tooth_level_importance']
        consensus_teeth = tooth_ranking_df.head(Config.TOP_FEATURES_COUNT)['tooth_fdi'].tolist()
        
        param_key = f"{model_name}_Combined"
        params = pipeline.best_params_cache.get(param_key, {})
        
        test_perf = {}
        if final_test_results:
            test_df = pd.DataFrame(final_test_results)
            combined_test = test_df[(test_df['Model'] == model_name) & (test_df['Feature_Set'] == 'Combined')]
            for _, row in combined_test.iterrows():
                test_perf[f"Test_{row['Metric']}"] = f"{row['Test_Score']:.3f}"
        
        shap_test_perf = {}
        if final_test_results:
            shap_test = test_df[(test_df['Model'] == model_name) & (test_df['Feature_Set'] == f'SHAP_Consensus_{model_name}')]
            for _, row in shap_test.iterrows():
                shap_test_perf[f"SHAP_Test_{row['Metric']}"] = f"{row['Test_Score']:.3f}"
        
        summary_row = {
            'Model': model_name,
            'Consensus_Teeth_FDI': ', '.join(map(str, sorted(consensus_teeth))) if consensus_teeth else 'None',
            'N_Consensus_Teeth': len(consensus_teeth)
        }
        summary_row.update(params)
        summary_row.update(test_perf)
        summary_row.update(shap_test_perf)
        final_summary_data.append(summary_row)

    if final_summary_data:
        final_summary_df = pd.DataFrame(final_summary_data).fillna('N/A')
        base_cols = ['Model', 'Consensus_Teeth_FDI', 'N_Consensus_Teeth']
        test_cols = sorted([c for c in final_summary_df.columns if c.startswith('Test_') or c.startswith('SHAP_Test_')])
        param_cols = sorted([c for c in final_summary_df.columns if c not in base_cols + test_cols])
        final_summary_df = final_summary_df[base_cols + test_cols + param_cols]
        final_summary_df.to_csv(Config.OUTPUT_DIR / "final_consolidated_summary.csv", index=False)
        print("Saved: final_consolidated_summary.csv")

    print("\n" + "="*80 + "\nANALYSIS COMPLETE\n" + "="*80)
    print(f"All results, summaries, and plots have been saved to: {Config.OUTPUT_DIR}")
    print("\nKey outputs:")
    print("- cross_validation_results.csv: Multi-seed cross-validation performance (on training set)")
    print("- final_test_results.csv: Unbiased test set performance")  
    print("- cv_vs_test_comparison.csv: Cross-validation vs test performance comparison")
    print("- final_consolidated_summary.csv: Complete summary with test scores")
    print("\nData usage summary:")
    print("- Training set (80%): Used for hyperparameter tuning, cross-validation, and SHAP analysis")
    print("- Test set (20%): Used only for final unbiased evaluation")

if __name__ == "__main__":
    main()