In [None]:
# !jupyter lab build

In [None]:
import logging

logging.disable(logging.CRITICAL)
# logging.basicConfig(
#     level=logging.INFO,
#     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
#     datefmt="%Y-%m-%d %H:%M:%S",
# )

from hydro_forecasting.experiment_utils.seed_manager import SeedManager  # noqa: E402

seed_manager = SeedManager(42)
seed_manager.set_global_seeds()

In [None]:
from hydro_forecasting.data.caravanify_parquet import (
    CaravanifyParquet,
    CaravanifyParquetConfig,
)
from hydro_forecasting.experiment_utils.finetune_pretrained_model import finetune_pretrained_models
from hydro_forecasting.preprocessing.pipeline_builder import PipelineBuilder

---

## Experiment constants

In [None]:
REGIONS = ["CA"]

COUNTRY = "kyrgyzstan"
ON_COUNTRY = "kyrgyzstan" if COUNTRY == "tajikistan" else "tajikistan"

## Loading the data (as gauge_ids)

In [None]:
def load_basin_ids(country: str) -> list[str]:
    """
    Function to load basins for a given country in Central Asia
    """
    # Make country lowercase and make the first letter uppercase
    country = country.lower()
    country = country.capitalize()

    if country != "Tajikistan" and country != "Kyrgyzstan":
        print("Country not supported")
        return []

    configs = CaravanifyParquetConfig(
        attributes_dir="/workspace/CaravanifyParquet/CA/post_processed/attributes",
        timeseries_dir="/workspace/CaravanifyParquet/CA/post_processed/timeseries/csv",
        gauge_id_prefix="CA",
        use_hydroatlas_attributes=True,
        use_caravan_attributes=True,
        use_other_attributes=True,
    )

    caravan = CaravanifyParquet(configs)
    ca_basins = caravan.get_all_gauge_ids()
    caravan.load_stations(ca_basins)
    static_data = caravan.get_static_attributes()

    return list(static_data[static_data["country"] == country]["gauge_id"].unique())


basin_ids = load_basin_ids(COUNTRY)
print(f"Basins for {COUNTRY}: {len(basin_ids)}")

## Datamodule Configs

In [None]:
region_time_series_base_dirs = {
    region: f"/workspace/CaravanifyParquet/{region}/post_processed/timeseries/csv/{region}" for region in REGIONS
}

region_static_attributes_base_dirs = {
    region: f"/workspace/CaravanifyParquet/{region}/post_processed/attributes/{region}" for region in REGIONS
}

path_to_preprocessing_output_directory = f"/workspace/hydro-forecasting/experiments/finetune/data_cache/{COUNTRY}"

In [None]:
forcing_features = [
    "snow_depth_water_equivalent_mean",
    "surface_net_solar_radiation_mean",
    "surface_net_thermal_radiation_mean",
    "potential_evaporation_sum_ERA5_LAND",
    "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
    "temperature_2m_mean",
    "temperature_2m_min",
    "temperature_2m_max",
    "total_precipitation_sum",
]

static_features = [
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]

target = "streamflow"

In [None]:
builder = PipelineBuilder()

feature_section = (
    builder.features().transforms(["standard_scale"]).strategy("unified", fit_on_n_basins=200).columns(forcing_features)
)

target_section = (
    builder.target()
    .transforms(["log_scale", "standard_scale"])
    .strategy("per_group", group_by="gauge_id")
    .columns([target])
)

static_section = builder.static_features().transforms(["standard_scale"]).strategy("unified").columns(static_features)

preprocessing_config = builder.build()

preprocessing_config["static_features"].keys()

In [None]:
datamodule_config = {
    "region_time_series_base_dirs": region_time_series_base_dirs,
    "region_static_attributes_base_dirs": region_static_attributes_base_dirs,
    "path_to_preprocessing_output_directory": path_to_preprocessing_output_directory,
    "group_identifier": "gauge_id",
    "batch_size": 2048,
    "forcing_features": forcing_features,
    "static_features": static_features,
    "target": target,
    "num_workers": 4,
    "min_train_years": 5,
    "train_prop": 0.5,
    "val_prop": 0.25,
    "test_prop": 0.25,
    "max_imputation_gap_size": 5,
    "chunk_size": 100,
    "validation_chunk_size": 100,
    "is_autoregressive": True,
    "preprocessing_configs": preprocessing_config,
    "random_seed": 42,
}

## Training Configs

In [None]:
training_config = {
    "max_epochs": 200,
    "accelerator": "cuda",
    "devices": 1,
    "early_stopping_patience": 15,
    "reload_dataloaders_every_n_epochs": False,
}

## Remaining Configs

In [None]:
output_dir = "/workspace/hydro-forecasting/experiments/finetune"
pretrained_checkpoint_dir = (
    f"/workspace/hydro-forecasting/experiments/benchmark/pretrain_{COUNTRY}_on_{ON_COUNTRY}/checkpoints"
)


model_types = [
    "tide",
    "ealstm",
    "tsmixer",
    "tft",
]
yaml_paths = [
    f"/workspace/hydro-forecasting/experiments/yaml-files/unified/{COUNTRY}/tide.yaml",
    f"/workspace/hydro-forecasting/experiments/yaml-files/unified/{COUNTRY}/ealstm.yaml",
    f"/workspace/hydro-forecasting/experiments/yaml-files/unified/{COUNTRY}/tsmixer.yaml",
    f"/workspace/hydro-forecasting/experiments/yaml-files/unified/{COUNTRY}/tft.yaml",
]
experiment_name = f"regional_transfer_to_{COUNTRY}_from_{ON_COUNTRY}"
num_runs = 5
override_previous_attempts = False

## Training the models from scratch

In [None]:
train_results = finetune_pretrained_models(
    gauge_ids=basin_ids,
    pretrained_checkpoint_dir=pretrained_checkpoint_dir,
    datamodule_config=datamodule_config,
    training_config=training_config,
    output_dir=output_dir,
    model_types=model_types,
    pretrained_yaml_paths=yaml_paths,
    experiment_name=experiment_name,
    num_runs=num_runs,
    override_previous_attempts=override_previous_attempts,
    lr_reduction_factor=25,
    select_best_from_pretrained=True,
)