In [1]:
"""
Advanced Multiclass Classification Pipeline for tfopwg_disp
============================================================
This module implements a production-ready machine learning solution for
multiclass classification with comprehensive data processing, feature engineering,
advanced modeling, and detailed evaluation.

Author: ML Pipeline
Date: October 4, 2025
"""

import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime

# Preprocessing
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.preprocessing import StandardScaler, RobustScaler, LabelEncoder
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.feature_selection import SelectKBest, mutual_info_classif

# Models
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
import xgboost as xgb
import lightgbm as lgb

# Metrics
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    precision_recall_fscore_support, roc_auc_score, roc_curve,
    matthews_corrcoef, cohen_kappa_score, log_loss
)

# Utilities
import joblib
from scipy import stats
from collections import Counter

# Set random seed for reproducibility
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

# Styling
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")


class DataProcessor:
    """Handle data loading, cleaning, and preprocessing"""

    def __init__(self, filepath):
        self.filepath = filepath
        self.df = None
        self.target_col = 'tfopwg_disp'
        self.label_encoder = LabelEncoder()

    def load_data(self):
        """Load and perform initial data inspection"""
        print("="*70)
        print("LOADING DATA")
        print("="*70)

        self.df = pd.read_csv(self.filepath)
        print(f"\n✓ Data loaded successfully: {self.df.shape[0]} rows, {self.df.shape[1]} columns")

        # Display target distribution
        print(f"\n📊 Target Variable Distribution ({self.target_col}):")
        print(self.df[self.target_col].value_counts())
        print(f"\nTarget Class Proportions:")
        print(self.df[self.target_col].value_counts(normalize=True).round(4))

        return self.df

    def analyze_data_quality(self):
        """Comprehensive data quality analysis"""
        print("\n" + "="*70)
        print("DATA QUALITY ANALYSIS")
        print("="*70)

        # Missing values
        missing = self.df.isnull().sum()
        missing_pct = (missing / len(self.df) * 100).round(2)
        missing_df = pd.DataFrame({
            'Missing_Count': missing,
            'Missing_Percentage': missing_pct
        }).sort_values('Missing_Percentage', ascending=False)

        print("\n📋 Columns with Missing Values:")
        print(missing_df[missing_df['Missing_Count'] > 0].head(20))

        # Data types
        print("\n📝 Data Types:")
        print(self.df.dtypes.value_counts())

        # Numeric vs categorical
        numeric_cols = self.df.select_dtypes(include=[np.number]).columns.tolist()
        categorical_cols = self.df.select_dtypes(exclude=[np.number]).columns.tolist()

        print(f"\n✓ Numeric columns: {len(numeric_cols)}")
        print(f"✓ Categorical columns: {len(categorical_cols)}")

        return missing_df, numeric_cols, categorical_cols

    def prepare_features_target(self):
        """Separate features and target, handle data types"""
        print("\n" + "="*70)
        print("PREPARING FEATURES AND TARGET")
        print("="*70)

        # Remove non-predictive columns
        cols_to_drop = ['toi', 'tid', 'rastr', 'decstr', 'toi_created', 'rowupdate', self.target_col]
        feature_cols = [col for col in self.df.columns if col not in cols_to_drop]

        X = self.df[feature_cols].copy()
        y = self.df[self.target_col].copy()

        # Encode target
        y_encoded = self.label_encoder.fit_transform(y)

        print(f"\n✓ Features shape: {X.shape}")
        print(f"✓ Target shape: {y_encoded.shape}")
        print(f"✓ Number of classes: {len(self.label_encoder.classes_)}")
        print(f"✓ Classes: {list(self.label_encoder.classes_)}")

        return X, y_encoded, feature_cols


