In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path().absolute().parent))
import time

In [None]:
import optuna
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl

In [None]:
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
)
import torch
from torch.nn import MSELoss
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import glob
from pathlib import Path

from src.data_models.camels_ch import CamelsCH, CamelsCHConfig, get_all_gauge_ids
from src.data_models.dataset import HydroDataset
from src.data_models.preprocessing import (
    scale_time_series,
    scale_static_attributes,
    inverse_scale_static_attributes,
    inverse_scale_time_series,
)
from src.data_models.caravanify import Caravanify, CaravanifyConfig

from utils.metrics import nash_sutcliffe_efficiency
from src.data_models.datamodule import HydroDataModule

from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler

from src.preprocessing.transformers import GroupedTransformer, LogTransformer

In [None]:
from src.models.lstm import LitLSTM
from src.models.ealstm import LitEALSTM
from src.models.TSMixer import LitTSMixer
from src.models.evaluators import TSForecastEvaluator
from torch.optim import Adam
from torch.nn import MSELoss

---

In [None]:
config = CaravanifyConfig(
    attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/timeseries/csv",
    gauge_id_prefix="CA",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)


caravan = Caravanify(config)
ids_for_training = caravan.get_all_gauge_ids()[:3]

print(f"Total number of stations: {len(ids_for_training)}")

caravan.load_stations(ids_for_training)

ts_data = caravan.get_time_series()
static_data = caravan.get_static_attributes()

In [None]:
ts_data["date"] = pd.to_datetime(ts_data["date"])

ts_data["julian_day"] = ts_data["date"].dt.dayofyear

ts_columns = [
    # "potential_evaporation_sum_ERA5_LAND",
    # "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
    "streamflow",
    # "julian_day",
    # "temperature_2m_mean",
    "total_precipitation_sum",
]

In [None]:
whole_data = ts_columns + ["gauge_id", "date"]
ts_data = ts_data[whole_data]

In [None]:
statics_to_keep = [
    "gauge_id",
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]

static_columns = static_data.columns
static_columns = [col for col in list(static_columns) if col in statics_to_keep]

static_data = static_data[static_columns]

In [None]:
features = [
    col for col in ts_data.columns if col not in ["gauge_id", "date", "streamflow"]
]
ts_columns = features + ["streamflow"]  # Ensure target is not in features

# Feature pipeline: log + scale
feature_pipeline = Pipeline([("log", LogTransformer()), ("scaler", StandardScaler())])


target_pipeline = GroupedTransformer(
    Pipeline([("log", LogTransformer()), ("scaler", StandardScaler())]),
    columns=["streamflow"],
    group_identifier="gauge_id",
)

static_pipeline = Pipeline([("scaler", StandardScaler())])

preprocessing_configs = {
    "features": {"pipeline": feature_pipeline},
    "target": {"pipeline": target_pipeline},
    "static_features": {"pipeline": static_pipeline},
}

## Tuning Hyperparameters

In [None]:
output_length = 10
static_columns = [c for c in static_columns if c not in ["gauge_id"]]


def objective(trial):
    # Define the hyperparameters to tune
    batch_size = trial.suggest_int("batch_size", 16, 128)
    input_length = trial.suggest_int("input_length", 14, 60)
    hidden_size = trial.suggest_int("hidden_size", 32, 256)

    # Create data module with the trial's batch size and input length
    data_module = HydroDataModule(
        time_series_df=ts_data,
        static_df=static_data,
        group_identifier="gauge_id",
        preprocessing_config=preprocessing_configs,
        batch_size=batch_size,  # Use trial's batch size
        input_length=input_length,  # Use trial's input length
        output_length=output_length,
        num_workers=4,
        features=ts_columns,
        static_features=static_columns,
        target="streamflow",
        min_train_years=2,
        val_years=1,
        test_years=3,
        max_missing_pct=10,
    )

    # Create model with trial's hidden size
    model = LitTSMixer(
        input_len=input_length,  # Match data module's input length
        output_len=output_length,
        input_size=len(ts_columns),
        static_size=len(static_columns),
        hidden_size=hidden_size,  # Use trial's hidden size
    )

    # Configure trainer with early stopping
    trainer = pl.Trainer(
        max_epochs=1,  # Keep this low for initial testing
        accelerator="cpu",
        devices=1,
        callbacks=[EarlyStopping(monitor="val_loss", patience=3, mode="min")],
        enable_progress_bar=False,  # Reduce output clutter during optimization
    )

    # Train and get the best validation loss
    trainer.fit(model, data_module)

    return trainer.callback_metrics["val_loss"].item()


# Create a study object and specify the direction of optimization
study = optuna.create_study(direction="minimize")

# Run the optimization
study.optimize(objective, n_trials=10)  # Start with 10 trials for testing

# Print the best parameters and score
print("Best parameters:", study.best_params)
print("Best validation loss:", study.best_value)

# You can also print a summary of the optimization
print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print(
    "  Number of pruned trials: ",
    len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]),
)
print(
    "  Number of complete trials: ",
    len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]),
)