In [1]:
import sys
from pathlib import Path

sys.path.append(str(Path().absolute().parent))


In [None]:
from pytorch_lightning.callbacks import (

    EarlyStopping,
)
import matplotlib.pyplot as plt


import seaborn as sns

import pandas as pd
from pathlib import Path


from src.data_models.caravanify import Caravanify, CaravanifyConfig

from src.data_models.datamodule import HydroTransferDataModule, HydroDataModule

from sklearn.pipeline import Pipeline

from src.preprocessing.grouped import GroupedTransformer
from src.preprocessing.standard_scale import StandardScaleTransformer

In [None]:
from src.models.TSMixerDomainAdaptation import LitTSMixerDomainAdaptation
from src.models.evaluators import TSForecastEvaluator

---

## Central Asia

In [None]:
# Configuration for loading Central Asian (CA) hydrology data
CA_config = CaravanifyConfig(
    attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/timeseries/csv",
    gauge_id_prefix="CA",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)

# Initialize Caravan data loader and load first 3 stations for training
CA_caravan = Caravanify(CA_config)
ids_for_training = CA_caravan.get_all_gauge_ids()[14:17]
print(f"Total number of stations: {len(ids_for_training)}")
CA_caravan.load_stations(ids_for_training)

# Get time series and static data
CA_ts_data = CA_caravan.get_time_series()
CA_static_data = CA_caravan.get_static_attributes()

# Process time series data
CA_ts_data["date"] = pd.to_datetime(CA_ts_data["date"])
CA_ts_data["julian_day"] = CA_ts_data["date"].dt.dayofyear

# Select relevant time series features
ts_columns = ["streamflow", "total_precipitation_sum"]
CA_ts_data = CA_ts_data[ts_columns + ["gauge_id", "date"]]

# Select relevant static features that characterize catchment properties
static_columns = [
    "gauge_id",
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]
CA_static_data = CA_static_data[static_columns]

# Separate features from target variable
features = [
    col for col in CA_ts_data.columns if col not in ["gauge_id", "date", "streamflow"]
]
ts_columns = features + ["streamflow"]

feature_pipeline = Pipeline([("scaler", StandardScaleTransformer(columns=features))])

target_pipeline = GroupedTransformer(
    Pipeline([("scaler", StandardScaleTransformer(columns=["streamflow"]))]),
    columns=["streamflow"],
    group_identifier="gauge_id",
)

static_pipeline = Pipeline(
    [("scaler", StandardScaleTransformer(columns=static_columns[1:]))]
)

# Define preprocessing configurations
preprocessing_configs = {
    "features": {"pipeline": feature_pipeline},
    "target": {"pipeline": target_pipeline},
    "static_features": {"pipeline": static_pipeline},
}

In [None]:
batch_size = 128
output_length = 10
input_length = 64
hidden_size = 32

# Create data module with the trial's batch size and input length
CA_data_module = HydroDataModule(
    time_series_df=CA_ts_data,
    static_df=CA_static_data,
    group_identifier="gauge_id",
    preprocessing_config=preprocessing_configs,
    batch_size=batch_size,  # Use trial's batch size
    input_length=input_length,  # Use trial's input length
    output_length=output_length,
    num_workers=4,
    features=ts_columns,
    static_features=static_columns,
    target="streamflow",
    min_train_years=2,
    val_years=1,
    test_years=3,
    max_missing_pct=10,
    domain_id="target",
)

CA_data_module.prepare_data()
CA_data_module.setup()

## Switzerland 

In [None]:
# Configuration for loading Central Asian (CA) hydrology data
CH_config = CaravanifyConfig(
    attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CH/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CH/post_processed/timeseries/csv",
    gauge_id_prefix="CH",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)

# Initialize Caravan data loader and load first 3 stations for training
CH_caravan = Caravanify(CH_config)
ids_for_training = CH_caravan.get_all_gauge_ids()[14:17]
print(f"Total number of stations: {len(ids_for_training)}")
CH_caravan.load_stations(ids_for_training)

# Get time series and static data
CH_ts_data = CH_caravan.get_time_series()
CH_static_data = CH_caravan.get_static_attributes()

# Process time series data
CH_ts_data["date"] = pd.to_datetime(CH_ts_data["date"])
CH_ts_data["julian_day"] = CH_ts_data["date"].dt.dayofyear

# Select relevant time series features
ts_columns = ["streamflow", "total_precipitation_sum"]
CH_ts_data = CH_ts_data[ts_columns + ["gauge_id", "date"]]