class FeatureEngineer:
    """Advanced feature engineering and selection"""

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.imputer = None
        self.scaler = None
        self.selected_features = None

    def handle_missing_values(self, strategy='advanced'):
        """Handle missing values with multiple strategies"""
        print("\n" + "="*70)
        print("HANDLING MISSING VALUES")
        print("="*70)

        missing_cols = self.X.columns[self.X.isnull().any()].tolist()

        if len(missing_cols) == 0:
            print("\n✓ No missing values detected")
            return self.X

        # Remove columns with too many missing values (>95%)
        missing_pct = (self.X.isnull().sum() / len(self.X) * 100)
        cols_to_drop = missing_pct[missing_pct > 95].index.tolist()

        if len(cols_to_drop) > 0:
            print(f"\n🗑️  Dropping {len(cols_to_drop)} columns with >95% missing values:")
            for col in cols_to_drop:
                print(f"  • {col} ({missing_pct[col]:.1f}% missing)")
            self.X = self.X.drop(columns=cols_to_drop)
            missing_cols = self.X.columns[self.X.isnull().any()].tolist()

        if len(missing_cols) == 0:
            print("\n✓ No missing values remaining after dropping high-missing columns")
            return self.X

        if strategy == 'advanced':
            # Use KNN imputer for better imputation
            print(f"\n🔧 Applying KNN Imputation (k=5) to {len(missing_cols)} columns...")
            self.imputer = KNNImputer(n_neighbors=5, weights='distance')
            X_imputed = self.imputer.fit_transform(self.X)
            self.X = pd.DataFrame(X_imputed, columns=self.X.columns, index=self.X.index)
        else:
            # Simple median imputation
            print(f"\n🔧 Applying Median Imputation to {len(missing_cols)} columns...")
            self.imputer = SimpleImputer(strategy='median')
            X_imputed = self.imputer.fit_transform(self.X)
            self.X = pd.DataFrame(X_imputed, columns=self.X.columns, index=self.X.index)

        print(f"✓ Missing values handled successfully")
        print(f"✓ Final feature count: {self.X.shape[1]}")

        return self.X

    def create_features(self):
        """Create advanced engineered features"""
        print("\n" + "="*70)
        print("FEATURE ENGINEERING")
        print("="*70)

        X_new = self.X.copy()

        # 1. Polynomial features (selected important features)
        if 'st_rad' in X_new.columns and 'st_teff' in X_new.columns:
            X_new['luminosity_proxy'] = X_new['st_rad']**2 * X_new['st_teff']**4
            print("✓ Created: luminosity_proxy")

        # 2. Ratio features
        if 'pl_rade' in X_new.columns and 'st_rad' in X_new.columns:
            X_new['planet_star_radius_ratio'] = X_new['pl_rade'] / (X_new['st_rad'] + 1e-10)
            print("✓ Created: planet_star_radius_ratio")

        # 3. Transit depth proxy
        if 'pl_rade' in X_new.columns and 'st_rad' in X_new.columns:
            X_new['transit_depth_proxy'] = (X_new['pl_rade'] / (X_new['st_rad'] + 1e-10))**2
            print("✓ Created: transit_depth_proxy")

        # 4. Distance features
        if 'st_dist' in X_new.columns:
            X_new['log_distance'] = np.log1p(X_new['st_dist'])
            X_new['inv_distance'] = 1 / (X_new['st_dist'] + 1e-10)
            print("✓ Created: log_distance, inv_distance")

        # 5. Temperature ratio
        if 'pl_eqt' in X_new.columns and 'st_teff' in X_new.columns:
            X_new['temp_ratio'] = X_new['pl_eqt'] / (X_new['st_teff'] + 1e-10)
            print("✓ Created: temp_ratio")

        # 6. Insolation features
        if 'pl_insol' in X_new.columns:
            X_new['log_insol'] = np.log1p(X_new['pl_insol'])
            X_new['sqrt_insol'] = np.sqrt(X_new['pl_insol'] + 1e-10)
            print("✓ Created: log_insol, sqrt_insol")

        # 7. Orbital period features
        if 'pl_orbper' in X_new.columns:
            X_new['log_orbper'] = np.log1p(X_new['pl_orbper'])
            print("✓ Created: log_orbper")

        # 8. Error ratio features (uncertainty indicators)
        error_pairs = [
            ('st_pmra', 'st_pmraerr1'),
            ('st_pmdec', 'st_pmdecerr1'),
            ('pl_rade', 'pl_radeerr1'),
        ]

        for base, error in error_pairs:
            if base in X_new.columns and error in X_new.columns:
                X_new[f'{base}_error_ratio'] = abs(X_new[error]) / (abs(X_new[base]) + 1e-10)
                print(f"✓ Created: {base}_error_ratio")

        # 9. Statistical features
        numeric_cols = X_new.select_dtypes(include=[np.number]).columns[:10]  # First 10 numeric cols
        if len(numeric_cols) > 3:
            X_new['feature_mean'] = X_new[numeric_cols].mean(axis=1)
            X_new['feature_std'] = X_new[numeric_cols].std(axis=1)
            X_new['feature_max'] = X_new[numeric_cols].max(axis=1)
            X_new['feature_min'] = X_new[numeric_cols].min(axis=1)
            print("✓ Created: aggregated statistical features")

        # Replace infinities
        X_new = X_new.replace([np.inf, -np.inf], np.nan)

        # Fill any new NaN values created during feature engineering
        for col in X_new.columns:
            if X_new[col].isnull().any():
                X_new[col].fillna(X_new[col].median(), inplace=True)

        print(f"\n✓ Total features after engineering: {X_new.shape[1]}")

        self.X = X_new
        return self.X

    def remove_outliers(self, contamination=0.05):
        """Remove outliers using IQR method"""
        print("\n" + "="*70)
        print("OUTLIER DETECTION AND REMOVAL")
        print("="*70)

        original_shape = self.X.shape[0]

        # IQR method for outlier detection
        Q1 = self.X.quantile(0.25)
        Q3 = self.X.quantile(0.75)
        IQR = Q3 - Q1

        # Define outlier boundaries
        lower_bound = Q1 - 3 * IQR
        upper_bound = Q3 + 3 * IQR

        # Identify outliers
        outlier_mask = ~((self.X < lower_bound) | (self.X > upper_bound)).any(axis=1)

        self.X = self.X[outlier_mask]
        self.y = self.y[outlier_mask]

        removed = original_shape - self.X.shape[0]
        print(f"\n✓ Removed {removed} outlier samples ({removed/original_shape*100:.2f}%)")
        print(f"✓ Remaining samples: {self.X.shape[0]}")

        return self.X, self.y

    def scale_features(self, method='robust'):
        """Scale features using specified method"""
        print("\n" + "="*70)
        print("FEATURE SCALING")
        print("="*70)

        if method == 'robust':
            self.scaler = RobustScaler()
            print("\n🔧 Applying Robust Scaling (resistant to outliers)...")
        else:
            self.scaler = StandardScaler()
            print("\n🔧 Applying Standard Scaling...")

        X_scaled = self.scaler.fit_transform(self.X)
        self.X = pd.DataFrame(X_scaled, columns=self.X.columns, index=self.X.index)

        print(f"✓ Features scaled: {self.X.shape[1]} features")

        return self.X

    def select_features(self, k=60):
        """Select top k features using mutual information"""
        print("\n" + "="*70)
        print("FEATURE SELECTION")
        print("="*70)

        if self.X.shape[1] <= k:
            print(f"\n✓ Number of features ({self.X.shape[1]}) <= k ({k}), keeping all features")
            self.selected_features = self.X.columns.tolist()
            feature_scores = pd.DataFrame({
                'feature': self.X.columns,
                'score': [1.0] * len(self.X.columns)
            })
            return self.X, feature_scores

        print(f"\n🔧 Selecting top {k} features using Mutual Information...")

        selector = SelectKBest(mutual_info_classif, k=k)
        X_selected = selector.fit_transform(self.X, self.y)

        # Get selected feature names
        selected_indices = selector.get_support(indices=True)
        self.selected_features = self.X.columns[selected_indices].tolist()

        # Get feature scores
        feature_scores = pd.DataFrame({
            'feature': self.X.columns,
            'score': selector.scores_
        }).sort_values('score', ascending=False)

        print(f"\n✓ Selected {k} features")
        print("\n📊 Top 15 Most Important Features:")
        print(feature_scores.head(15).to_string(index=False))

        self.X = pd.DataFrame(X_selected, columns=self.selected_features, index=self.X.index)

        return self.X, feature_scores


