In [None]:
import pandas as pd
import numpy as np
import joblib
import json
from sklearn.model_selection import GridSearchCV, TimeSeriesSplit
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
import os

# Create directories for saving results
os.makedirs('models_v2/baseline', exist_ok=True)
os.makedirs('models_v2/with_typhoon', exist_ok=True)
os.makedirs('models_v2/with_typhoon_enhanced', exist_ok=True)
os.makedirs('predictions_v2', exist_ok=True)
os.makedirs('results_v2', exist_ok=True)

# ================== Step 1: Typhoon Feature Engineering ==================

def create_typhoon_features(df):
    """
    Create enhanced typhoon features - improved version
    
    Improvements:
    1. Use non-linear transformation for distance feature (sigmoid function)
    2. Create comprehensive impact index with physical meaning
    3. Add relative position features (typhoon position relative to Boluo Station)
    4. Standardize air pressure feature
    5. Create intensity accumulation features (consider impact intensity not just presence)
    """
    df = df.copy()
    
    # 1. Typhoon active flag (basic binary feature)
    df['typhoon_active'] = (df['typhoon_grade'] != 0).astype(int)
    
    # 2. Numerical encoding of typhoon intensity grade
    def map_intensity(val):
        if pd.isna(val) or val == 0 or val == '0':
            return 0
        
        val_str = str(val).upper()
        
        if 'TD' in val_str or 'TROPICAL DEPRESSION' in val_str:
            return 1
        elif 'TS' in val_str or ('TROPICAL STORM' in val_str and 'SEVERE' not in val_str):
            return 2
        elif 'STS' in val_str or 'SEVERE TROPICAL STORM' in val_str:
            return 3
        elif 'STY' in val_str or 'SUPER TYPHOON' in val_str:
            return 6
        elif 'TY' in val_str or ('TYPHOON' in val_str and 'SUPER' not in val_str and 'SEVERE' not in val_str):
            return 4
        elif 'SEVERE TYPHOON' in val_str:
            return 5
        else:
            try:
                return float(val)
            except:
                return 0
    
    df['typhoon_intensity_encoded'] = df['typhoon_intensity'].apply(map_intensity)
    
    # 3. 【Improved】Typhoon distance impact weight (non-linear transformation using sigmoid function)
    # Sigmoid function converts distance to 0-1 impact weight
    # Impact weight is 0.5 at 250km, higher weight for closer distance
    df['typhoon_distance_impact'] = df.apply(
        lambda row: 1 / (1 + np.exp((row['distance_to_boluo'] - 250) / 50)) if row['distance_to_boluo'] > 0 else 0,
        axis=1
    )
    
    # 4. 【Improved】Comprehensive typhoon impact index (non-linear combination of wind speed, grade and distance)
    df['typhoon_comprehensive_impact'] = df['wind_speed'] * df['typhoon_grade'] * df['typhoon_distance_impact']
    
    # 5. 【Improved】Typhoon wind force grade (based on Beaufort scale)
    def classify_wind(speed):
        if speed == 0:
            return 0
        elif speed < 17.2:
            return 1  # Tropical Depression
        elif speed < 24.5:
            return 2  # Tropical Storm
        elif speed < 32.7:
            return 3  # Severe Tropical Storm
        elif speed < 41.5:
            return 4  # Typhoon
        elif speed < 51.0:
            return 5  # Severe Typhoon
        else:
            return 6  # Super Typhoon
    
    df['wind_force_grade'] = df['wind_speed'].apply(classify_wind)
    
    # 6. 【Improved】Air pressure intensity (standardized deviation from standard atmospheric pressure)
    # Lower pressure means stronger typhoon, standardized by dividing by 50
    df['air_pressure_intensity'] = df.apply(
        lambda row: (1013.25 - row['air_pressure']) / 50 if row['air_pressure'] > 0 else 0,
        axis=1
    )
    
    # 7. Typhoon duration feature
    df['typhoon_duration_days'] = df['days_since_typhoon_start']
    
    # 8. Typhoon occurrence accumulation (sliding window - count typhoon presence)
    df['typhoon_accumulated_3d'] = df['typhoon_active'].rolling(window=3, min_periods=1).sum()
    df['typhoon_accumulated_7d'] = df['typhoon_active'].rolling(window=7, min_periods=1).sum()
    
    # 9. 【New】Typhoon impact intensity accumulation (sliding window - average impact intensity)
    # This feature considers both presence and actual impact intensity of typhoons
    df['typhoon_intensity_accumulated_3d'] = df['typhoon_comprehensive_impact'].rolling(window=3, min_periods=1).mean()
    df['typhoon_intensity_accumulated_7d'] = df['typhoon_comprehensive_impact'].rolling(window=7, min_periods=1).mean()
    
    # 10. Typhoon movement features
    df['typhoon_longitude_change'] = df['typhoon_longitude'].diff().fillna(0)
    df['typhoon_latitude_change'] = df['typhoon_latitude'].diff().fillna(0)
    df['typhoon_movement_speed'] = np.sqrt(df['typhoon_longitude_change']**2 + df['typhoon_latitude_change']**2)
    
    # 11. 【New】Typhoon relative position features (deviation from Boluo Station)
    # Captures directional information of typhoon relative to Boluo
    boluo_lon = 114.2967
    boluo_lat = 23.15881
    df['typhoon_relative_longitude'] = df['typhoon_longitude'] - boluo_lon
    df['typhoon_relative_latitude'] = df['typhoon_latitude'] - boluo_lat
    
    # 12. 【New】Season-typhoon interaction feature
    df['date'] = pd.to_datetime(df['date'])
    df['month'] = df['date'].dt.month
    df['is_typhoon_season'] = df['month'].isin([6, 7, 8, 9, 10]).astype(int)
    df['typhoon_season_intensity'] = df['typhoon_comprehensive_impact'] * df['is_typhoon_season']
    
    return df