# Select relevant static features that characterize catchment properties
static_columns = [
    "gauge_id",
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]
CH_static_data = CH_static_data[static_columns]

# Separate features from target variable
features = [
    col for col in CH_ts_data.columns if col not in ["gauge_id", "date", "streamflow"]
]
ts_columns = features + ["streamflow"]

# Define preprocessing pipelines
feature_pipeline = Pipeline([("scaler", StandardScaleTransformer(columns=features))])

target_pipeline = GroupedTransformer(
    Pipeline([("scaler", StandardScaleTransformer(columns=["streamflow"]))]),
    columns=["streamflow"],
    group_identifier="gauge_id",
)

static_pipeline = Pipeline(
    [("scaler", StandardScaleTransformer(columns=static_columns[1:]))]
)

# Define preprocessing configurations
preprocessing_configs = {
    "features": {"pipeline": feature_pipeline},
    "target": {"pipeline": target_pipeline},
    "static_features": {"pipeline": static_pipeline},
}

In [None]:
CH_data_module = HydroDataModule(
    time_series_df=CH_ts_data,
    static_df=CH_static_data,
    group_identifier="gauge_id",
    preprocessing_config=preprocessing_configs,
    batch_size=batch_size,  # Use trial's batch size
    input_length=input_length,  # Use trial's input length
    output_length=output_length,
    num_workers=4,
    features=ts_columns,
    static_features=static_columns,
    target="streamflow",
    min_train_years=2,
    val_years=1,
    test_years=3,
    max_missing_pct=10,
    domain_id="source",
)

CH_data_module.prepare_data()
CH_data_module.setup()

## Testing combined DataModule

In [None]:
# Create transfer datamodule
transfer_dm = HydroTransferDataModule(
    source_datamodule=CA_data_module,
    target_datamodule=CH_data_module,
    num_workers=4,
    mode="min_size",
)

# # Show one batch of data
# batch = next(iter(transfer_dm.train_dataloader()))
# print(f"Batch type: {type(batch)}")

# # Since it's a tuple, we need to access elements by index
# data_dict, _, _ = batch
# source_batch = data_dict["source"]
# target_batch = data_dict["target"]

# # Print source batch information
# print("\nSource batch keys:", source_batch.keys())
# print("Source X shape:", source_batch["X"].shape)
# print("Source static shape:", source_batch["static"].shape)
# print("Source y shape:", source_batch["y"].shape)

# # Print target batch information
# print("\nTarget batch keys:", target_batch.keys())
# print("Target X shape:", target_batch["X"].shape)
# print("Target static shape:", target_batch["static"].shape)
# print("Target y shape:", target_batch["y"].shape)

In [None]:
import pytorch_lightning as pl
from src.models.TSMixerDomainAdaptation import TSMixerDomainAdaptationConfig

# 1. Create the domain adaptation config
domain_adaptation_config = TSMixerDomainAdaptationConfig(
    input_len=input_length,
    input_size=2,
    output_len=output_length,
    static_size=10,
    hidden_size=hidden_size,
    lambda_adv=1.0,
    domain_loss_weight=0.1,
    group_identifier="gauge_id",
    use_target_labels=True,
)

# 2. Initialize the model with domain adaptation config
model = LitTSMixerDomainAdaptation(config=domain_adaptation_config)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cpu",
    devices=1,
    callbacks=[EarlyStopping(monitor="val_loss", patience=5)],
    enable_progress_bar=True,
)

trainer.fit(model, transfer_dm)

In [None]:
# Example usage:
fig = model.visualize_domain_adaptation(
    source_dataloader=CH_data_module.train_dataloader(),
    target_dataloader=CA_data_module.train_dataloader(),
)

fig.savefig("/Users/cooper/Desktop/CAMELS-CH/images/with_domain_adaptation.png")

In [None]:
trainer.test(model, CA_data_module)
raw_results = model.test_results

# Create evaluator and get metrics
evaluator = TSForecastEvaluator(
    CA_data_module, horizons=list(range(1, model.config.output_len + 1))
)
results_df, overall_metrics, basin_metrics = evaluator.evaluate(raw_results)

# Get overall summary
overall_summary = evaluator.summarize_metrics(overall_metrics)

# Get per-basin summary
basin_summary = evaluator.summarize_metrics(basin_metrics, per_basin=True)

evaluator.test_results = raw_results

# Assuming you have an evaluator with test results already populated
fig, ax = evaluator.plot_rolling_forecast(
    horizon=1,
    group_identifier="CA_15081",
    datamodule=CA_data_module,
    y_label="Streamflow (m³/s)",
    debug=True,
    line_style_forecast="-",
)