class ModelBuilder:
    """Build, train, and ensemble multiple models"""

    def __init__(self, X_train, X_test, y_train, y_test):
        self.X_train = X_train
        self.X_test = X_test
        self.y_train = y_train
        self.y_test = y_test
        self.models = {}
        self.predictions = {}
        self.probabilities = {}
        self.ensemble_model = None

    def build_models(self):
        """Build multiple classification models with optimized hyperparameters"""
        print("\n" + "="*70)
        print("BUILDING MODELS (OPTIMIZED HYPERPARAMETERS)")
        print("="*70)

        # Define models with improved hyperparameters
        self.models = {
            'LightGBM': lgb.LGBMClassifier(
                n_estimators=800,
                learning_rate=0.03,
                max_depth=10,
                num_leaves=50,
                min_child_samples=15,
                subsample=0.85,
                colsample_bytree=0.85,
                reg_alpha=0.1,
                reg_lambda=0.1,
                min_split_gain=0.01,
                random_state=RANDOM_STATE,
                verbose=-1,
                n_jobs=-1,
                class_weight='balanced'
            ),
            'XGBoost': xgb.XGBClassifier(
                n_estimators=800,
                learning_rate=0.03,
                max_depth=9,
                min_child_weight=2,
                subsample=0.85,
                colsample_bytree=0.85,
                colsample_bylevel=0.85,
                gamma=0.1,
                reg_alpha=0.1,
                reg_lambda=1.0,
                scale_pos_weight=1,
                random_state=RANDOM_STATE,
                eval_metric='mlogloss',
                n_jobs=-1,
                tree_method='hist'
            ),
            'RandomForest': RandomForestClassifier(
                n_estimators=700,
                max_depth=25,
                min_samples_split=4,
                min_samples_leaf=1,
                max_features='sqrt',
                bootstrap=True,
                oob_score=True,
                class_weight='balanced',
                random_state=RANDOM_STATE,
                n_jobs=-1
            ),
            'ExtraTrees': RandomForestClassifier(
                n_estimators=700,
                max_depth=30,
                min_samples_split=3,
                min_samples_leaf=1,
                max_features='sqrt',
                bootstrap=False,
                class_weight='balanced',
                criterion='gini',
                random_state=RANDOM_STATE,
                n_jobs=-1
            ),
            'LogisticRegression': LogisticRegression(
                max_iter=2000,
                multi_class='multinomial',
                solver='saga',
                C=0.5,
                penalty='l2',
                class_weight='balanced',
                random_state=RANDOM_STATE,
                n_jobs=-1
            )
        }

        print(f"\n✓ Built {len(self.models)} models with optimized parameters:")
        for name in self.models.keys():
            print(f"  • {name}")

        return self.models

    def train_models(self):
        """Train all models"""
        print("\n" + "="*70)
        print("TRAINING MODELS")
        print("="*70)

        for name, model in self.models.items():
            print(f"\n🔧 Training {name}...")
            start_time = datetime.now()

            model.fit(self.X_train, self.y_train)

            # Make predictions
            self.predictions[name] = model.predict(self.X_test)
            self.probabilities[name] = model.predict_proba(self.X_test)

            # Calculate accuracy
            accuracy = accuracy_score(self.y_test, self.predictions[name])

            elapsed_time = (datetime.now() - start_time).total_seconds()
            print(f"✓ {name} trained in {elapsed_time:.2f}s - Accuracy: {accuracy:.4f}")

        return self.predictions

    def create_ensemble(self):
        """Create weighted voting ensemble of best models"""
        print("\n" + "="*70)
        print("CREATING ENSEMBLE MODEL")
        print("="*70)

        # Use top 4 models for ensemble with optimized weights
        ensemble_models = [
            ('lgb', self.models['LightGBM']),
            ('xgb', self.models['XGBoost']),
            ('rf', self.models['RandomForest']),
            ('et', self.models['ExtraTrees'])
        ]

        print("\n🔧 Building Weighted Soft Voting Classifier...")
        # Weights based on expected performance (LightGBM, XGBoost get higher weights)
        self.ensemble_model = VotingClassifier(
            estimators=ensemble_models,
            voting='soft',
            weights=[1.5, 1.5, 1.0, 1.0],
            n_jobs=-1
        )

        print("🔧 Training Ensemble...")
        self.ensemble_model.fit(self.X_train, self.y_train)

        # Make predictions
        self.predictions['Ensemble'] = self.ensemble_model.predict(self.X_test)
        self.probabilities['Ensemble'] = self.ensemble_model.predict_proba(self.X_test)

        accuracy = accuracy_score(self.y_test, self.predictions['Ensemble'])
        print(f"✓ Ensemble trained - Accuracy: {accuracy:.4f}")

        return self.ensemble_model

    def cross_validate(self, cv_folds=5):
        """Perform cross-validation on all models"""
        print("\n" + "="*70)
        print("CROSS-VALIDATION")
        print("="*70)

        cv_results = {}

        for name, model in self.models.items():
            print(f"\n🔧 Cross-validating {name}...")

            skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=RANDOM_STATE)
            scores = cross_val_score(model, self.X_train, self.y_train, cv=skf, scoring='accuracy', n_jobs=-1)

            cv_results[name] = {
                'mean': scores.mean(),
                'std': scores.std(),
                'scores': scores
            }

            print(f"  Mean Accuracy: {scores.mean():.4f} (+/- {scores.std()*2:.4f})")

        return cv_results


