In [3]:
from transfer_learning_publication.models import ModelFactory
from transfer_learning_publication.data import CaravanDataSource, LSHDataModule
from transfer_learning_publication.transforms import PipelineBuilder, Log, ZScore
from transfer_learning_publication.models import ModelFactory, ModelEvaluator
from transfer_learning_publication.models.tide import LitTiDE
from transfer_learning_publication.callbacks import HFUploadCallback
import lightning as pl
from litmodels.integrations import LightningModelCheckpoint
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, Callback, LearningRateMonitor, RichProgressBar
import matplotlib.pyplot as plt
import joblib
import os
from huggingface_hub import upload_file

## Transform data and sink to disk

In [4]:
ts_features = ["temperature_2m_max",
            "temperature_2m_min",
            "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
            "potential_evaporation_sum_ERA5_LAND",
            "temperature_2m_mean",
            "total_precipitation_sum",
            "streamflow",
            "streamflow_was_filled",
            "sin_day_of_year",
            "cos_day_of_year"]

static_features = ["p_mean",
            "area",
            "ele_mt_sav",
            "high_prec_dur",
            "frac_snow",
            "high_prec_freq",
            "slp_dg_sav",
            "cly_pc_sav",
            "aridity_ERA5_LAND",
            "aridity_FAO_PM",]

stages = ["train", "val", "test"]

In [5]:
# ts_preprocessing_pipelines = (
#     PipelineBuilder(group_identifier="gauge_id")
#     .add_per_basin(Log(), columns=["streamflow"])
#     .add_global(ZScore(), columns=["temperature_2m_max",
#             "temperature_2m_min",
#             "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
#             "potential_evaporation_sum_ERA5_LAND",
#             "temperature_2m_mean",
#             "total_precipitation_sum",
#             "streamflow"])
#     .build()
# )

# static_preprocessing_pipelines = (
#     PipelineBuilder(group_identifier="gauge_id")
#     .add_global(ZScore(), columns=static_features)
#     .build()
# )

# for stage in stages:
#     caravan = CaravanDataSource(f"/Users/cooper/Desktop/CARAVAN_CLEAN/{stage}", region="tajikkyrgyz")
#     gauge_ids = caravan.list_gauge_ids()
#     ts_data = caravan.get_timeseries(
#         gauge_ids[:3], 
#         columns=ts_features
#     )
#     static_data = caravan.get_static_attributes(
#         gauge_ids[:3], 
#         columns=static_features
#     )

#     if stage == "train":
#         ts_transformed = ts_preprocessing_pipelines.fit_transform(ts_data.collect())
#         static_transformed = static_preprocessing_pipelines.fit_transform(static_data.collect())
#     else:
#         ts_transformed = ts_preprocessing_pipelines.transform(ts_data.collect())
#         static_transformed = static_preprocessing_pipelines.transform(static_data.collect())

#     caravan.write_timeseries(ts_transformed, f"/Users/cooper/Desktop/first-test/{stage}", overwrite=True)
#     caravan.write_static_attributes(static_transformed, f"/Users/cooper/Desktop/first-test/{stage}", overwrite=True)

#     print(f"Completed {stage} stage")

In [6]:
# # Dump the fitted pipelines
# joblib.dump(ts_preprocessing_pipelines, "/Users/cooper/Desktop/first-test/ts_pipeline.joblib")
# joblib.dump(static_preprocessing_pipelines, "/Users/cooper/Desktop/first-test/static_pipeline.joblib")

## Create datamodule and instantiate model for training

In [7]:
config_path = "/Users/cooper/Desktop/transfer-learning-publication/configs/first_run.yaml"

In [8]:
datamodule = LSHDataModule(config_path=config_path)

model_factory = ModelFactory()
model = model_factory.create_from_config(config_path=config_path)

In [None]:
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="auto",
    devices=1,
    logger=True,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss", 
            mode="min", 
            save_top_k=1,
            dirpath="checkpoints",
            filename="best-{epoch:02d}-{val_loss:.2f}"
        ),
        HFUploadCallback(
            repo_id="CooperBigFoot/test-repo",
            repo_type="model",
            path_in_repo="checkpoints/{filename}",
            private=False
        ),
        EarlyStopping(
            monitor="val_loss", 
            mode="min", 
            patience=5, 
            verbose=True
        ),
        RichProgressBar() 
    ],
    enable_progress_bar=True,
    log_every_n_steps=1
)

trainer.fit(model=model, datamodule=datamodule)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/cooper/Desktop/transfer-learning-publication/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name      | Type      | Params | Mode 
------------------------------------------------
0 | criterion | 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/cooper/Desktop/transfer-learning-publication/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/cooper/Desktop/transfer-learning-publication/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [10]:
models_and_datamodules = {
    "tide": (model, datamodule),
}