plt.show()

In [None]:
def plot_metric_summary(
    summary_df: pd.DataFrame, metric: str, per_basin: bool = False, figsize=(10, 6)
):
    plt.figure(figsize=figsize)

    if per_basin:
        df_plot = summary_df[metric].unstack(level=0)

        # Sort basins based on first horizon values
        first_horizon_values = df_plot.iloc[0]
        sorted_basins = first_horizon_values.sort_values(ascending=False).index
        df_plot = df_plot[sorted_basins]

        sns.barplot(
            data=df_plot.melt(ignore_index=False).reset_index(),
            x="horizon",
            y="value",
            hue="basin_id",
            palette="Blues",
        )
        plt.title(f"{metric} by Basin and Horizon")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Basin ID")

    else:
        ax = sns.barplot(x=summary_df.index, y=summary_df[metric], color="steelblue")
        plt.title(f"Overall {metric} by Horizon")

        for i, v in enumerate(summary_df[metric]):
            ax.text(i, v, f"{v:.2f}", ha="center", va="bottom")

    plt.xlabel("Forecast Horizon")
    plt.ylabel(metric)
    plt.tight_layout()
    sns.despine()
    plt.show()


# Usage example:
plot_metric_summary(overall_summary, "NSE")  # Plot overall NSE
plot_metric_summary(basin_summary, "NSE", per_basin=True, figsize=(12, 6))

## Now with lambda set to 0

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from src.models.TSMixerDomainAdaptation import TSMixerDomainAdaptationConfig

# 1. Create the domain adaptation config
domain_adaptation_config = TSMixerDomainAdaptationConfig(
    input_len=input_length,
    input_size=2,
    output_len=output_length,
    static_size=10,
    hidden_size=hidden_size,
    lambda_adv=0.0,
    domain_loss_weight=0.0,
    group_identifier="gauge_id",
    use_target_labels=True,
)

# 2. Initialize the model2 with domain adaptation config
model2 = LitTSMixerDomainAdaptation(config=domain_adaptation_config)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cpu",
    devices=1,
    callbacks=[EarlyStopping(monitor="val_loss", patience=3)],
    enable_progress_bar=True,
)

trainer.fit(model2, transfer_dm)

In [None]:
# Example usage:
fig = model.visualize_domain_adaptation(
    source_dataloader=CH_data_module.train_dataloader(),
    target_dataloader=CA_data_module.train_dataloader(),
)

fig.savefig("/Users/cooper/Desktop/CAMELS-CH/images/without_domain_adaptation.png")

In [None]:
trainer.test(model, CA_data_module)
raw_results = model.test_results

# Create evaluator and get metrics
evaluator = TSForecastEvaluator(
    CA_data_module, horizons=list(range(1, model.config.output_len + 1))
)
results_df, overall_metrics, basin_metrics = evaluator.evaluate(raw_results)

# Get overall summary
overall_summary = evaluator.summarize_metrics(overall_metrics)

# Get per-basin summary
basin_summary = evaluator.summarize_metrics(basin_metrics, per_basin=True)

In [None]:



def plot_metric_summary(
    summary_df: pd.DataFrame, metric: str, per_basin: bool = False, figsize=(10, 6)
):
    plt.figure(figsize=figsize)

    if per_basin:
        df_plot = summary_df[metric].unstack(level=0)

        # Sort basins based on first horizon values
        first_horizon_values = df_plot.iloc[0]
        sorted_basins = first_horizon_values.sort_values(ascending=False).index
        df_plot = df_plot[sorted_basins]

        sns.barplot(
            data=df_plot.melt(ignore_index=False).reset_index(),
            x="horizon",
            y="value",
            hue="basin_id",
            palette="Blues",
        )
        plt.title(f"{metric} by Basin and Horizon")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Basin ID")

    else:
        ax = sns.barplot(x=summary_df.index, y=summary_df[metric], color="steelblue")
        plt.title(f"Overall {metric} by Horizon")

        for i, v in enumerate(summary_df[metric]):
            ax.text(i, v, f"{v:.2f}", ha="center", va="bottom")

    plt.xlabel("Forecast Horizon")
    plt.ylabel(metric)
    plt.tight_layout()
    sns.despine()
    plt.show()


# Usage example:
plot_metric_summary(overall_summary, "NSE")  # Plot overall NSE
plot_metric_summary(basin_summary, "NSE", per_basin=True, figsize=(12, 6))