class ModelEvaluator:
    """Comprehensive model evaluation and visualization"""

    def __init__(self, models_dict, predictions_dict, probabilities_dict,
                 y_test, label_encoder, results_dir):
        self.models = models_dict
        self.predictions = predictions_dict
        self.probabilities = probabilities_dict
        self.y_test = y_test
        self.label_encoder = label_encoder
        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(exist_ok=True)
        self.metrics_summary = {}

    def calculate_metrics(self):
        """Calculate comprehensive metrics for all models"""
        print("\n" + "="*70)
        print("CALCULATING METRICS")
        print("="*70)

        for name in self.predictions.keys():
            y_pred = self.predictions[name]
            y_proba = self.probabilities[name]

            # Basic metrics
            accuracy = accuracy_score(self.y_test, y_pred)
            precision, recall, f1, support = precision_recall_fscore_support(
                self.y_test, y_pred, average='weighted'
            )

            # Additional metrics
            mcc = matthews_corrcoef(self.y_test, y_pred)
            kappa = cohen_kappa_score(self.y_test, y_pred)
            logloss = log_loss(self.y_test, y_proba)

            # Multi-class AUC (ovr)
            try:
                auc_ovr = roc_auc_score(self.y_test, y_proba, multi_class='ovr', average='weighted')
            except:
                auc_ovr = np.nan

            self.metrics_summary[name] = {
                'Accuracy': accuracy,
                'Precision': precision,
                'Recall': recall,
                'F1-Score': f1,
                'MCC': mcc,
                'Cohen_Kappa': kappa,
                'Log_Loss': logloss,
                'AUC_OVR': auc_ovr
            }

            print(f"\n📊 {name} Metrics:")
            print(f"  Accuracy:     {accuracy:.4f}")
            print(f"  Precision:    {precision:.4f}")
            print(f"  Recall:       {recall:.4f}")
            print(f"  F1-Score:     {f1:.4f}")
            print(f"  MCC:          {mcc:.4f}")
            print(f"  Cohen Kappa:  {kappa:.4f}")
            print(f"  Log Loss:     {logloss:.4f}")
            print(f"  AUC (OVR):    {auc_ovr:.4f}")

        # Save metrics to JSON
        metrics_file = self.results_dir / 'metrics_summary.json'
        with open(metrics_file, 'w') as f:
            json.dump(self.metrics_summary, f, indent=4)
        print(f"\n✓ Metrics saved to {metrics_file}")

        return self.metrics_summary

    def plot_model_comparison(self):
        """Create comprehensive model comparison plots"""
        print("\n" + "="*70)
        print("CREATING MODEL COMPARISON PLOTS")
        print("="*70)

        # Extract metrics for plotting
        models_list = list(self.metrics_summary.keys())
        metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'MCC', 'Cohen_Kappa']

        # Create subplots
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        fig.suptitle('Model Performance Comparison', fontsize=20, fontweight='bold')
        axes = axes.ravel()

        for idx, metric in enumerate(metrics_names):
            values = [self.metrics_summary[model][metric] for model in models_list]

            # Create bar plot
            bars = axes[idx].bar(models_list, values, color=plt.cm.viridis(np.linspace(0, 1, len(models_list))))
            axes[idx].set_title(f'{metric}', fontsize=14, fontweight='bold')
            axes[idx].set_ylabel('Score', fontsize=12)
            axes[idx].set_ylim([0, 1])
            axes[idx].grid(axis='y', alpha=0.3)
            axes[idx].tick_params(axis='x', rotation=45)

            # Add value labels on bars
            for bar in bars:
                height = bar.get_height()
                axes[idx].text(bar.get_x() + bar.get_width()/2., height,
                             f'{height:.3f}',
                             ha='center', va='bottom', fontsize=10)

        plt.tight_layout()
        plot_file = self.results_dir / 'model_comparison.png'
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✓ Model comparison plot saved to {plot_file}")

    def plot_confusion_matrices(self):
        """Plot confusion matrices for all models"""
        print("\n" + "="*70)
        print("CREATING CONFUSION MATRICES")
        print("="*70)

        n_models = len(self.predictions)
        fig, axes = plt.subplots(2, 3, figsize=(20, 14))
        fig.suptitle('Confusion Matrices', fontsize=20, fontweight='bold')
        axes = axes.ravel()

        for idx, (name, y_pred) in enumerate(self.predictions.items()):
            if idx >= len(axes):
                break

            cm = confusion_matrix(self.y_test, y_pred)

            # Normalize confusion matrix
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

            # Plot
            sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                       ax=axes[idx], cbar_kws={'label': 'Proportion'},
                       xticklabels=self.label_encoder.classes_,
                       yticklabels=self.label_encoder.classes_)
            axes[idx].set_title(f'{name}\nAccuracy: {self.metrics_summary[name]["Accuracy"]:.4f}',
                              fontsize=12, fontweight='bold')
            axes[idx].set_ylabel('True Label', fontsize=11)
            axes[idx].set_xlabel('Predicted Label', fontsize=11)

        # Hide unused subplots
        for idx in range(len(self.predictions), len(axes)):
            axes[idx].axis('off')

        plt.tight_layout()
        plot_file = self.results_dir / 'confusion_matrices.png'
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✓ Confusion matrices saved to {plot_file}")

    def plot_classification_reports(self):
        """Create detailed classification reports"""
        print("\n" + "="*70)
        print("CREATING CLASSIFICATION REPORTS")
        print("="*70)

        for name, y_pred in self.predictions.items():
            print(f"\n📊 {name} Classification Report:")
            report = classification_report(
                self.y_test, y_pred,
                target_names=self.label_encoder.classes_,
                digits=4
            )
            print(report)

            # Save to file
            report_file = self.results_dir / f'classification_report_{name}.txt'
            with open(report_file, 'w') as f:
                f.write(f"Classification Report: {name}\n")
                f.write("="*70 + "\n\n")
                f.write(report)

            print(f"✓ Classification report saved to {report_file}")

    def plot_feature_importance(self, feature_names, top_n=20):
        """Plot feature importance for tree-based models"""
        print("\n" + "="*70)
        print("CREATING FEATURE IMPORTANCE PLOTS")
        print("="*70)

        importance_models = ['LightGBM', 'XGBoost', 'RandomForest', 'ExtraTrees']

        fig, axes = plt.subplots(2, 2, figsize=(20, 16))
        fig.suptitle('Feature Importance Analysis', fontsize=20, fontweight='bold')
        axes = axes.ravel()

        for idx, name in enumerate(importance_models):
            if name not in self.models:
                continue

            model = self.models[name]

            # Get feature importance
            if hasattr(model, 'feature_importances_'):
                importances = model.feature_importances_

                # Create dataframe
                importance_df = pd.DataFrame({
                    'feature': feature_names,
                    'importance': importances
                }).sort_values('importance', ascending=False).head(top_n)

                # Plot
                axes[idx].barh(range(len(importance_df)), importance_df['importance'],
                             color=plt.cm.plasma(np.linspace(0, 1, len(importance_df))))
                axes[idx].set_yticks(range(len(importance_df)))
                axes[idx].set_yticklabels(importance_df['feature'], fontsize=9)
                axes[idx].set_xlabel('Importance', fontsize=11)
                axes[idx].set_title(f'{name} - Top {top_n} Features', fontsize=13, fontweight='bold')
                axes[idx].invert_yaxis()
                axes[idx].grid(axis='x', alpha=0.3)

        plt.tight_layout()
        plot_file = self.results_dir / 'feature_importance.png'
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✓ Feature importance plot saved to {plot_file}")

    def plot_learning_curves(self, X_train, y_train):
        """Plot learning curves to analyze model performance"""
        print("\n" + "="*70)
        print("CREATING LEARNING CURVES")
        print("="*70)

        from sklearn.model_selection import learning_curve

        # Select a few models for learning curves
        models_to_plot = ['LightGBM', 'XGBoost', 'RandomForest']

        fig, axes = plt.subplots(1, 3, figsize=(20, 6))
        fig.suptitle('Learning Curves', fontsize=20, fontweight='bold')

        for idx, name in enumerate(models_to_plot):
            if name not in self.models:
                continue

            print(f"🔧 Computing learning curve for {name}...")

            model = self.models[name]
            train_sizes, train_scores, val_scores = learning_curve(
                model, X_train, y_train, cv=3,
                train_sizes=np.linspace(0.1, 1.0, 10),
                scoring='accuracy', n_jobs=-1, random_state=RANDOM_STATE
            )

            # Calculate means and stds
            train_mean = np.mean(train_scores, axis=1)
            train_std = np.std(train_scores, axis=1)
            val_mean = np.mean(val_scores, axis=1)
            val_std = np.std(val_scores, axis=1)

            # Plot
            axes[idx].plot(train_sizes, train_mean, label='Training score',
                          color='blue', marker='o', linewidth=2)
            axes[idx].fill_between(train_sizes, train_mean - train_std,
                                  train_mean + train_std, alpha=0.15, color='blue')

            axes[idx].plot(train_sizes, val_mean, label='Validation score',
                          color='red', marker='s', linewidth=2)
            axes[idx].fill_between(train_sizes, val_mean - val_std,
                                  val_mean + val_std, alpha=0.15, color='red')

            axes[idx].set_xlabel('Training Set Size', fontsize=11)
            axes[idx].set_ylabel('Accuracy Score', fontsize=11)
            axes[idx].set_title(f'{name}', fontsize=13, fontweight='bold')
            axes[idx].legend(loc='best', fontsize=10)
            axes[idx].grid(alpha=0.3)

        plt.tight_layout()
        plot_file = self.results_dir / 'learning_curves.png'
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✓ Learning curves saved to {plot_file}")

    def plot_roc_curves(self):
        """Plot ROC curves for multiclass classification (One-vs-Rest)"""
        print("\n" + "="*70)
        print("CREATING ROC CURVES")
        print("="*70)

        from sklearn.preprocessing import label_binarize

        n_classes = len(self.label_encoder.classes_)
        y_test_bin = label_binarize(self.y_test, classes=range(n_classes))

        # Plot for best model (Ensemble)
        if 'Ensemble' in self.probabilities:
            fig, ax = plt.subplots(figsize=(12, 10))

            y_proba = self.probabilities['Ensemble']

            # Compute ROC curve and AUC for each class
            for i in range(n_classes):
                fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_proba[:, i])
                auc_score = roc_auc_score(y_test_bin[:, i], y_proba[:, i])

                ax.plot(fpr, tpr, linewidth=2,
                       label=f'{self.label_encoder.classes_[i]} (AUC = {auc_score:.3f})')

            ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Random Classifier')
            ax.set_xlabel('False Positive Rate', fontsize=13)
            ax.set_ylabel('True Positive Rate', fontsize=13)
            ax.set_title('ROC Curves - Ensemble Model (One-vs-Rest)',
                        fontsize=15, fontweight='bold')
            ax.legend(loc='lower right', fontsize=10)
            ax.grid(alpha=0.3)

            plt.tight_layout()
            plot_file = self.results_dir / 'roc_curves_ensemble.png'
            plt.savefig(plot_file, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"✓ ROC curves saved to {plot_file}")

    def plot_prediction_distribution(self):
        """Plot distribution of predictions vs actual"""
        print("\n" + "="*70)
        print("CREATING PREDICTION DISTRIBUTION PLOTS")
        print("="*70)

        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        fig.suptitle('Prediction Distribution Analysis', fontsize=20, fontweight='bold')
        axes = axes.ravel()

        for idx, (name, y_pred) in enumerate(self.predictions.items()):
            if idx >= len(axes):
                break

            # Count predictions and actuals
            pred_counts = pd.Series(y_pred).value_counts().sort_index()
            actual_counts = pd.Series(self.y_test).value_counts().sort_index()

            # Create comparison dataframe
            comparison_df = pd.DataFrame({
                'Actual': actual_counts,
                'Predicted': pred_counts
            }).fillna(0)

            # Plot
            x = np.arange(len(comparison_df))
            width = 0.35

            axes[idx].bar(x - width/2, comparison_df['Actual'], width,
                         label='Actual', alpha=0.8, color='steelblue')
            axes[idx].bar(x + width/2, comparison_df['Predicted'], width,
                         label='Predicted', alpha=0.8, color='coral')

            axes[idx].set_xlabel('Class', fontsize=11)
            axes[idx].set_ylabel('Count', fontsize=11)
            axes[idx].set_title(f'{name}', fontsize=13, fontweight='bold')
            axes[idx].set_xticks(x)
            axes[idx].set_xticklabels(self.label_encoder.classes_)
            axes[idx].legend(fontsize=10)
            axes[idx].grid(axis='y', alpha=0.3)

        # Hide unused subplots
        for idx in range(len(self.predictions), len(axes)):
            axes[idx].axis('off')

        plt.tight_layout()
        plot_file = self.results_dir / 'prediction_distribution.png'
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✓ Prediction distribution plot saved to {plot_file}")


