In [1]:
# -*- coding: utf-8 -*-
"""
EBM Training, Evaluation, and Interpretation for Coastal Wave Forecasting

This script executes Phases 2-4 of the research plan for predicting
significant wave height (buoy_main_hs) using an Explainable Boosting Machine (EBM).

Phase 2: Hyperparameter Optimization with Nested, Blocked Cross-Validation.
Phase 3: Final Model Training, Saving, and Out-of-Sample (OOS) Evaluation.
Phase 4: Interpretation and Visualization for Publication.

Environment: Google Colab
Libraries: pandas, numpy, scikit-learn, interpret, optuna, pickle, matplotlib, seaborn
"""

# =============================================================================
# Step 1: Setup and Data Preparation
# =============================================================================
print("--- Step 1: Setup and Data Preparation ---")

!pip install optuna interpret

# --- 1.1. Imports and Setup ---
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import optuna

from google.colab import drive
from sklearn.model_selection import cross_val_score
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection._split import BaseCrossValidator
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show

# Suppress Optuna's informational messages for cleaner output
optuna.logging.set_verbosity(optuna.logging.WARNING)

# --- 1.2. Mount Google Drive ---
print("Mounting Google Drive...")
try:
    drive.mount('/content/drive', force_remount=True)
    print("Google Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")


# --- 1.3. Plotting Style Configuration ---
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['figure.dpi'] = 450

# --- 1.4. Paths and Directories ---
PROJECT_ROOT = '/content/drive/My Drive/Paper_3_New'

INPUT_FILE_PATH = os.path.join(
    PROJECT_ROOT,
    'Outputs/Feature_Engineering_v1/final_engineered_features_v3.csv'
)

OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'Outputs/Modeling_v1/EBM')

print(f"Output directory will be: {OUTPUT_DIR}")
os.makedirs(OUTPUT_DIR, exist_ok=True)


# --- 1.5. Data Loading and Preparation ---
def load_and_prepare_data(filepath):
    """
    Loads the engineered features, parses datetime, and splits the data.
    """
    print(f"Attempting to load data from: {filepath}")
    try:
        df = pd.read_csv(filepath)
    except FileNotFoundError:
        print(f"\nCRITICAL ERROR: Input file not found at the specified path.")
        print("Please ensure the feature engineering script has run successfully.")
        return None, None

    df['time'] = pd.to_datetime(df['time'])
    df.set_index('time', inplace=True)

    # --- FIX: Corrected the values to match the feature engineering script's output ---
    df_train_val = df[df['split'] == 'Train_Val'].copy()
    df_oos = df[df['split'] == 'OOS'].copy()

    # Drop the 'split' column as it's no longer needed
    df_train_val.drop(columns=['split'], inplace=True)
    df_oos.drop(columns=['split'], inplace=True)

    print(f"Data loaded successfully.")
    print(f"Training/Validation set shape: {df_train_val.shape}")
    print(f"Out-of-Sample (OOS) set shape: {df_oos.shape}")

    return df_train_val, df_oos

df_train_val, df_oos = load_and_prepare_data(INPUT_FILE_PATH)