# ================== Step 2: Create Lagged Features ==================

def create_lagged_features(df, target_column, lag_days=7, feature_type='baseline'):
    """
    Create lagged features for the past lag_days to predict target column, named in (T-1) format
    """
    features = []
    labels = []
    
    # Define typhoon-related columns
    typhoon_base_cols = ['days_since_typhoon_start', 'typhoon_longitude', 'typhoon_latitude', 
                         'typhoon_grade', 'typhoon_intensity', 'wind_speed', 'air_pressure', 'distance_to_boluo']

    delete_cols = ['boluo_discharge','baipenzhu_outflow', 'boluo', 'boluo_T' , 'boluo_p', 'boluo_E', 'andun', 
                  'dabeibu', 'pingshan']
    delete_cols2 = ['boluo_discharge','baipenzhu_outflow', 'boluo', 'boluo_T', 'boluo_p',  'boluo_E']
        
    # 【Modified】Update enhanced feature list to match create_typhoon_features()
    typhoon_enhanced_cols = [
        'typhoon_active',                # Typhoon active flag
        'typhoon_intensity_encoded',     # Intensity numerical encoding
        'typhoon_distance_impact',       # Improved: sigmoid distance impact (formerly 'typhoon_distance_reciprocal')
        'typhoon_comprehensive_impact',  # Improved: comprehensive impact index (formerly 'typhoon_impact_index')
        'wind_force_grade',              # Wind force classification (formerly 'wind_speed_grade')
        'air_pressure_intensity',        # Improved: standardized air pressure (formerly 'air_pressure_diff')
        'typhoon_duration_days',         # New: duration days
        'typhoon_accumulated_3d',        # Active accumulation
        'typhoon_accumulated_7d',        # Active accumulation
        'typhoon_intensity_accumulated_3d',  # New: intensity accumulation
        'typhoon_intensity_accumulated_7d',  # New: intensity accumulation
        'typhoon_longitude_change',      # Movement feature
        'typhoon_latitude_change',       # Movement feature
        'typhoon_movement_speed',        # Movement feature
        'typhoon_relative_longitude',    # New: relative position
        'typhoon_relative_latitude',     # New: relative position
        'month',                         # Season feature
        'is_typhoon_season',             # New: typhoon season flag
        'typhoon_season_intensity'       # New: season interaction (formerly 'typhoon_season_interaction')
    ]
    
    # Select columns based on feature type
    exclude_cols = ['date', target_column]
    
    if feature_type == 'baseline':
        # Baseline model: exclude all typhoon-related features
        exclude_cols.extend(typhoon_base_cols)
        exclude_cols.extend(typhoon_enhanced_cols)
        exclude_cols.extend(delete_cols)
    elif feature_type == 'with_typhoon':
        # Include original typhoon features, exclude enhanced features and raw intensity column (contains strings)
        exclude_cols.extend(typhoon_enhanced_cols)
        exclude_cols.append('typhoon_intensity')  # Exclude raw intensity column, use encoded version
        exclude_cols.extend(delete_cols2)
    elif feature_type == 'with_typhoon_enhanced':
        # Include all features, exclude raw typhoon intensity column
        exclude_cols.append('typhoon_intensity')  # Exclude raw intensity column
    
    feature_cols = [col for col in df.columns if col not in exclude_cols]
    
    for i in range(len(df) - lag_days):
        feature = df[feature_cols].iloc[i:i + lag_days].values.flatten()
        label = df[target_column].iloc[i + lag_days]
        features.append(feature)
        labels.append(label)
    
    # Generate feature column names (T-1, T-2, ..., T-7 format)
    feature_columns = []
    for day in range(1, lag_days + 1):
        for col in feature_cols:
            feature_columns.append(f'{col}(T-{day})')
    
    features = pd.DataFrame(features, columns=feature_columns)
    labels = pd.Series(labels, name=target_column)
    
    return features, labels