def save_artifacts(results_dir, models, scaler, imputer, label_encoder,
                  feature_names, selected_features, feature_scores):
    """Save all models and preprocessing artifacts"""
    print("\n" + "="*70)
    print("SAVING MODELS AND ARTIFACTS")
    print("="*70)

    results_path = Path(results_dir)

    # Save models
    for name, model in models.items():
        model_file = results_path / f'model_{name.lower().replace(" ", "_")}.pkl'
        joblib.dump(model, model_file)
        print(f"✓ Saved {name} model to {model_file}")

    # Save preprocessing objects
    joblib.dump(scaler, results_path / 'scaler.pkl')
    print(f"✓ Saved scaler to {results_path / 'scaler.pkl'}")

    joblib.dump(imputer, results_path / 'imputer.pkl')
    print(f"✓ Saved imputer to {results_path / 'imputer.pkl'}")

    joblib.dump(label_encoder, results_path / 'label_encoder.pkl')
    print(f"✓ Saved label encoder to {results_path / 'label_encoder.pkl'}")

    # Save feature information
    feature_info = {
        'all_features': feature_names,
        'selected_features': selected_features,
        'n_features': len(feature_names),
        'n_selected': len(selected_features)
    }

    with open(results_path / 'feature_info.json', 'w') as f:
        json.dump(feature_info, f, indent=4)
    print(f"✓ Saved feature info to {results_path / 'feature_info.json'}")

    # Save feature scores
    feature_scores.to_csv(results_path / 'feature_scores.csv', index=False)
    print(f"✓ Saved feature scores to {results_path / 'feature_scores.csv'}")


