In [None]:
# ==================== Section 5: Complete Training with Random Forest ====================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (accuracy_score, f1_score, roc_auc_score,
                           confusion_matrix, precision_score, recall_score,
                           cohen_kappa_score, matthews_corrcoef, roc_curve,
                           precision_recall_curve, average_precision_score)
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import RFECV
import warnings
warnings.filterwarnings('ignore')
import os
import joblib
import json

print("Complete Training with Random Forest - ECF-TST vs Global Model")
print("="*80)

# ==================== 1. Load and Process Real Data ====================
print("Step 1: Loading and processing real data...")

from google.colab import drive
drive.mount('/content/drive')
data_path = '/content/drive/MyDrive/merged_data_by_year).csv'

# Load data
df = pd.read_csv(data_path)
print(f"Data loaded successfully: {df.shape[0]} rows, {df.shape[1]} columns")

# Display basic information
print("\nData basic information:")
print(df.info())

# Check for missing values
print("\nChecking for missing values...")
missing_counts = df.isnull().sum()
missing_pct = (missing_counts / len(df)) * 100
missing_summary = pd.DataFrame({
    'Missing Count': missing_counts,
    'Missing Percentage': missing_pct
})

# Show only columns with missing values
missing_summary = missing_summary[missing_summary['Missing Count'] > 0]
if not missing_summary.empty:
    print("Columns with missing values:")
    print(missing_summary)
else:
    print("No missing values found")

# Handle missing values: DELETE rows with missing values
print("\nHandling missing values...")
original_shape = df.shape
df = df.dropna()
rows_removed = original_shape[0] - df.shape[0]
print(f"  Rows with missing values removed: {rows_removed}")
print(f"  New data shape: {df.shape[0]} rows, {df.shape[1]} columns")

# Check class distribution
print("\nChecking class distribution...")
class_dist = df['b1'].value_counts()
print(f"Class distribution (b1):")
print(class_dist)
print(f"Class ratio (0:stable / 1:disturbance): {class_dist[0]/len(df)*100:.1f}% / {class_dist[1]/len(df)*100:.1f}%")

# Check ecoregion distribution
print("\nChecking ecoregion distribution...")
ecoregion_counts = df['ID'].value_counts()
print(f"Total ecoregions: {len(ecoregion_counts)}")
print("\nTop 10 ecoregions by sample count:")
print(ecoregion_counts.head(10))

# ==================== 2. Define Complete Feature Set ====================
print("\nStep 2: Defining complete feature set...")

# Define feature groups based on paper description
feature_groups = {
    'Remote Sensing Features': [
        'NBR', 'NDVI',
        'NBR_con_texture', 'NBR_cor_texture', 'NBR_ent_texture',
        'NDVI_con_texture', 'NDVI_cor_texture', 'NDVI_ent_texture',
        'NBR_rol_3y_temporal', 'NBR_rol_5y_temporal', 'NBR_vola_5y_temporal',
        'NDVI_rol_3y_temporal', 'NDVI_rol_5y_temporal', 'NDVI_vola_5y_temporal'
    ],

    'Topographic Features': ['aspect', 'elevation', 'slope', 'tpi'],

    'Climate Features': [
        'annual_precip', 'annual_temp', 'prev_year_precip',
        'summer_temp', 'temp_anomaly'
    ],

    'Socioeconomic Features': ['gdp', 'population'],

    'Landcover': ['landcover']
}

# Combine all features into a single list
feature_cols = []
for group_name, features in feature_groups.items():
    feature_cols.extend(features)
    print(f"  {group_name}: {len(features)} features")

print(f"\nTotal: {len(feature_cols)} features")

# Check which features don't exist in the data
missing_features = [f for f in feature_cols if f not in df.columns]
if missing_features:
    print(f"\nWarning: The following {len(missing_features)} features are not in the data:")
    for f in missing_features:
        print(f"  - {f}")

    # Remove non-existent features from the feature list
    feature_cols = [f for f in feature_cols if f in df.columns]
    print(f"\nUsing {len(feature_cols)} available features")
else:
    print("\n✓ All features exist in the data")

# ==================== 3. Training Class ====================
print("\nStep 3: Defining training class...")