def prepare_data(data_file, lag_days=7):
    """
    Prepare three datasets: baseline, with original typhoon features, with enhanced typhoon features
    """
    # Read data
    print("Reading data...")
    df = pd.read_csv(data_file)
    
    # Data preprocessing: fill values less than 0 with 0
    numeric_columns = df.select_dtypes(include=['number']).columns
    df[numeric_columns] = df[numeric_columns].applymap(lambda x: 0 if x < 0 else x)
    
    # Create enhanced typhoon features
    print("Creating enhanced typhoon features...")
    df_enhanced = create_typhoon_features(df)
    
    # Create three datasets
    print("Creating baseline dataset (without typhoon features)...")
    features_baseline, labels_baseline = create_lagged_features(
        df_enhanced, 'boluo_discharge', lag_days=lag_days, feature_type='baseline'
    )
    combined_baseline = pd.concat([features_baseline, labels_baseline], axis=1)
    combined_baseline['date'] = df_enhanced['date'][lag_days:].values
    
    print("Creating dataset with original typhoon features...")
    features_typhoon, labels_typhoon = create_lagged_features(
        df_enhanced, 'boluo_discharge', lag_days=lag_days, feature_type='with_typhoon'
    )
    combined_typhoon = pd.concat([features_typhoon, labels_typhoon], axis=1)
    combined_typhoon['date'] = df_enhanced['date'][lag_days:].values
    
    print("Creating dataset with enhanced typhoon features...")
    features_enhanced, labels_enhanced = create_lagged_features(
        df_enhanced, 'boluo_discharge', lag_days=lag_days, feature_type='with_typhoon_enhanced'
    )
    combined_enhanced = pd.concat([features_enhanced, labels_enhanced], axis=1)
    combined_enhanced['date'] = df_enhanced['date'][lag_days:].values
    
    return combined_baseline, combined_typhoon, combined_enhanced

