In [1]:
# !jupyter lab build

In [2]:
import logging

# logging.disable(logging.CRITICAL)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

from hydro_forecasting.experiment_utils.seed_manager import SeedManager  # noqa: E402

seed_manager = SeedManager(42)
seed_manager.set_global_seeds()

2025-07-08 13:49:25 - matplotlib.font_manager - INFO - generated new fontManager
  warn(
2025-07-08 13:49:25 - hydro_forecasting.experiment_utils.seed_manager - INFO - SeedManager initialized with master seed: 42
2025-07-08 13:49:25 - lightning_fabric.utilities.seed - INFO - Seed set to 42


In [3]:
from hydro_forecasting.data.caravanify_parquet import (
    CaravanifyParquet,
    CaravanifyParquetConfig,
)
from hydro_forecasting.experiment_utils.finetune_pretrained_model import finetune_pretrained_models
from hydro_forecasting.preprocessing.pipeline_builder import PipelineBuilder

---

## Experiment constants

In [4]:
REGIONS = ["CA"]

COUNTRY = "kyrgyzstan"
ON_COUNTRY = "kyrgyzstan" if COUNTRY == "tajikistan" else "tajikistan"

## Loading the data (as gauge_ids)

In [5]:
def load_basin_ids(country: str) -> list[str]:
    """
    Function to load basins for a given country in Central Asia
    """
    # Make country lowercase and make the first letter uppercase
    country = country.lower()
    country = country.capitalize()

    if country != "Tajikistan" and country != "Kyrgyzstan":
        print("Country not supported")
        return []

    configs = CaravanifyParquetConfig(
        attributes_dir="/workspace/CaravanifyParquet/CA/post_processed/attributes",
        timeseries_dir="/workspace/CaravanifyParquet/CA/post_processed/timeseries/csv",
        gauge_id_prefix="CA",
        use_hydroatlas_attributes=True,
        use_caravan_attributes=True,
        use_other_attributes=True,
    )

    caravan = CaravanifyParquet(configs)
    ca_basins = caravan.get_all_gauge_ids()
    caravan.load_stations(ca_basins)
    static_data = caravan.get_static_attributes()

    return list(static_data[static_data["country"] == country]["gauge_id"].unique())


basin_ids = load_basin_ids(COUNTRY)
print(f"Basins for {COUNTRY}: {len(basin_ids)}")

Basins for kyrgyzstan: 62


## Datamodule Configs

In [6]:
region_time_series_base_dirs = {
    region: f"/workspace/CaravanifyParquet/{region}/post_processed/timeseries/csv/{region}" for region in REGIONS
}

region_static_attributes_base_dirs = {
    region: f"/workspace/CaravanifyParquet/{region}/post_processed/attributes/{region}" for region in REGIONS
}

path_to_preprocessing_output_directory = f"/workspace/hydro-forecasting/experiments/finetune/data_cache/{COUNTRY}"

In [7]:
forcing_features = [
    "snow_depth_water_equivalent_mean",
    "surface_net_solar_radiation_mean",
    "surface_net_thermal_radiation_mean",
    "potential_evaporation_sum_ERA5_LAND",
    "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
    "temperature_2m_mean",
    "temperature_2m_min",
    "temperature_2m_max",
    "total_precipitation_sum",
]

static_features = [
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]

target = "streamflow"

In [8]:
builder = PipelineBuilder()

feature_section = (
    builder.features().transforms(["standard_scale", "normalize"]).strategy("per_group", group_by="gauge_id").columns(forcing_features)
)

target_section = (
    builder.target()
    .transforms(["standard_scale", "normalize"])
    .strategy("per_group", group_by="gauge_id")
    .columns([target])
)

static_section = builder.static_features().transforms(["standard_scale"]).strategy("unified").columns(static_features)

preprocessing_config = builder.build()

preprocessing_config["static_features"].keys()

dict_keys(['pipeline', 'strategy', 'columns'])

In [9]:
datamodule_config = {
    "region_time_series_base_dirs": region_time_series_base_dirs,
    "region_static_attributes_base_dirs": region_static_attributes_base_dirs,
    "path_to_preprocessing_output_directory": path_to_preprocessing_output_directory,
    "group_identifier": "gauge_id",
    "batch_size": 2048,
    "forcing_features": forcing_features,
    "static_features": static_features,
    "target": target,
    "num_workers": 4,
    "min_train_years": 5,
    "train_prop": 0.5,
    "val_prop": 0.25,
    "test_prop": 0.25,
    "max_imputation_gap_size": 5,
    "chunk_size": 100,
    "validation_chunk_size": 100,
    "is_autoregressive": True,
    "preprocessing_configs": preprocessing_config,
    "random_seed": 42,
}

## Training Configs

In [None]:
training_config = {
    "max_epochs": 200,
    "accelerator": "cuda",
    "devices": 1,
    "early_stopping_patience": 10,
    "reload_dataloaders_every_n_epochs": False,
}

## Remaining Configs

In [11]:
output_dir = "/workspace/hydro-forecasting/experiments/finetune"
pretrained_checkpoint_dir = (
    f"/workspace/hydro-forecasting/experiments/similar_catchments/similar-catchments_{COUNTRY}/checkpoints"
)

model_types = [
    # "tide",
    # "ealstm",
    # "tsmixer",
    "tft",
]
yaml_paths = [
    # f"/workspace/hydro-forecasting/experiments/yaml-files/{COUNTRY}/tide.yaml",
    # f"/workspace/hydro-forecasting/experiments/yaml-files/{COUNTRY}/ealstm.yaml",
    # f"/workspace/hydro-forecasting/experiments/yaml-files/{COUNTRY}/tsmixer.yaml",
    f"/workspace/hydro-forecasting/experiments/yaml-files/{COUNTRY}/tft.yaml",
]
experiment_name = f"finetune_similar_{COUNTRY}"
num_runs = 5
override_previous_attempts = False

## Training the models from scratch

In [None]:
train_results = finetune_pretrained_models(
    gauge_ids=basin_ids,
    pretrained_checkpoint_dir=pretrained_checkpoint_dir,
    datamodule_config=datamodule_config,
    training_config=training_config,
    output_dir=output_dir,
    model_types=model_types,
    pretrained_yaml_paths=yaml_paths,
    experiment_name=experiment_name,
    num_runs=num_runs,
    override_previous_attempts=override_previous_attempts,
    lr_reduction_factor=25,
    select_best_from_pretrained=True,
)

2025-07-08 13:49:26 - hydro_forecasting.experiment_utils.finetune_pretrained_model - INFO - Found pre-trained checkpoint for tft: /workspace/hydro-forecasting/experiments/similar_catchments/similar-catchments_kyrgyzstan/checkpoints/tft/run_0/attempt_0/tft-run0-attempt_0-epoch=113-val_loss=0.0365.ckpt
2025-07-08 13:49:26 - hydro_forecasting.experiment_utils.training_runner - INFO - Starting experiment 'finetune_similar_kyrgyzstan' from ExperimentRunner.
2025-07-08 13:49:26 - hydro_forecasting.experiment_utils.training_runner - INFO - Processing model (1/1): tft
2025-07-08 13:49:26 - hydro_forecasting.experiment_utils.training_runner - INFO - Processing model: tft using HPs from /workspace/hydro-forecasting/experiments/yaml-files/kyrgyzstan/tft.yaml
2025-07-08 13:49:26 - hydro_forecasting.experiment_utils.seed_manager - INFO - SeedManager initialized without master seed (non-deterministic mode)
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Initialized SeedMan

The following parameters were not found in the YAML file and will use defaults:
  - hidden_continuous_size (model-specific)
  - quantiles (model-specific)
  - scheduler_factor (model-specific)
  - scheduler_patience (model-specific)


2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Found 59 basins for synchronized train/val chunking and validation pool selection.
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Found 59 basins for test split.
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Data preparation finished.
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Loading static data cache and converting to Tensors...
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Loaded and tensorized static data for 59 basins.
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Created fixed validation pool with 59 basins: ['CA_15212', 'CA_15189', 'CA_16068', 'CA_15034', 'CA_16139']...
2025-07-08 13:49:26 - hydro_forecasting.data.in_memory_datamodule - INFO - Loading and caching data for 59 validation basins...
2025-07-08 13:49:27 - hydro_forecasting.data.in_memory_datamodule 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:49:27 - hydro_forecasting.data.in_memory_datamodule - INFO - Epoch 0: Val Dataloader using cached validation data with 76914 samples from 59 basins.
2025-07-08 13:49:28 - hydro_forecasting.data.in_memory_datamodule - INFO - Epoch 0: Train Dataloader using chunk 1/1 with 59 basins.
2025-07-08 13:49:29 - hydro_forecasting.data.in_memory_datamodule - INFO - Stage 'train' chunk data loaded for 59 basins. Shape: (193981, 12). Est. Mem: 8.88 MB


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:49:56 - pytorch_lightning.callbacks.early_stopping - INFO - Metric val_loss improved. New best score: 0.025


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:50:50 - pytorch_lightning.callbacks.early_stopping - INFO - Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.025


Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:51:17 - pytorch_lightning.callbacks.early_stopping - INFO - Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.025


Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:51:44 - pytorch_lightning.callbacks.early_stopping - INFO - Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.025


Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:52:11 - pytorch_lightning.callbacks.early_stopping - INFO - Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.025


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:58:54 - pytorch_lightning.callbacks.early_stopping - INFO - Monitored metric val_loss did not improve in the last 15 records. Best score: 0.025. Signaling Trainer to stop.
2025-07-08 13:58:54 - hydro_forecasting.experiment_utils.training_runner - INFO - Run 0 completed. Best val_loss: 0.024699846282601357, Path: /workspace/hydro-forecasting/experiments/finetune/finetune_similar_kyrgyzstan/checkpoints/tft/run_0/attempt_2/tft-run0-attempt_2-epoch=05-val_loss=0.0247.ckpt
2025-07-08 13:58:54 - hydro_forecasting.experiment_utils.training_runner - INFO - Starting run 2/5 for tft
2025-07-08 13:58:54 - hydro_forecasting.experiment_utils.seed_manager - INFO - SeedManager initialized with master seed: 42
2025-07-08 13:58:54 - hydro_forecasting.experiment_utils.seed_manager - INFO - SeedManager initialized with master seed: 987689484
2025-07-08 13:58:54 - lightning_fabric.utilities.seed - INFO - Seed set to 987689484
2025-07-08 13:58:54 - hydro_forecasting.experiment_utils.training_

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:58:54 - hydro_forecasting.data.in_memory_datamodule - INFO - Epoch 0: Val Dataloader using cached validation data with 76914 samples from 59 basins.
2025-07-08 13:58:54 - hydro_forecasting.data.in_memory_datamodule - INFO - Completed full pass through training shared chunks. Recomputing.
2025-07-08 13:58:54 - hydro_forecasting.data.in_memory_datamodule - INFO - Initializing/Re-initializing training shared chunks from 59 basins.
2025-07-08 13:58:54 - hydro_forecasting.data.in_memory_datamodule - INFO - Created 1 training shared chunks.
2025-07-08 13:58:54 - hydro_forecasting.data.in_memory_datamodule - INFO - Epoch 0: Train Dataloader using chunk 1/1 with 59 basins.
2025-07-08 13:58:55 - hydro_forecasting.data.in_memory_datamodule - INFO - Stage 'train' chunk data loaded for 59 basins. Shape: (193981, 12). Est. Mem: 8.88 MB


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

2025-07-08 13:59:22 - pytorch_lightning.callbacks.early_stopping - INFO - Metric val_loss improved. New best score: 0.025