class EcoregionModelTrainer:
    """Ecoregion trainer with Random Forest"""

    def __init__(self, test_size=0.2, random_state=42, n_jobs=-1):
        self.test_size = test_size
        self.random_state = random_state
        self.n_jobs = n_jobs
        self.models = {}
        self.results = {}
        self.summary_df = None
        self.global_model = None
        self.global_results = {}

    def prepare_data(self, df_region, feature_cols, target_col='b1'):
        """Prepare data for training"""
        X = df_region[feature_cols].copy()
        y = df_region[target_col].copy()

        # Check and handle any remaining missing values
        missing_cols = X.columns[X.isnull().any()].tolist()
        if missing_cols:
            print(f"    Handling missing values in {len(missing_cols)} features...")
            for col in missing_cols:
                median_val = X[col].median()
                X[col].fillna(median_val, inplace=True)

        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        return X_scaled, y, scaler, X.columns.tolist()

    def calculate_comprehensive_metrics(self, y_true, y_pred, y_pred_proba=None):
        """Calculate comprehensive evaluation metrics"""
        metrics = {}

        # Basic metrics
        metrics['accuracy'] = accuracy_score(y_true, y_pred)
        metrics['precision'] = precision_score(y_true, y_pred, zero_division=0)
        metrics['recall'] = recall_score(y_true, y_pred, zero_division=0)
        metrics['f1_score'] = f1_score(y_true, y_pred, zero_division=0)
        metrics['kappa'] = cohen_kappa_score(y_true, y_pred)
        metrics['mcc'] = matthews_corrcoef(y_true, y_pred)

        # AUC-ROC
        if y_pred_proba is not None and len(np.unique(y_true)) > 1:
            metrics['auc_roc'] = roc_auc_score(y_true, y_pred_proba)
            # PR curve related metrics
            metrics['average_precision'] = average_precision_score(y_true, y_pred_proba)
        else:
            metrics['auc_roc'] = 0.5
            metrics['average_precision'] = 0.5

        # Confusion matrix related metrics
        cm = confusion_matrix(y_true, y_pred)
        if cm.shape == (2, 2):
            TN, FP, FN, TP = cm.ravel()

            # Producer's Accuracy = 1 - Omission Error
            metrics['pa_stable'] = TN / (TN + FP) if (TN + FP) > 0 else 0  # Stable class
            metrics['pa_disturbance'] = TP / (TP + FN) if (TP + FN) > 0 else 0  # Disturbance class

            # User's Accuracy = 1 - Commission Error
            metrics['ua_stable'] = TN / (TN + FN) if (TN + FN) > 0 else 0  # Stable class
            metrics['ua_disturbance'] = TP / (TP + FP) if (TP + FP) > 0 else 0  # Disturbance class

            # Omission Error
            metrics['oe_stable'] = 1 - metrics['pa_stable'] if 'pa_stable' in metrics else None
            metrics['oe_disturbance'] = 1 - metrics['pa_disturbance'] if 'pa_disturbance' in metrics else None

            # Commission Error
            metrics['ce_stable'] = 1 - metrics['ua_stable'] if 'ua_stable' in metrics else None
            metrics['ce_disturbance'] = 1 - metrics['ua_disturbance'] if 'ua_disturbance' in metrics else None

            # Mean accuracies
            metrics['mean_pa'] = (metrics['pa_stable'] + metrics['pa_disturbance']) / 2
            metrics['mean_ua'] = (metrics['ua_stable'] + metrics['ua_disturbance']) / 2

            # Balanced accuracy
            metrics['balanced_accuracy'] = metrics['mean_pa']

            # Add confusion matrix values
            metrics['TN'] = TN
            metrics['FP'] = FP
            metrics['FN'] = FN
            metrics['TP'] = TP

        return metrics, cm

    def train_ecoregion_model(self, df, ecoregion_id, feature_cols):
        """Train Random Forest model for a single ecoregion"""
        print(f"\n{'='*70}")
        print(f"ECF-TST Training: Ecoregion {ecoregion_id}")
        print('='*70)

        try:
            # Filter data for the ecoregion
            df_region = df[df['ID'] == ecoregion_id].copy()

            print(f"  Sample count: {len(df_region)}")
            print(f"  Class distribution: {df_region['b1'].value_counts().to_dict()}")

            # Prepare data
            X, y, scaler, feature_names = self.prepare_data(df_region, feature_cols)

            # Split data into train and test sets (80/20)
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=self.test_size, random_state=self.random_state, stratify=y
            )

            print(f"  Training set: {X_train.shape}, Test set: {X_test.shape}")

            # Train Random Forest model (no SMOTE, no feature selection for simplicity)
            print("    Training Random Forest model...")
            model = RandomForestClassifier(
                n_estimators=100,
                max_depth=None,
                min_samples_split=2,
                min_samples_leaf=1,
                max_features='sqrt',
                random_state=self.random_state,
                n_jobs=self.n_jobs,
                verbose=0
            )

            model.fit(X_train, y_train)

            # Predict and evaluate
            y_pred = model.predict(X_test)
            y_pred_proba = model.predict_proba(X_test)[:, 1]

            # Calculate comprehensive metrics
            metrics, cm = self.calculate_comprehensive_metrics(y_test, y_pred, y_pred_proba)

            print(f"\n  Test set performance:")
            print(f"    Accuracy: {metrics['accuracy']:.4f}")
            print(f"    F1 Score: {metrics['f1_score']:.4f}")
            print(f"    AUC-ROC: {metrics['auc_roc']:.4f}")
            print(f"    Producer's Accuracy (stable): {metrics.get('pa_stable', 0):.4f}")
            print(f"    Producer's Accuracy (disturbance): {metrics.get('pa_disturbance', 0):.4f}")

            # Save results
            result = {
                'model': model,
                'scaler': scaler,
                'metrics': metrics,
                'confusion_matrix': cm,
                'test_size': len(y_test),
                'train_size': len(y_train),
                'X_test': X_test,
                'y_test': y_test,
                'y_pred': y_pred,
                'y_pred_proba': y_pred_proba,
                'feature_names': feature_names
            }

            return result

        except Exception as e:
            print(f"  Error training ecoregion {ecoregion_id}: {str(e)}")
            import traceback
            traceback.print_exc()
            return None

    def train_ecf_tst_models(self, df, feature_cols, min_samples=1000):
        """Train ECF-TST models for all ecoregions"""
        # Sort ecoregions by sample count
        ecoregion_counts = df['ID'].value_counts()

        # Filter ecoregions with sufficient samples
        eligible_ecoregions = ecoregion_counts[ecoregion_counts >= min_samples].index.tolist()

        print(f"\nEcoregions meeting minimum sample requirement ({min_samples}): {len(eligible_ecoregions)}")
        print(f"Training ALL eligible ecoregions...")

        for i, ecoregion_id in enumerate(eligible_ecoregions, 1):
            print(f"\n{'='*70}")
            print(f"Progress: {i}/{len(eligible_ecoregions)} - Ecoregion: {ecoregion_id}")
            print(f"Sample count: {ecoregion_counts[ecoregion_id]}")
            print('='*70)

            result = self.train_ecoregion_model(df, ecoregion_id, feature_cols)

            if result:
                self.models[ecoregion_id] = result['model']
                self.results[ecoregion_id] = result
                print(f"✓ Ecoregion {ecoregion_id} training completed")
            else:
                print(f"✗ Ecoregion {ecoregion_id} training failed")

        # Summarize results
        if self.results:
            self.summarize_ecf_tst_results()

        return self.models, self.results

    def train_global_model(self, df, feature_cols, test_size=0.2):
        """Train global model using all data"""
        print("\n" + "="*70)
        print("Training Global Model")
        print("="*70)

        # Prepare all data
        X_all, y_all, scaler, feature_names = self.prepare_data(df, feature_cols)

        # Split all data into train and test (80/20)
        X_train_all, X_test_all, y_train_all, y_test_all = train_test_split(
            X_all, y_all, test_size=test_size, random_state=self.random_state, stratify=y_all
        )

        print(f"Global training set: {X_train_all.shape}")
        print(f"Global test set: {X_test_all.shape}")
        print(f"Class distribution in training: {np.bincount(y_train_all)}")

        # Train global Random Forest model
        print("\nTraining Global Random Forest model...")
        global_model = RandomForestClassifier(
            n_estimators=100,
            max_depth=None,
            min_samples_split=2,
            min_samples_leaf=1,
            max_features='sqrt',
            random_state=self.random_state,
            n_jobs=self.n_jobs,
            verbose=0
        )

        global_model.fit(X_train_all, y_train_all)

        # Evaluate on global test set
        y_pred_all = global_model.predict(X_test_all)
        y_pred_proba_all = global_model.predict_proba(X_test_all)[:, 1]

        # Calculate comprehensive metrics for global test
        global_metrics, global_cm = self.calculate_comprehensive_metrics(y_test_all, y_pred_all, y_pred_proba_all)

        print(f"\nGlobal model performance on global test set:")
        print(f"  Accuracy: {global_metrics['accuracy']:.4f}")
        print(f"  F1 Score: {global_metrics['f1_score']:.4f}")
        print(f"  AUC-ROC: {global_metrics['auc_roc']:.4f}")

        # Now evaluate global model on each ecoregion's test set
        print("\n" + "="*70)
        print("Evaluating Global Model on Each Ecoregion's Test Set")
        print("="*70)

        # We need to collect each ecoregion's test data
        for ecoregion_id, result in self.results.items():
            # Get the test data for this ecoregion from ECF-TST training
            X_test_eco = result['X_test']
            y_test_eco = result['y_test']

            # Predict using global model
            y_pred_eco = global_model.predict(X_test_eco)
            y_pred_proba_eco = global_model.predict_proba(X_test_eco)[:, 1]

            # Calculate metrics
            eco_metrics, eco_cm = self.calculate_comprehensive_metrics(y_test_eco, y_pred_eco, y_pred_proba_eco)

            # Store results
            self.global_results[ecoregion_id] = {
                'metrics': eco_metrics,
                'confusion_matrix': eco_cm,
                'test_size': len(y_test_eco)
            }

            print(f"\nEcoregion {ecoregion_id}:")
            print(f"  Accuracy: {eco_metrics['accuracy']:.4f}")
            print(f"  F1 Score: {eco_metrics['f1_score']:.4f}")

        # Calculate average performance across all ecoregions
        if self.global_results:
            accuracies = [r['metrics']['accuracy'] for r in self.global_results.values()]
            f1_scores = [r['metrics']['f1_score'] for r in self.global_results.values()]
            auc_scores = [r['metrics']['auc_roc'] for r in self.global_results.values()]

            print(f"\nGlobal model average performance across ecoregions:")
            print(f"  Mean Accuracy: {np.mean(accuracies):.4f} (±{np.std(accuracies):.4f})")
            print(f"  Mean F1 Score: {np.mean(f1_scores):.4f} (±{np.std(f1_scores):.4f})")
            print(f"  Mean AUC-ROC: {np.mean(auc_scores):.4f} (±{np.std(auc_scores):.4f})")

        self.global_model = global_model
        return global_model, self.global_results

    def summarize_ecf_tst_results(self):
        """Summarize results from all ecoregions for ECF-TST model"""
        print("\n" + "="*70)
        print("Summary of ECF-TST Model Results")
        print("="*70)

        summary_data = []
        for ecoregion_id, result in self.results.items():
            metrics = result['metrics']
            summary_data.append({
                'EcoregionID': ecoregion_id,
                'TrainSamples': result['train_size'],
                'TestSamples': result['test_size'],
                'Accuracy': metrics['accuracy'],
                'F1Score': metrics['f1_score'],
                'AUC_ROC': metrics['auc_roc'],
                'Kappa': metrics['kappa'],
                'MCC': metrics['mcc'],
                'PA_Stable': metrics.get('pa_stable', 0),
                'PA_Disturbance': metrics.get('pa_disturbance', 0),
                'UA_Stable': metrics.get('ua_stable', 0),
                'UA_Disturbance': metrics.get('ua_disturbance', 0),
                'OE_Stable': metrics.get('oe_stable', 0),
                'OE_Disturbance': metrics.get('oe_disturbance', 0),
                'CE_Stable': metrics.get('ce_stable', 0),
                'CE_Disturbance': metrics.get('ce_disturbance', 0),
                'Balanced_Accuracy': metrics.get('balanced_accuracy', 0),
                'TN': metrics.get('TN', 0),
                'FP': metrics.get('FP', 0),
                'FN': metrics.get('FN', 0),
                'TP': metrics.get('TP', 0)
            })

        self.summary_df = pd.DataFrame(summary_data)

        print("\nPerformance by Ecoregion:")
        print(self.summary_df[['EcoregionID', 'Accuracy', 'F1Score', 'AUC_ROC', 'Kappa', 'Balanced_Accuracy']].to_string(index=False))

        # Statistical analysis
        print(f"\nStatistical Analysis:")
        print(f"Mean Accuracy: {self.summary_df['Accuracy'].mean():.4f} (±{self.summary_df['Accuracy'].std():.4f})")
        print(f"Mean F1 Score: {self.summary_df['F1Score'].mean():.4f} (±{self.summary_df['F1Score'].std():.4f})")
        print(f"Mean AUC-ROC: {self.summary_df['AUC_ROC'].mean():.4f} (±{self.summary_df['AUC_ROC'].std():.4f})")
        print(f"Mean Kappa: {self.summary_df['Kappa'].mean():.4f} (±{self.summary_df['Kappa'].std():.4f})")
        print(f"Mean Balanced Accuracy: {self.summary_df['Balanced_Accuracy'].mean():.4f} (±{self.summary_df['Balanced_Accuracy'].std():.4f})")

        return self.summary_df

    def create_table1_comparison(self):
        """Create Table 1: Comparison between ECF-TST and Global models"""
        print("\n" + "="*70)
        print("Table 1: Performance Comparison")
        print("="*70)

        # Calculate means and standard deviations for ECF-TST
        ecf_mean_accuracy = self.summary_df['Accuracy'].mean()
        ecf_std_accuracy = self.summary_df['Accuracy'].std()
        ecf_mean_f1 = self.summary_df['F1Score'].mean()
        ecf_std_f1 = self.summary_df['F1Score'].std()
        ecf_mean_kappa = self.summary_df['Kappa'].mean()
        ecf_std_kappa = self.summary_df['Kappa'].std()
        ecf_mean_pa = self.summary_df['Balanced_Accuracy'].mean()
        ecf_std_pa = self.summary_df['Balanced_Accuracy'].std()

        # Calculate means and standard deviations for Global model
        if self.global_results:
            global_accuracies = [r['metrics']['accuracy'] for r in self.global_results.values()]
            global_f1s = [r['metrics']['f1_score'] for r in self.global_results.values()]
            global_kappas = [r['metrics']['kappa'] for r in self.global_results.values()]
            global_pas = [r['metrics'].get('balanced_accuracy', 0.5) for r in self.global_results.values()]

            global_mean_accuracy = np.mean(global_accuracies)
            global_std_accuracy = np.std(global_accuracies)
            global_mean_f1 = np.mean(global_f1s)
            global_std_f1 = np.std(global_f1s)
            global_mean_kappa = np.mean(global_kappas)
            global_std_kappa = np.std(global_kappas)
            global_mean_pa = np.mean(global_pas)
            global_std_pa = np.std(global_pas)

            # Calculate differences
            diff_accuracy = ecf_mean_accuracy - global_mean_accuracy
            diff_f1 = ecf_mean_f1 - global_mean_f1
            diff_kappa = ecf_mean_kappa - global_mean_kappa
            diff_pa = ecf_mean_pa - global_mean_pa

            # Create Table 1
            table1 = pd.DataFrame({
                'Metric': ['Overall Accuracy (OA)', 'F1 Score', "Cohen's Kappa", 'Balanced Accuracy (Producer\'s)'],
                'ECF-TST Model': [
                    f'{ecf_mean_accuracy:.3f} (±{ecf_std_accuracy:.3f})',
                    f'{ecf_mean_f1:.3f} (±{ecf_std_f1:.3f})',
                    f'{ecf_mean_kappa:.3f} (±{ecf_std_kappa:.3f})',
                    f'{ecf_mean_pa:.3f} (±{ecf_std_pa:.3f})'
                ],
                'Global Model': [
                    f'{global_mean_accuracy:.3f} (±{global_std_accuracy:.3f})',
                    f'{global_mean_f1:.3f} (±{global_std_f1:.3f})',
                    f'{global_mean_kappa:.3f} (±{global_std_kappa:.3f})',
                    f'{global_mean_pa:.3f} (±{global_std_pa:.3f})'
                ],
                'Improvement': [
                    f'+{diff_accuracy:.3f}',
                    f'+{diff_f1:.3f}',
                    f'+{diff_kappa:.3f}',
                    f'+{diff_pa:.3f}'
                ]
            })

            print("\nTable 1: Performance comparison between ECF-TST and Global models")
            print(table1.to_string(index=False))

            return table1

        return None

    def export_supplementary_tables(self, output_path='/content/drive/MyDrive/ECF_TST_Results'):
        """Export all supplementary tables"""
        import os
        os.makedirs(output_path, exist_ok=True)

        print(f"\nExporting supplementary tables to: {output_path}")

        # 1. Supplementary Table S3: ECF-TST完整混淆矩阵汇总表（各生态区）
        if self.summary_df is not None:
            table_s3 = self.summary_df[['EcoregionID', 'TN', 'FP', 'FN', 'TP',
                                       'PA_Stable', 'PA_Disturbance',
                                       'UA_Stable', 'UA_Disturbance',
                                       'OE_Stable', 'OE_Disturbance',
                                       'CE_Stable', 'CE_Disturbance']].copy()
            table_s3_path = f'{output_path}/Supplementary_Table_S3_ECF_TST_Confusion_Matrix.csv'
            table_s3.to_csv(table_s3_path, index=False, encoding='utf-8-sig')
            print(f"✓ Supplementary Table S3 saved: {table_s3_path}")

        # 2. Supplementary Table S4: 全局模型完整混淆矩阵汇总表（各生态区）
        if self.global_results:
            global_confusion_data = []
            for ecoregion_id, result in self.global_results.items():
                metrics = result['metrics']
                cm = result['confusion_matrix']

                if cm.shape == (2, 2):
                    TN, FP, FN, TP = cm.ravel()
                else:
                    TN, FP, FN, TP = 0, 0, 0, 0

                global_confusion_data.append({
                    'EcoregionID': ecoregion_id,
                    'TN': TN,
                    'FP': FP,
                    'FN': FN,
                    'TP': TP,
                    'PA_Stable': metrics.get('pa_stable', 0),
                    'PA_Disturbance': metrics.get('pa_disturbance', 0),
                    'UA_Stable': metrics.get('ua_stable', 0),
                    'UA_Disturbance': metrics.get('ua_disturbance', 0),
                    'OE_Stable': metrics.get('oe_stable', 0),
                    'OE_Disturbance': metrics.get('oe_disturbance', 0),
                    'CE_Stable': metrics.get('ce_stable', 0),
                    'CE_Disturbance': metrics.get('ce_disturbance', 0)
                })

            table_s4 = pd.DataFrame(global_confusion_data)
            table_s4_path = f'{output_path}/Supplementary_Table_S4_Global_Confusion_Matrix.csv'
            table_s4.to_csv(table_s4_path, index=False, encoding='utf-8-sig')
            print(f"✓ Supplementary Table S4 saved: {table_s4_path}")

        # 3. Supplementary Table S5: 各生态区详细精度指标表（ECF-TST模型）
        if self.summary_df is not None:
            table_s5 = self.summary_df[['EcoregionID', 'Accuracy', 'F1Score', 'AUC_ROC', 'Kappa', 'MCC',
                                       'PA_Stable', 'PA_Disturbance', 'UA_Stable', 'UA_Disturbance',
                                       'Balanced_Accuracy', 'TestSamples']].copy()
            table_s5_path = f'{output_path}/Supplementary_Table_S5_ECF_TST_Detailed_Metrics.csv'
            table_s5.to_csv(table_s5_path, index=False, encoding='utf-8-sig')
            print(f"✓ Supplementary Table S5 saved: {table_s5_path}")

        # 4. Supplementary Table S6: 各生态区详细精度指标表（全局模型）
        if self.global_results:
            global_metrics_data = []
            for ecoregion_id, result in self.global_results.items():
                metrics = result['metrics']
                global_metrics_data.append({
                    'EcoregionID': ecoregion_id,
                    'Accuracy': metrics['accuracy'],
                    'F1Score': metrics['f1_score'],
                    'AUC_ROC': metrics['auc_roc'],
                    'Kappa': metrics['kappa'],
                    'MCC': metrics['mcc'],
                    'PA_Stable': metrics.get('pa_stable', 0),
                    'PA_Disturbance': metrics.get('pa_disturbance', 0),
                    'UA_Stable': metrics.get('ua_stable', 0),
                    'UA_Disturbance': metrics.get('ua_disturbance', 0),
                    'Balanced_Accuracy': metrics.get('balanced_accuracy', 0),
                    'TestSamples': result['test_size']
                })

            table_s6 = pd.DataFrame(global_metrics_data)
            table_s6_path = f'{output_path}/Supplementary_Table_S6_Global_Detailed_Metrics.csv'
            table_s6.to_csv(table_s6_path, index=False, encoding='utf-8-sig')
            print(f"✓ Supplementary Table S6 saved: {table_s6_path}")

        print(f"\nAll supplementary tables exported successfully!")