def split_data(combined_data, train_start='1985-01-01', train_end='2004-12-31',
               test_start='2005-01-01', test_end='2013-12-31'):
    """
    Split into training and testing sets
    """
    combined_data['date'] = pd.to_datetime(combined_data['date'])
    
    train_data = combined_data[
        (combined_data['date'] >= train_start) & 
        (combined_data['date'] <= train_end)
    ].copy()
    test_data = combined_data[
        (combined_data['date'] >= test_start) & 
        (combined_data['date'] <= test_end)
    ].copy()
    
    return train_data, test_data

# ================== Step 3: Flood Event Identification ==================

def identify_flood_events(y_data, dates, threshold_percentile=90):
    """
    Identify flood events
    """
    threshold = np.percentile(y_data, threshold_percentile)
    flood_indices = np.where(y_data >= threshold)[0]
    
    print(f"Flood threshold ({threshold_percentile}th percentile): {threshold:.2f} m³/s")
    print(f"Flood event days: {len(flood_indices)} ({len(flood_indices)/len(y_data)*100:.2f}%)")
    
    return flood_indices, threshold

# ================== Step 4: Evaluation Metrics ==================

def nse(y_true, y_pred):
    """Nash-Sutcliffe Efficiency (NSE)"""
    return 1 - (np.sum((y_true - y_pred)**2) / np.sum((y_true - np.mean(y_true))**2))

def kge(y_true, y_pred):
    """Kling-Gupta Efficiency (KGE)"""
    mean_obs = np.mean(y_true)
    mean_pred = np.mean(y_pred)
    std_obs = np.std(y_true)
    std_pred = np.std(y_pred)
    r = np.corrcoef(y_true, y_pred)[0, 1]
    
    return 1 - np.sqrt((r - 1)**2 + (std_pred/std_obs - 1)**2 + (mean_pred/mean_obs - 1)**2)

def evaluate_model(y_true, y_pred, flood_indices=None):
    """
    Calculate evaluation metrics
    """
    # Overall evaluation metrics
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    nse_val = nse(y_true, y_pred)
    kge_val = kge(y_true, y_pred)
    
    metrics = {
        'RMSE': rmse,
        'MAE': mae,
        'NSE': nse_val,
        'KGE': kge_val
    }
    
    # Calculate flood period metrics if flood indices are provided
    if flood_indices is not None and len(flood_indices) > 0:
        y_true_flood = y_true.iloc[flood_indices]
        y_pred_flood = y_pred[flood_indices]
        
        rmse_flood = np.sqrt(mean_squared_error(y_true_flood, y_pred_flood))
        mae_flood = mean_absolute_error(y_true_flood, y_pred_flood)
        nse_flood = nse(y_true_flood, y_pred_flood)
        kge_flood = kge(y_true_flood, y_pred_flood)
        
        metrics['RMSE_flood'] = rmse_flood
        metrics['MAE_flood'] = mae_flood
        metrics['NSE_flood'] = nse_flood
        metrics['KGE_flood'] = kge_flood
    
    return metrics

# ================== Step 5: Model Definition ==================