def main():
    """Main execution pipeline"""
    print("\n")
    print("="*70)
    print(" " * 10 + "MULTICLASS CLASSIFICATION PIPELINE")
    print(" " * 15 + "Target: tfopwg_disp")
    print("="*70)
    print("\n")

    # Configuration
    DATA_FILE = 'TOI_2025.10.03_22.38.38.csv'
    RESULTS_DIR = 'results'

    # Create results directory
    Path(RESULTS_DIR).mkdir(exist_ok=True)

    # ========================
    # 1. DATA LOADING
    # ========================
    processor = DataProcessor(DATA_FILE)
    df = processor.load_data()
    missing_df, numeric_cols, categorical_cols = processor.analyze_data_quality()
    X, y, feature_cols = processor.prepare_features_target()

    # ========================
    # 2. FEATURE ENGINEERING
    # ========================
    engineer = FeatureEngineer(X, y)
    X = engineer.handle_missing_values(strategy='advanced')
    X = engineer.create_features()
    X, y = engineer.remove_outliers(contamination=0.05)
    X = engineer.scale_features(method='robust')
    X, feature_scores = engineer.select_features(k=60)

    # ========================
    # 3. TRAIN-TEST SPLIT
    # ========================
    print("\n" + "="*70)
    print("TRAIN-TEST SPLIT")
    print("="*70)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
    )

    print(f"\n✓ Training set: {X_train.shape[0]} samples")
    print(f"✓ Test set: {X_test.shape[0]} samples")
    print(f"✓ Feature dimensions: {X_train.shape[1]}")

    # ========================
    # 4. MODEL BUILDING & TRAINING
    # ========================
    builder = ModelBuilder(X_train, X_test, y_train, y_test)
    models = builder.build_models()
    predictions = builder.train_models()
    ensemble_model = builder.create_ensemble()
    cv_results = builder.cross_validate(cv_folds=5)

    # ========================
    # 5. MODEL EVALUATION
    # ========================
    evaluator = ModelEvaluator(
        models_dict=builder.models,
        predictions_dict=builder.predictions,
        probabilities_dict=builder.probabilities,
        y_test=y_test,
        label_encoder=processor.label_encoder,
        results_dir=RESULTS_DIR
    )

    metrics = evaluator.calculate_metrics()
    evaluator.plot_model_comparison()
    evaluator.plot_confusion_matrices()
    evaluator.plot_classification_reports()
    evaluator.plot_feature_importance(feature_names=engineer.selected_features, top_n=20)
    evaluator.plot_learning_curves(X_train, y_train)
    evaluator.plot_roc_curves()
    evaluator.plot_prediction_distribution()

    # ========================
    # 6. SAVE ARTIFACTS
    # ========================
    save_artifacts(
        results_dir=RESULTS_DIR,
        models=builder.models,
        scaler=engineer.scaler,
        imputer=engineer.imputer,
        label_encoder=processor.label_encoder,
        feature_names=feature_cols,
        selected_features=engineer.selected_features,
        feature_scores=feature_scores
    )

    # ========================
    # 7. FINAL SUMMARY
    # ========================
    print("\n" + "="*70)
    print("PIPELINE COMPLETED SUCCESSFULLY!")
    print("="*70)

    print("\n📊 FINAL MODEL RANKINGS:")
    rankings = sorted(metrics.items(), key=lambda x: x[1]['Accuracy'], reverse=True)
    for rank, (name, scores) in enumerate(rankings, 1):
        print(f"{rank}. {name:20s} - Accuracy: {scores['Accuracy']:.4f}, F1: {scores['F1-Score']:.4f}")

    print(f"\n✓ All results saved to '{RESULTS_DIR}/' directory")
    print(f"✓ Models, scalers, and artifacts ready for production deployment")

    print("\n" + "="*70)
    print()


if __name__ == "__main__":
    main()



          MULTICLASS CLASSIFICATION PIPELINE
               Target: tfopwg_disp


LOADING DATA

✓ Data loaded successfully: 7703 rows, 65 columns

📊 Target Variable Distribution (tfopwg_disp):
tfopwg_disp
PC     4679
FP     1197
CP      684
KP      583
APC     462
FA       98
Name: count, dtype: int64

Target Class Proportions:
tfopwg_disp
PC     0.6074
FP     0.1554
CP     0.0888
KP     0.0757
APC    0.0600
FA     0.0127
Name: proportion, dtype: float64

DATA QUALITY ANALYSIS

📋 Columns with Missing Values:
              Missing_Count  Missing_Percentage
pl_insolerr2           7703              100.00
pl_eqterr1             7703              100.00
pl_insollim            7703              100.00
pl_eqterr2             7703              100.00
pl_eqtlim              7703              100.00
pl_insolerr1           7703              100.00
st_loggerr1            2271               29.48
st_loggerr2            2271               29.48
st_raderr1             1963               25.48
st_r