# --- 1.6. Feature and Target Separation and Main Execution Block ---
if df_train_val is not None and df_oos is not None and not df_train_val.empty:
    TARGET = 'buoy_main_hs'

    # Identify feature columns by excluding ALL columns that start with 'buoy_main_'
    feature_cols = [col for col in df_train_val.columns if not col.startswith('buoy_main_')]

    print(f"Target variable: {TARGET}")
    print(f"Number of features: {len(feature_cols)}")

    X_train_val = df_train_val[feature_cols]
    y_train_val = df_train_val[TARGET]

    X_oos = df_oos[feature_cols]
    y_oos = df_oos[TARGET]

    # =============================================================================
    # Step 2: Implement the Nested, Blocked Cross-Validation Framework
    # =============================================================================
    print("\n--- Step 2: Implementing Blocked Time Series CV ---")

    class BlockedTimeSeriesSplit(BaseCrossValidator):
        """
        Custom time series cross-validator that adds a 'gap' between train and test sets.
        """
        def __init__(self, n_splits=5, gap=24):
            self.n_splits = n_splits
            self.gap = gap

        def get_n_splits(self, X=None, y=None, groups=None):
            return self.n_splits

        def split(self, X, y=None, groups=None):
            n_samples = len(X)
            k_fold_size = n_samples // self.n_splits
            indices = np.arange(n_samples)

            for i in range(self.n_splits):
                start = i * k_fold_size
                stop = start + k_fold_size
                mid = int(0.8 * (stop - start)) + start

                if i < self.n_splits - 1:
                    train_indices = indices[start:mid]
                    test_indices = indices[mid + self.gap:stop]
                else:
                    train_indices = indices[start:mid]
                    test_indices = indices[mid + self.gap:]

                yield train_indices, test_indices

    outer_cv = BlockedTimeSeriesSplit(n_splits=5, gap=24)
    inner_cv = BlockedTimeSeriesSplit(n_splits=3, gap=24)
    print("BlockedTimeSeriesSplit class defined and CV splitters instantiated.")


    # =============================================================================
    # Step 3: Hyperparameter Optimization with Nested CV
    # =============================================================================
    print("\n--- Step 3: Hyperparameter Optimization with Nested CV ---")
    print("This step will take a significant amount of time.")

    def objective(trial, X, y, cv_splitter):
        """
        Optuna objective function for hyperparameter tuning of the EBM.
        """
        params = {
            'interactions': trial.suggest_int('interactions', 0, 15),
            'max_bins': trial.suggest_categorical('max_bins', [128, 256, 512]),
            'learning_rate': trial.suggest_float('learning_rate', 1e-3, 1e-1, log=True),
            'outer_bags': trial.suggest_int('outer_bags', 8, 16),
            'inner_bags': trial.suggest_int('inner_bags', 0, 8),
            'n_jobs': -1,
            'random_state': 42
        }
        ebm = ExplainableBoostingRegressor(**params)
        scores = cross_val_score(ebm, X, y, scoring='neg_root_mean_squared_error', cv=cv_splitter, n_jobs=-1)
        return np.mean(scores)

    outer_fold_scores = []
    fold_counter = 1
    for train_idx, val_idx in outer_cv.split(X_train_val, y_train_val):
        print(f"\n--- Starting Outer Fold {fold_counter}/{outer_cv.get_n_splits()} ---")
        X_train_outer, y_train_outer = X_train_val.iloc[train_idx], y_train_val.iloc[train_idx]
        X_val_outer, y_val_outer = X_train_val.iloc[val_idx], y_train_val.iloc[val_idx]

        study_objective = lambda trial: objective(trial, X_train_outer, y_train_outer, inner_cv)
        study = optuna.create_study(direction='maximize')
        study.optimize(study_objective, n_trials=50, n_jobs=-1)

        best_params = study.best_params
        print(f"Best params for fold {fold_counter}: {best_params}")
        print(f"Best score for fold {fold_counter}: {-study.best_value:.4f} (RMSE)")

        ebm_fold = ExplainableBoostingRegressor(**best_params, n_jobs=-1, random_state=42)
        ebm_fold.fit(X_train_outer, y_train_outer)
        preds = ebm_fold.predict(X_val_outer)
        rmse_score = np.sqrt(mean_squared_error(y_val_outer, preds))
        outer_fold_scores.append(rmse_score)
        print(f"RMSE on outer validation set for fold {fold_counter}: {rmse_score:.4f}")
        fold_counter += 1

    mean_rmse = np.mean(outer_fold_scores)
    std_rmse = np.std(outer_fold_scores)
    print("\n--- Nested Cross-Validation Results ---")
    print(f"Unbiased EBM Performance Estimate (RMSE)")
    print(f"Mean: {mean_rmse:.4f}")
    print(f"Standard Deviation: {std_rmse:.4f}")


    # =============================================================================
    # Step 4: Final Model Training, Saving, and OOS Evaluation
    # =============================================================================
    print("\n--- Step 4: Final Model Training and OOS Evaluation ---")

    print("Running final, extensive hyperparameter search on all training data...")
    final_objective = lambda trial: objective(trial, X_train_val, y_train_val, inner_cv)
    final_study = optuna.create_study(direction='maximize')
    final_study.optimize(final_objective, n_trials=100, n_jobs=-1)
    final_best_params = final_study.best_params
    print(f"Absolute best hyperparameters found: {final_best_params}")
    print(f"Best CV score (RMSE): {-final_study.best_value:.4f}")

    print("Training final model on the entire X_train_val dataset...")
    final_ebm = ExplainableBoostingRegressor(**final_best_params, n_jobs=-1, random_state=42)
    final_ebm.fit(X_train_val, y_train_val)
    print("Final model training complete.")

    model_path = os.path.join(OUTPUT_DIR, 'ebm_final_model_hs.pkl')
    with open(model_path, 'wb') as f:
        pickle.dump(final_ebm, f)
    print(f"Final model saved to: {model_path}")

    print("Evaluating final model on the held-out Out-of-Sample (OOS) data...")
    with open(model_path, 'rb') as f:
        loaded_ebm = pickle.load(f)
    oos_preds = loaded_ebm.predict(X_oos)
    oos_rmse = np.sqrt(mean_squared_error(y_oos, oos_preds))
    oos_r2 = r2_score(y_oos, oos_preds)

    def calculate_csi(y_true, y_pred, threshold):
        """Calculates the Critical Success Index (CSI)."""
        true_positives = np.sum((y_true > threshold) & (y_pred > threshold))
        false_positives = np.sum((y_true <= threshold) & (y_pred > threshold))
        false_negatives = np.sum((y_true > threshold) & (y_pred <= threshold))
        denominator = true_positives + false_positives + false_negatives
        return true_positives / denominator if denominator != 0 else 0.0

    hazard_threshold = 2.5
    oos_csi = calculate_csi(y_oos, oos_preds, hazard_threshold)

    print("\n--- Final OOS Performance Report ---")
    print(f"RMSE: {oos_rmse:.4f}")
    print(f"R-squared (R²): {oos_r2:.4f}")
    print(f"Critical Success Index (CSI) for Hm0 > {hazard_threshold}m: {oos_csi:.4f}")

    report_path = os.path.join(OUTPUT_DIR, 'oos_performance_report.txt')
    with open(report_path, 'w') as f:
        f.write("--- Final OOS Performance Report ---\n")
        f.write(f"RMSE: {oos_rmse:.4f}\n")
        f.write(f"R-squared (R²): {oos_r2:.4f}\n")
        f.write(f"Critical Success Index (CSI) for Hm0 > {hazard_threshold}m: {oos_csi:.4f}\n")
    print(f"Performance report saved to: {report_path}")


    # =============================================================================
    # Step 5: Interpretation and Visualization for Publication
    # =============================================================================
    print("\n--- Step 5: Generating Interpretation Plots ---")

    print("Generating global explanation plots...")
    global_exp = loaded_ebm.explain_global(name='EBM_Global')

    importance_data = global_exp.data()
    feature_importances = pd.DataFrame({
        'feature': importance_data['names'],
        'importance': importance_data['scores']
    }).sort_values(by='importance', ascending=False)

    plt.figure(figsize=(12, 8))
    sns.barplot(x='importance', y='feature', data=feature_importances.head(20), palette='viridis')
    plt.title('Top 20 Feature Importances (Global)', fontsize=16)
    plt.xlabel('Mean Absolute Value (Impact on Model Output)', fontsize=12)
    plt.ylabel('Feature', fontsize=12)
    plt.tight_layout()
    fig_path = os.path.join(OUTPUT_DIR, 'global_feature_importance_top20.png')
    plt.savefig(fig_path, dpi=450)
    plt.close()
    print(f"Saved feature importance plot to: {fig_path}")

    top_10_features = feature_importances['feature'].head(10).tolist()
    print("Generating response curves for top 10 features...")
    for feature_name in top_10_features:
        feature_index = global_exp.feature_names.index(feature_name)
        feature_data = global_exp.data(feature_index)
        plt.figure(figsize=(8, 5))
        plt.plot(feature_data['names'], feature_data['scores'], color='b')
        if 'upper_bounds' in feature_data and 'lower_bounds' in feature_data:
            plt.fill_between(feature_data['names'], feature_data['lower_bounds'], feature_data['upper_bounds'], color='b', alpha=0.2)
        plt.title(f'Response Curve for: {feature_name}', fontsize=14)
        plt.xlabel(f'Value of {feature_name}', fontsize=12)
        plt.ylabel('Contribution to Prediction (log-odds)', fontsize=12)
        plt.tight_layout()
        fig_path = os.path.join(OUTPUT_DIR, f'response_curve_{feature_name}.png')
        plt.savefig(fig_path, dpi=450)
        plt.close()

    print("Generating interaction heatmaps for top 5 interactions...")
    interaction_data = global_exp.data(key='interactions')
    interaction_scores = sorted(interaction_data, key=lambda x: -x[1])
    for i in range(min(5, len(interaction_scores))):
        interaction_index = interaction_scores[i][0]
        interaction_exp = global_exp.data(interaction_index)
        feature_1_name = global_exp.feature_names[interaction_exp['left_names']]
        feature_2_name = global_exp.feature_names[interaction_exp['right_names']]
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111)
        cax = ax.imshow(interaction_exp['scores'], cmap='RdBu', origin='lower')
        fig.colorbar(cax, label='Interaction Contribution')
        ax.set_xticks(np.arange(len(interaction_exp['xtick_vals'])))
        ax.set_yticks(np.arange(len(interaction_exp['ytick_vals'])))
        ax.set_xticklabels([f"{x:.2f}" for x in interaction_exp['xtick_vals']], rotation=45, ha="right")
        ax.set_yticklabels([f"{y:.2f}" for y in interaction_exp['ytick_vals']])
        ax.set_xlabel(feature_2_name)
        ax.set_ylabel(feature_1_name)
        ax.set_title(f'Interaction: {feature_1_name} vs {feature_2_name}')
        plt.tight_layout()
        fig_path = os.path.join(OUTPUT_DIR, f'interaction_{i+1}_{feature_1_name}_vs_{feature_2_name}.png')
        plt.savefig(fig_path, dpi=450)
        plt.close()

    print("Generating local explanations for the 3 highest wave events in OOS...")
    storm_peak_indices = y_oos.nlargest(3).index
    for i, timestamp in enumerate(storm_peak_indices):
        print(f"Generating explanation for storm peak #{i+1} at {timestamp}...")
        instance_to_explain = X_oos.loc[[timestamp]]
        local_exp = loaded_ebm.explain_local(instance_to_explain, name=f'EBM_Local_Storm_{i+1}')
        fig = show(local_exp, show_selector=False)
        fig_path = os.path.join(OUTPUT_DIR, f'local_explanation_storm_{i+1}_{timestamp.strftime("%Y%m%d")}.png')
        fig.savefig(fig_path, dpi=450, bbox_inches='tight')
        plt.close(fig)
        print(f"Saved local explanation plot to: {fig_path}")

    print("\n--- All Phases Complete ---")
    print(f"All outputs (model, reports, figures) are saved in: {OUTPUT_DIR}")