def get_models_and_params():
    """Return 4 core models and corresponding parameter grids"""
    tscv = TimeSeriesSplit(n_splits=5)
    
    # Linear Regression
    lr = LinearRegression()
    lr_params = {
        'fit_intercept': [True, False]
    }
    
    # Artificial Neural Network
    ann = MLPRegressor(max_iter=2000, random_state=42, early_stopping=True)
    ann_params = {
        'hidden_layer_sizes': [(64,), (128,), (64, 32)],
        'activation': ['relu', 'tanh'],
        'alpha': [0.0001, 0.001],
        'learning_rate': ['constant', 'adaptive']
    }
    
    # Random Forest
    rf = RandomForestRegressor(random_state=42, n_jobs=-1)
    rf_params = {
        'n_estimators': [100, 200],
        'max_depth': [None, 15, 20],
        'min_samples_split': [2, 5],
        'min_samples_leaf': [1, 2],
        'max_features': ['sqrt', 'log2']
    }
    
    # XGBoost
    xgb = XGBRegressor(objective='reg:squarederror', random_state=42, n_jobs=-1)
    xgb_params = {
        'n_estimators': [100, 200],
        'max_depth': [3, 5, 7],
        'learning_rate': [0.01, 0.1, 0.2],
        'subsample': [0.8, 1.0],
        'colsample_bytree': [0.8, 1.0]
    }
    
    return {
        'LR': {'model': lr, 'params': lr_params, 'cv': tscv},
        'ANN': {'model': ann, 'params': ann_params, 'cv': tscv},
        'RF': {'model': rf, 'params': rf_params, 'cv': tscv},
        'XGB': {'model': xgb, 'params': xgb_params, 'cv': tscv}
    }

# ================== Step 6: Training and Evaluation ==================