# ==================== 4. Execute Training ====================
print("\nStep 4: Executing complete training...")

# Create trainer
trainer = EcoregionModelTrainer(
    test_size=0.2,  # 80% training, 20% testing
    random_state=42,
    n_jobs=-1
)

# Train ECF-TST models (separate model for each ecoregion)
print("\n" + "="*80)
print("Training ECF-TST Models (Ecoregion-specific)")
print("="*80)

ecf_tst_models, ecf_tst_results = trainer.train_ecf_tst_models(
    df,
    feature_cols,
    min_samples=1000  # Minimum samples required for training
)

# Train Global model and evaluate on each ecoregion
print("\n" + "="*80)
print("Training and Evaluating Global Model")
print("="*80)

global_model, global_results = trainer.train_global_model(
    df,
    feature_cols,
    test_size=0.2
)

# ==================== 5. Generate Final Results ====================
print("\n" + "="*80)
print("Final Results Summary")
print("="*80)

# Create Table 1
table1 = trainer.create_table1_comparison()

# Export supplementary tables
output_path = '/content/drive/MyDrive/ECF_TST_Results_Revised'
trainer.export_supplementary_tables(output_path)

# Summary statistics
print("\n" + "="*80)
print("Key Findings Summary")
print("="*80)

if trainer.summary_df is not None and trainer.global_results:
    # ECF-TST performance
    ecf_acc_mean = trainer.summary_df['Accuracy'].mean()
    ecf_acc_std = trainer.summary_df['Accuracy'].std()
    ecf_f1_mean = trainer.summary_df['F1Score'].mean()
    ecf_f1_std = trainer.summary_df['F1Score'].std()

    # Global model performance
    global_acc_values = [r['metrics']['accuracy'] for r in trainer.global_results.values()]
    global_f1_values = [r['metrics']['f1_score'] for r in trainer.global_results.values()]

    global_acc_mean = np.mean(global_acc_values)
    global_acc_std = np.std(global_acc_values)
    global_f1_mean = np.mean(global_f1_values)
    global_f1_std = np.std(global_f1_values)

    print(f"\nECF-TST Framework Performance:")
    print(f"  • Mean Accuracy: {ecf_acc_mean:.3f} (±{ecf_acc_std:.3f})")
    print(f"  • Mean F1 Score: {ecf_f1_mean:.3f} (±{ecf_f1_std:.3f})")
    print(f"  • Trained Ecoregions: {len(trainer.results)}")

    print(f"\nGlobal Model Performance:")
    print(f"  • Mean Accuracy: {global_acc_mean:.3f} (±{global_acc_std:.3f})")
    print(f"  • Mean F1 Score: {global_f1_mean:.3f} (±{global_f1_std:.3f})")

    print(f"\nECF-TST Improvement over Global Model:")
    print(f"  • Accuracy Improvement: +{(ecf_acc_mean - global_acc_mean):.3f}")
    print(f"  • F1 Score Improvement: +{(ecf_f1_mean - global_f1_mean):.3f}")

    # Calculate percentage improvement
    acc_improvement_pct = ((ecf_acc_mean - global_acc_mean) / global_acc_mean) * 100
    f1_improvement_pct = ((ecf_f1_mean - global_f1_mean) / global_f1_mean) * 100

    print(f"  • Accuracy Improvement (%): +{acc_improvement_pct:.1f}%")
    print(f"  • F1 Score Improvement (%): +{f1_improvement_pct:.1f}%")

print("\n" + "="*80)
print("Training Completed Successfully!")
print("="*80)
print("\nGenerated output files:")
print(f"1. Table 1 (printed above)")
print(f"2. Supplementary Table S3: ECF-TST Confusion Matrix Summary")
print(f"3. Supplementary Table S4: Global Model Confusion Matrix Summary")
print(f"4. Supplementary Table S5: ECF-TST Detailed Metrics")
print(f"5. Supplementary Table S6: Global Model Detailed Metrics")
print(f"\nAll files saved to: {output_path}")