In [None]:
import sys
from pathlib import Path

# Add src directory to Python path
project_root = Path.cwd().parent
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))
    print(f"Added {src_path} to Python path")

In [None]:
import polars as pl
from sklearn.pipeline import Pipeline

from hydro_forecasting.data.caravanify_parquet import (
    CaravanifyParquet,
    CaravanifyParquetConfig,
)
from hydro_forecasting.data.in_memory_datamodule import HydroInMemoryDataModule
from hydro_forecasting.model_evaluation.evaluators import TSForecastEvaluator
from hydro_forecasting.model_evaluation.hp_from_yaml import hp_from_yaml
from hydro_forecasting.models.tide import LitTiDE, TiDEConfig
from hydro_forecasting.preprocessing.grouped import GroupedPipeline
from hydro_forecasting.preprocessing.normalize import NormalizeTransformer
from hydro_forecasting.preprocessing.standard_scale import StandardScaleTransformer

---

In [None]:
yaml_path = Path("/Users/cooper/Desktop/hydro-forecasting/experiments/yaml-files/tajikistan/tide.yaml")

tide_hp = hp_from_yaml("tide", yaml_path)
tide_hp

In [None]:
config_ca = CaravanifyParquetConfig(
    attributes_dir="/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/timeseries/csv",
    shapefile_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/shapefiles",
    gauge_id_prefix="CA",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)

caravan_ca = CaravanifyParquet(config_ca)
basin_ids = caravan_ca.get_all_gauge_ids()[:50]

# basin_ids = [bid for bid in basin_ids if bid != "CA_15030"]

caravan_ca.load_stations(basin_ids)

In [None]:
config_us = CaravanifyParquetConfig(
    attributes_dir="/Users/cooper/Desktop/CaravanifyParquet/USA/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CaravanifyParquet/USA/post_processed/timeseries/csv",
    shapefile_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/USA/post_processed/shapefiles",
    gauge_id_prefix="USA",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)

caravan_us = CaravanifyParquet(config_us)
basin_ids += caravan_us.get_all_gauge_ids()[:50]

In [None]:
forcing_features = [
    "snow_depth_water_equivalent_mean",
    "surface_net_solar_radiation_mean",
    "surface_net_thermal_radiation_mean",
    "potential_evaporation_sum_ERA5_LAND",
    "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
    "temperature_2m_mean",
    "temperature_2m_min",
    "temperature_2m_max",
    "total_precipitation_sum",
]

static_features = [
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]

target = "streamflow"

In [None]:
feature_pipeline = GroupedPipeline(
    Pipeline([("scaler", StandardScaleTransformer()), ("normalizer", NormalizeTransformer())]),
    columns=forcing_features,
    group_identifier="gauge_id",
)

target_pipeline = GroupedPipeline(
    Pipeline([("scaler", StandardScaleTransformer()), ("normalizer", NormalizeTransformer())]),
    columns=["streamflow"],
    group_identifier="gauge_id",
)

static_pipeline = Pipeline([("scaler", StandardScaleTransformer())])

preprocessing_config = {
    "features": {"pipeline": feature_pipeline},
    "target": {"pipeline": target_pipeline},
    "static_features": {"pipeline": static_pipeline, "columns": static_features},
}

In [None]:
region_time_series_base_dirs = {
    "CA": "/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/timeseries/csv/CA",
    "USA": "/Users/cooper/Desktop/CaravanifyParquet/USA/post_processed/timeseries/csv/USA",
}

region_static_attributes_base_dirs = {
    "CA": "/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/attributes/CA",
    "USA": "/Users/cooper/Desktop/CaravanifyParquet/USA/post_processed/attributes/USA",
}

DATA_CHUNK_SIZE = 10

datamodule = HydroInMemoryDataModule(
    region_time_series_base_dirs=region_time_series_base_dirs,
    region_static_attributes_base_dirs=region_static_attributes_base_dirs,
    path_to_preprocessing_output_directory="/Users/cooper/Desktop/hydro-forecasting/tests/yolo_6",
    group_identifier="gauge_id",
    batch_size=2048,
    input_length=tide_hp["input_len"],
    output_length=tide_hp["output_len"],
    forcing_features=forcing_features,
    static_features=static_features,
    target=target,
    preprocessing_configs=preprocessing_config,
    num_workers=4,
    min_train_years=5,
    train_prop=0.5,
    val_prop=0.25,
    test_prop=0.25,
    max_imputation_gap_size=5,
    list_of_gauge_ids_to_process=basin_ids,
    is_autoregressive=True,
    chunk_size=DATA_CHUNK_SIZE,
    validation_chunk_size=2 * DATA_CHUNK_SIZE,
)