else:
    print("\nHalting script because data loading failed or the training dataframe is empty.")
    print("Please check the input file path and ensure the 'split' column contains 'Train_Val' and 'OOS' values.")



--- Step 1: Setup and Data Preparation ---
Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully.
Output directory will be: /content/drive/My Drive/Paper_3_New/Outputs/Modeling_v1/EBM
Attempting to load data from: /content/drive/My Drive/Paper_3_New/Outputs/Feature_Engineering_v1/final_engineered_features_v3.csv
Data loaded successfully.
Training/Validation set shape: (10538, 239)
Out-of-Sample (OOS) set shape: (7932, 239)
Target variable: buoy_main_hs
Number of features: 202

--- Step 2: Implementing Blocked Time Series CV ---
BlockedTimeSeriesSplit class defined and CV splitters instantiated.

--- Step 3: Hyperparameter Optimization with Nested CV ---
This step will take a significant amount of time.

--- Starting Outer Fold 1/5 ---


[W 2025-07-28 16:17:41,276] Trial 11 failed with parameters: {'interactions': 15, 'max_bins': 256, 'learning_rate': 0.0013218114173333237, 'outer_bags': 8, 'inner_bags': 8} because of the following error: KeyboardInterrupt().
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/joblib/externals/loky/process_executor.py", line 490, in _process_worker
    r = call_item()
        ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/joblib/externals/loky/process_executor.py", line 291, in __call__
    return self.fn(*self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/joblib/parallel.py", line 607, in __call__
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/joblib/parallel.py", line 6

KeyboardInterrupt: 