evaluator = ModelEvaluator(
    models_and_datamodules=models_and_datamodules,
    trainer_kwargs={"accelerator": "cpu"}
)

results = evaluator.test_models(cache_dir="/Users/cooper/Desktop/transfer-learning-publication/data/cache/first_run/", force_recompute=True, apply_inverse_transform=True)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/cooper/Desktop/transfer-learning-publication/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/cooper/Desktop/transfer-learning-publication/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

  predictions=torch.from_numpy(pred_reshaped).float(),


In [11]:
print(results.summary())

shape: (1, 5)
┌────────────┬───────────┬──────────┬───────────────┬───────────┐
│ model_name ┆ n_samples ┆ n_basins ┆ output_length ┆ has_dates │
│ ---        ┆ ---       ┆ ---      ┆ ---           ┆ ---       │
│ str        ┆ i64       ┆ i64      ┆ i64           ┆ bool      │
╞════════════╪═══════════╪══════════╪═══════════════╪═══════════╡
│ tide       ┆ 6146      ┆ 3        ┆ 10            ┆ true      │
└────────────┴───────────┴──────────┴───────────────┴───────────┘


In [12]:
print(results.raw_data.collect().describe())

shape: (9, 8)
┌────────────┬────────────┬────────────┬───────────┬───────────┬───────────┬───────────┬───────────┐
│ statistic  ┆ model_name ┆ group_iden ┆ lead_time ┆ predictio ┆ observati ┆ issue_dat ┆ predictio │
│ ---        ┆ ---        ┆ tifier     ┆ ---       ┆ n         ┆ on        ┆ e         ┆ n_date    │
│ str        ┆ str        ┆ ---        ┆ f64       ┆ ---       ┆ ---       ┆ ---       ┆ ---       │
│            ┆            ┆ str        ┆           ┆ f64       ┆ f64       ┆ str       ┆ str       │
╞════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════╪═══════════╪═══════════╡
│ count      ┆ 61460      ┆ 61460      ┆ 61460.0   ┆ 61460.0   ┆ 61460.0   ┆ 61460     ┆ 61460     │
│ null_count ┆ 0          ┆ 0          ┆ 0.0       ┆ 0.0       ┆ 0.0       ┆ 0         ┆ 0         │
│ mean       ┆ null       ┆ null       ┆ 5.5       ┆ 1.474145  ┆ 1.510034  ┆ 2020-01-1 ┆ 2020-01-1 │
│            ┆            ┆            ┆           ┆           ┆           ┆ 

In [13]:
print(results.filter(model_name="tide", basin_id="tajikkyrgyz_15013").collect())

shape: (20_610, 7)
┌────────────┬───────────────┬───────────┬────────────┬─────────────┬───────────────┬──────────────┐
│ model_name ┆ group_identif ┆ lead_time ┆ prediction ┆ observation ┆ issue_date    ┆ prediction_d │
│ ---        ┆ ier           ┆ ---       ┆ ---        ┆ ---         ┆ ---           ┆ ate          │
│ str        ┆ ---           ┆ i64       ┆ f64        ┆ f64         ┆ datetime[μs]  ┆ ---          │
│            ┆ str           ┆           ┆            ┆             ┆               ┆ datetime[μs] │
╞════════════╪═══════════════╪═══════════╪════════════╪═════════════╪═══════════════╪══════════════╡
│ tide       ┆ tajikkyrgyz_1 ┆ 1         ┆ 3.940384   ┆ 2.98        ┆ 2017-04-30    ┆ 2017-05-01   │
│            ┆ 5013          ┆           ┆            ┆             ┆ 22:00:00      ┆ 22:00:00     │
│ tide       ┆ tajikkyrgyz_1 ┆ 2         ┆ 2.436957   ┆ 2.81        ┆ 2017-04-30    ┆ 2017-05-02   │
│            ┆ 5013          ┆           ┆            ┆             ┆ 22

In [14]:
type(datamodule._pipeline)

transfer_learning_publication.transforms.composite.CompositePipeline

In [17]:
datamodule.config["features"]["forcing"]

['temperature_2m_max',
 'temperature_2m_min',
 'potential_evaporation_sum_FAO_PENMAN_MONTEITH',
 'potential_evaporation_sum_ERA5_LAND',
 'temperature_2m_mean',
 'total_precipitation_sum',
 'streamflow',
 'streamflow_was_filled',
 'sin_day_of_year',
 'cos_day_of_year']