datamodule.prepare_data()
datamodule.setup()

In [None]:
dataloader = datamodule.get_train_dataloader()

dataset = dataloader.dataset

In [None]:
dataset.chunk_column_tensors

## Verify static data

In [None]:
# test_dataset = datamodule.test_dataset
# if not test_dataset:
#     print("Test dataset not found.")
# elif len(test_dataset) == 0:
#     print("Test dataset is empty.")
# else:
#     print(f"Test dataset size: {len(test_dataset)}")
#     # --- Get a Sample ---
#     sample_index = 1654
#     print(f"Getting sample at index {sample_index}...")
#     try:
#         sample = test_dataset[sample_index]

#         # --- Check for NaNs in the Sample Tensors ---
#         print("\n--- Checking for NaNs in sample tensors ---")
#         for key, tensor in sample.items():
#             if isinstance(tensor, torch.Tensor):
#                 has_nan = torch.isnan(tensor).any().item()
#                 print(f"Tensor'{key}' shape: {tensor.shape}, Contains NaNs: {has_nan}")
#                 print(f"  Sample tensor '{key}': {tensor[:5]}")
#                 if has_nan:
#                     # Optional: Print where NaNs occur
#                     nan_indices = torch.nonzero(torch.isnan(tensor))
#                     print(f"  NaN indices in '{key}': {nan_indices.tolist()[:5]}...") # Print first 5
#             else:
#                 print(f"Item '{key}' is not a tensor (type: {type(tensor)})")

#     except IndexError:
#         print(f"Error: Index {sample_index} out of bounds for dataset size {len(test_dataset)}")
#     except Exception as e:
#         print(f"An error occurred while getting or checking the sample: {e}")

In [None]:
# ie = datamodule.val_index_entries[1661]

# file_path = ie["file_path"]
# start_idx = ie["start_idx"]
# end_idx = ie["end_idx"]
# gauge_id = ie["gauge_id"]

# data = pd.read_parquet(file_path)

# # Slice the data
# data_slice = data.iloc[start_idx:end_idx]
# print(f"Data slice shape: {data_slice.shape}")

# data_slice["streamflow"]

## Let's try training a model

In [None]:
input_length = datamodule.input_length
output_length = datamodule.output_length

config = TiDEConfig(**tide_hp)


# Instantiate the Lightning module.
model = LitTiDE(config)

In [None]:
import pytorch_lightning as pl

trainer = pl.Trainer(max_epochs=5, accelerator="gpu", devices=1, reload_dataloaders_every_n_epochs=True)

# Train the model
trainer.fit(model, datamodule)

In [None]:
models_and_datamodules = {
    "TiDE": (model, datamodule),
}

evaluator = TSForecastEvaluator(
    horizons=list(range(1, output_length + 1)),
    models_and_datamodules=models_and_datamodules,
    trainer_kwargs={
        "accelerator": "cpu",
        "devices": 1,
        # "reload_dataloaders_every_epoch": True,
    },
)

In [None]:
results = evaluator.test_models()

In [None]:
df = results["TiDE"]["df"]

In [None]:
df

In [None]:
unique_ids = df["basin_id"].unique().to_list()
unique_ids

In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import polars as pl
import seaborn as sns

# Define the basin ID to filter
basin_id = "CA_15040"

# Filter the dataframe using polars syntax
df_basin = df.filter(pl.col("basin_id") == basin_id)

# Convert to pandas for plotting (optional, but sometimes easier with matplotlib)
# Alternatively, you can use the polars values directly as shown below
# df_basin_pd = df_basin.to_pandas()

# Create the plot
plt.figure(figsize=(12, 6))

# Use  to access the column values for plotting
plt.plot(df_basin["date"], df_basin["prediction"], label="Prediction", color="blue")
plt.plot(df_basin["date"], df_basin["observed"], label="Observed", color="orange")

# Add title and labels
plt.title(f"Observed vs Prediction for {basin_id}")
plt.xlabel("Date")
plt.ylabel("Streamflow")
plt.legend()

# Format the x-axis for dates
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=1))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
plt.gcf().autofmt_xdate()

# Remove the top and right spines
sns.despine()

# Show the plot
plt.show()