def train_and_compare(X_train, y_train, X_test, y_test, dates_test, 
                     scenario_name, flood_indices):
    """
    Train and evaluate all models
    """
    models_info = get_models_and_params()
    results = {}
    
    for name, info in models_info.items():
        # Check if model already exists
        model_path = f'models_v2/{scenario_name}/{name}_best_model.pkl'
        if os.path.exists(model_path):
            print(f"\n{name} model already exists, skipping training and loading directly...")
            best_estimator = joblib.load(model_path)
            
            # Prediction
            y_pred = best_estimator.predict(X_test)
            
            # Evaluation
            metrics = evaluate_model(y_test, y_pred, flood_indices)
            
            print(f"{name} - NSE: {metrics['NSE']:.4f}, RMSE: {metrics['RMSE']:.2f}")
            
            # Save prediction results
            predictions_df = pd.DataFrame({
                'date': dates_test.values,
                'observed': y_test.values,
                'predicted': y_pred
            })
            predictions_df.to_csv(f'predictions_v2/{scenario_name}_{name}_predictions.csv', index=False)
            
            results[name] = {
                'best_params': 'loaded_from_file',
                'metrics': metrics
            }
            
        else:
            print(f"\nTraining {name} model ({scenario_name})...")
            
            # Grid search for hyperparameter tuning
            grid_search = GridSearchCV(
                estimator=info['model'],
                param_grid=info['params'],
                cv=info['cv'],
                scoring='neg_mean_squared_error',
                n_jobs=-1,
                verbose=0
            )
            
            # Train model
            grid_search.fit(X_train, y_train)
            
            # Get best model
            best_estimator = grid_search.best_estimator_
            best_params = grid_search.best_params_
            
            # Prediction
            y_pred = best_estimator.predict(X_test)
            
            # Evaluation
            metrics = evaluate_model(y_test, y_pred, flood_indices)
            
            print(f"{name} completed - NSE: {metrics['NSE']:.4f}, RMSE: {metrics['RMSE']:.2f}")
            
            # Save model
            joblib.dump(best_estimator, model_path)
            
            # Save prediction results
            predictions_df = pd.DataFrame({
                'date': dates_test.values,
                'observed': y_test.values,
                'predicted': y_pred
            })
            predictions_df.to_csv(f'predictions_v2/{scenario_name}_{name}_predictions.csv', index=False)
            
            # Save results
            results[name] = {
                'best_params': best_params,
                'metrics': metrics
            }
    
    # Save scenario results
    with open(f'results_v2/{scenario_name}_results.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4, ensure_ascii=False)
    
    return results

# ================== Step 7: Results Comparison Tables ==================

def print_comparison_tables(results_baseline, results_typhoon, results_enhanced):
    """
    Print comparison tables for 4 metrics
    """
    models = ['LR', 'ANN', 'RF', 'XGB']
    
    print("\n" + "="*100)
    print("Experimental Results Comparison")
    print("="*100)
    
    # Table 1: Overall NSE comparison
    print("\n【Table 1】Overall NSE Metric Comparison")
    print("-"*100)
    print(f"{'Model':<10} {'Baseline':>15} {'Original Typhoon':>18} {'Enhanced Typhoon':>18} {'Original Imp%':>12} {'Enhanced Imp%':>12}")
    print("-"*100)
    for model in models:
        baseline = results_baseline[model]['metrics']['NSE']
        typhoon = results_typhoon[model]['metrics']['NSE']
        enhanced = results_enhanced[model]['metrics']['NSE']
        imp_typhoon = ((typhoon - baseline) / abs(baseline + 1e-10)) * 100
        imp_enhanced = ((enhanced - baseline) / abs(baseline + 1e-10)) * 100
        print(f"{model:<10} {baseline:>15.4f} {typhoon:>18.4f} {enhanced:>18.4f} {imp_typhoon:>11.2f}% {imp_enhanced:>11.2f}%")
    print("-"*100)
    
    # Table 2: Overall RMSE comparison
    print("\n【Table 2】Overall RMSE Metric Comparison")
    print("-"*100)
    print(f"{'Model':<10} {'Baseline':>15} {'Original Typhoon':>18} {'Enhanced Typhoon':>18} {'Original Red%':>12} {'Enhanced Red%':>12}")
    print("-"*100)
    for model in models:
        baseline = results_baseline[model]['metrics']['RMSE']
        typhoon = results_typhoon[model]['metrics']['RMSE']
        enhanced = results_enhanced[model]['metrics']['RMSE']
        imp_typhoon = ((baseline - typhoon) / abs(baseline + 1e-10)) * 100
        imp_enhanced = ((baseline - enhanced) / abs(baseline + 1e-10)) * 100
        print(f"{model:<10} {baseline:>15.2f} {typhoon:>18.2f} {enhanced:>18.2f} {imp_typhoon:>11.2f}% {imp_enhanced:>11.2f}%")
    print("-"*100)
    
    # Table 3: Flood period NSE comparison
    print("\n【Table 3】Flood Period NSE Metric Comparison")
    print("-"*100)
    print(f"{'Model':<10} {'Baseline':>15} {'Original Typhoon':>18} {'Enhanced Typhoon':>18} {'Original Imp%':>12} {'Enhanced Imp%':>12}")
    print("-"*100)
    for model in models:
        baseline = results_baseline[model]['metrics'].get('NSE_flood', 0)
        typhoon = results_typhoon[model]['metrics'].get('NSE_flood', 0)
        enhanced = results_enhanced[model]['metrics'].get('NSE_flood', 0)
        imp_typhoon = ((typhoon - baseline) / abs(baseline + 1e-10)) * 100
        imp_enhanced = ((enhanced - baseline) / abs(baseline + 1e-10)) * 100
        print(f"{model:<10} {baseline:>15.4f} {typhoon:>18.4f} {enhanced:>18.4f} {imp_typhoon:>11.2f}% {imp_enhanced:>11.2f}%")
    print("-"*100)
    
    # Table 4: Flood period RMSE comparison
    print("\n【Table 4】Flood Period RMSE Metric Comparison")
    print("-"*100)
    print(f"{'Model':<10} {'Baseline':>15} {'Original Typhoon':>18} {'Enhanced Typhoon':>18} {'Original Red%':>12} {'Enhanced Red%':>12}")
    print("-"*100)
    for model in models:
        baseline = results_baseline[model]['metrics'].get('RMSE_flood', 0)
        typhoon = results_typhoon[model]['metrics'].get('RMSE_flood', 0)
        enhanced = results_enhanced[model]['metrics'].get('RMSE_flood', 0)
        imp_typhoon = ((baseline - typhoon) / abs(baseline + 1e-10)) * 100
        imp_enhanced = ((baseline - enhanced) / abs(baseline + 1e-10)) * 100
        print(f"{model:<10} {baseline:>15.2f} {typhoon:>18.2f} {enhanced:>18.2f} {imp_typhoon:>11.2f}% {imp_enhanced:>11.2f}%")
    print("-"*100)
    
    # Save tables to CSV
    save_tables_to_csv(results_baseline, results_typhoon, results_enhanced, models)

def save_tables_to_csv(results_baseline, results_typhoon, results_enhanced, models):
    """
    Save comparison tables to CSV files
    """
    # Overall metrics table
    overall_data = []
    for model in models:
        row = {
            'Model': model,
            'Baseline_NSE': results_baseline[model]['metrics']['NSE'],
            'Baseline_RMSE': results_baseline[model]['metrics']['RMSE'],
            'Baseline_MAE': results_baseline[model]['metrics']['MAE'],
            'Baseline_KGE': results_baseline[model]['metrics']['KGE'],
            'Original_Typhoon_NSE': results_typhoon[model]['metrics']['NSE'],
            'Original_Typhoon_RMSE': results_typhoon[model]['metrics']['RMSE'],
            'Original_Typhoon_MAE': results_typhoon[model]['metrics']['MAE'],
            'Original_Typhoon_KGE': results_typhoon[model]['metrics']['KGE'],
            'Enhanced_Typhoon_NSE': results_enhanced[model]['metrics']['NSE'],
            'Enhanced_Typhoon_RMSE': results_enhanced[model]['metrics']['RMSE'],
            'Enhanced_Typhoon_MAE': results_enhanced[model]['metrics']['MAE'],
            'Enhanced_Typhoon_KGE': results_enhanced[model]['metrics']['KGE']
        }
        overall_data.append(row)
    
    df_overall = pd.DataFrame(overall_data)
    df_overall.to_csv('results_v2/comparison_overall.csv', index=False, encoding='utf-8-sig')
    
    # Flood period metrics table
    flood_data = []
    for model in models:
        row = {
            'Model': model,
            'Baseline_NSE_flood': results_baseline[model]['metrics'].get('NSE_flood', np.nan),
            'Baseline_RMSE_flood': results_baseline[model]['metrics'].get('RMSE_flood', np.nan),
            'Baseline_MAE_flood': results_baseline[model]['metrics'].get('MAE_flood', np.nan),
            'Baseline_KGE_flood': results_baseline[model]['metrics'].get('KGE_flood', np.nan),
            'Original_Typhoon_NSE_flood': results_typhoon[model]['metrics'].get('NSE_flood', np.nan),
            'Original_Typhoon_RMSE_flood': results_typhoon[model]['metrics'].get('RMSE_flood', np.nan),
            'Original_Typhoon_MAE_flood': results_typhoon[model]['metrics'].get('MAE_flood', np.nan),
            'Original_Typhoon_KGE_flood': results_typhoon[model]['metrics'].get('KGE_flood', np.nan),
            'Enhanced_Typhoon_NSE_flood': results_enhanced[model]['metrics'].get('NSE_flood', np.nan),
            'Enhanced_Typhoon_RMSE_flood': results_enhanced[model]['metrics'].get('RMSE_flood', np.nan),
            'Enhanced_Typhoon_MAE_flood': results_enhanced[model]['metrics'].get('MAE_flood', np.nan),
            'Enhanced_Typhoon_KGE_flood': results_enhanced[model]['metrics'].get('KGE_flood', np.nan)
        }
        flood_data.append(row)
    
    df_flood = pd.DataFrame(flood_data)
    df_flood.to_csv('results_v2/comparison_flood.csv', index=False, encoding='utf-8-sig')
    
    print("\nComparison tables saved to:")
    print("  - results_v2/comparison_overall.csv (Overall metrics)")
    print("  - results_v2/comparison_flood.csv (Flood period metrics)")

# ================== Main Program ==================

def main():
    """
    Main experimental workflow
    """
    print("="*100)
    print("Experiment: Impact of Typhoon Features on Runoff and Flood Prediction Accuracy")
    print("Compare three scenarios: Baseline, Original Typhoon Features, Enhanced Typhoon Features")
    print("="*100)
    
    # 1. Prepare data
    print("\nStep 1: Preparing data...")
    data_file = '../data/typhoon_daily_boluo.csv'
    combined_baseline, combined_typhoon, combined_enhanced = prepare_data(data_file, lag_days=7)
    
    # 2. Split into training and testing sets
    print("\nStep 2: Splitting training and testing sets...")
    train_baseline, test_baseline = split_data(combined_baseline)
    train_typhoon, test_typhoon = split_data(combined_typhoon)
    train_enhanced, test_enhanced = split_data(combined_enhanced)
    
    print(f"Training set size: {len(train_baseline)} samples")
    print(f"Testing set size: {len(test_baseline)} samples")
    print(f"Baseline model features: {len(train_baseline.columns) - 2}")
    print(f"Original typhoon features model features: {len(train_typhoon.columns) - 2}")
    print(f"Enhanced typhoon features model features: {len(train_enhanced.columns) - 2}")
    
    # 3. Identify flood events
    print("\nStep 3: Identifying flood events...")
    y_test = test_baseline['boluo_discharge']
    dates_test = test_baseline['date']
    flood_indices, flood_threshold = identify_flood_events(y_test, dates_test, threshold_percentile=90)
    
    # 4. Train and evaluate: Baseline model
    print("\n" + "="*100)
    print("Step 4: Training baseline model (without typhoon features)")
    print("="*100)
    X_train_baseline = train_baseline.drop(['boluo_discharge', 'date'], axis=1)
    y_train_baseline = train_baseline['boluo_discharge']
    X_test_baseline = test_baseline.drop(['boluo_discharge', 'date'], axis=1)
    
    results_baseline = train_and_compare(
        X_train_baseline, y_train_baseline, 
        X_test_baseline, y_test, dates_test,
        'baseline', flood_indices
    )
    
    # 5. Train and evaluate: Original typhoon features
    print("\n" + "="*100)
    print("Step 5: Training model with original typhoon features")
    print("="*100)
    X_train_typhoon = train_typhoon.drop(['boluo_discharge', 'date'], axis=1)
    y_train_typhoon = train_typhoon['boluo_discharge']
    X_test_typhoon = test_typhoon.drop(['boluo_discharge', 'date'], axis=1)
    
    results_typhoon = train_and_compare(
        X_train_typhoon, y_train_typhoon,
        X_test_typhoon, y_test, dates_test,
        'with_typhoon', flood_indices
    )
    
    # 6. Train and evaluate: Enhanced typhoon features
    print("\n" + "="*100)
    print("Step 6: Training model with enhanced typhoon features")
    print("="*100)
    X_train_enhanced = train_enhanced.drop(['boluo_discharge', 'date'], axis=1)
    y_train_enhanced = train_enhanced['boluo_discharge']
    X_test_enhanced = test_enhanced.drop(['boluo_discharge', 'date'], axis=1)
    
    results_enhanced = train_and_compare(
        X_train_enhanced, y_train_enhanced,
        X_test_enhanced, y_test, dates_test,
        'with_typhoon_enhanced', flood_indices
    )
    
    # 7. Print comparison tables
    print_comparison_tables(results_baseline, results_typhoon, results_enhanced)
    
    print("\n" + "="*100)
    print("Experiment completed!")
    print("="*100)
    print("\nAll results saved to the following directories:")
    print("  - models_v2/baseline/: Baseline model files")
    print("  - models_v2/with_typhoon/: Original typhoon features model files")
    print("  - models_v2/with_typhoon_enhanced/: Enhanced typhoon features model files")
    print("  - predictions_v2/: Prediction results CSV files for all models")
    print("  - results_v2/: Evaluation metrics JSON files and comparison tables CSV files")

if __name__ == "__main__":
    main()