In [1]:
import sys
from pathlib import Path

# Add src directory to Python path
project_root = Path.cwd().parent
src_path = project_root / "src"
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))
    print(f"Added {src_path} to Python path")

Added /Users/cooper/Desktop/hydro-forecasting/src to Python path


In [2]:
import random
import time

import numpy as np
import pandas as pd
from tqdm import tqdm

from hydro_forecasting.data.caravanify import Caravanify, CaravanifyConfig
from hydro_forecasting.data.caravanify_parquet import CaravanifyParquet, CaravanifyParquetConfig
from hydro_forecasting.data.preprocessing import Config, split_data

---

In [3]:
config = CaravanifyParquetConfig(
    attributes_dir="/Users/cooper/Desktop/CaravanifyParquet/USA/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CaravanifyParquet/USA/post_processed/timeseries/csv",
    shapefile_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/USA/post_processed/shapefiles",
    # human_influence_path="/Users/cooper/Desktop/CAMELS-CH/src/human_influence_index/results/human_influence_classification.csv",
    gauge_id_prefix="USA",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)

caravan = CaravanifyParquet(config)
basins = caravan.get_all_gauge_ids()[:20]

caravan.load_stations(basins)

static = caravan.get_static_attributes()
static["ele_mt_sav"]

0     276.198974
1     103.274219
2     174.744464
3     303.634479
4     378.958111
5     644.379447
6     626.000000
7     548.306359
8     282.385672
9      43.392693
10    369.000000
11     62.022910
12    187.342696
13    140.423908
14    527.042664
15    739.598304
16    422.189310
17    463.621079
18    456.324693
19    450.489605
Name: ele_mt_sav, dtype: float64

---

In [4]:
config = CaravanifyConfig(
    attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/timeseries/csv",
    shapefile_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CA/post_processed/shapefiles",
    # human_influence_path="/Users/cooper/Desktop/CAMELS-CH/src/human_influence_index/results/human_influence_classification.csv",
    gauge_id_prefix="CA",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)

caravan = Caravanify(config)
basins = caravan.get_all_gauge_ids()
basins = [basin for basin in basins if basin != "CA_15030"]

In [5]:
def load_gauge_parquet(gauge_ids: list[str], base_dir: Path) -> pd.DataFrame:
    """
    Loads the .parquet file for a given list of gauge_ids.

    Args:
        gauge_ids (list[str]): Gauge IDs with the 'USA_' prefix.
        base_dir (Path): Path to the directory containing the parquet files.

    Returns:
        pd.DataFrame: Combined data from the corresponding parquet files.
    """
    data = []

    for gauge_id in gauge_ids:
        file_path = base_dir / f"{gauge_id}.parquet"
        if not file_path.exists():
            raise FileNotFoundError(f"No parquet file found for gauge ID {gauge_id} at {file_path}")
        try:
            df = pd.read_parquet(file_path)
            df["gauge_id"] = gauge_id  # Assign here
            data.append(df)
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            continue

    combined_data = pd.concat(data, ignore_index=True)
    return combined_data

In [6]:
def get_split_boundaries(train, val, test, gauge_ids):
    """
    Determine the date boundaries between train/val/test splits for each gauge ID.

    Args:
        train: Training DataFrame
        val: Validation DataFrame
        test: Test DataFrame
        gauge_ids: List of gauge IDs to process

    Returns:
        Dictionary mapping gauge_ids to their split boundary dates
    """
    split_boundaries = {}

    for gauge_id in gauge_ids:
        # Get min dates for val and test splits for this gauge
        gauge_val = val[val["gauge_id"] == gauge_id]
        gauge_test = test[test["gauge_id"] == gauge_id]

        val_start = gauge_val["date"].min() if not gauge_val.empty else None
        test_start = gauge_test["date"].min() if not gauge_test.empty else None

        split_boundaries[gauge_id] = {"val_start": val_start, "test_start": test_start}

    return split_boundaries


def find_valid_sequences(basin_data, input_length, output_length, cols_to_check=None):
    """
    Find valid sequence starting positions in the basin data.

    Args:
        basin_data: DataFrame containing basin time series data
        input_length: Length of input sequence
        output_length: Length of output sequence
        cols_to_check: Columns to check for NaN values

    Returns:
        Tuple of (valid_positions, dates) arrays
    """
    if cols_to_check is None:
        cols_to_check = ["streamflow", "total_precipitation_sum"]

    total_seq_length = input_length + output_length

    if len(basin_data) < total_seq_length:
        return np.array([]), np.array([])

    # Extract needed data as arrays
    basin_values = basin_data[cols_to_check].to_numpy()
    dates = basin_data["date"].to_numpy()

    # Combined valid mask: 1 if all cols not NaN, 0 otherwise
    combined_valid = (~np.isnan(basin_values).any(axis=1)).astype(int)

    # Convolve to find valid input sequences
    input_conv = np.convolve(combined_valid, np.ones(input_length, dtype=int), mode="valid")
    input_valid = input_conv == input_length

    # Convolve for output sequences, shifted by input_length
    output_conv = np.convolve(combined_valid, np.ones(output_length, dtype=int), mode="valid")
    output_valid = output_conv == output_length
    output_valid_shifted = np.pad(output_valid, (input_length, 0), constant_values=False)[: len(input_valid)]

    # Find valid sequence starts
    valid_mask = input_valid & output_valid_shifted
    valid_positions = np.where(valid_mask)[0]

    return valid_positions, dates


def determine_stage(input_end_date, boundaries):
    """
    Determine which stage (train/val/test) a sequence belongs to based on its end date.

    Args:
        input_end_date: End date of the input sequence
        boundaries: Dictionary with val_start and test_start dates

    Returns:
        String: 'train', 'val', or 'test'
    """
    val_start = boundaries["val_start"]
    test_start = boundaries["test_start"]

    if test_start is not None and input_end_date >= test_start:
        return "test"
    elif val_start is not None and input_end_date >= val_start:
        return "val"
    else:
        return "train"


def create_basin_index(
    gauge_ids: list[str],
    base_dir: Path,
    static_file_path: Path,
    input_length=70,
    output_length=10,
):
    """
    Create index entries for valid sequences, identifying which stage (train/val/test) each sequence belongs to.

    Args:
        gauge_ids: List of gauge IDs to process
        base_dir: Base directory containing parquet files
        static_file_path: Path to static attributes file
        input_length: Length of input sequence
        output_length: Length of forecast horizon

    Returns:
        List of index entries with stage identification
    """
    valid_data = load_gauge_parquet(gauge_ids, base_dir)

    # Create splits
    split_config = Config(
        required_columns=["streamflow"],
        preprocessing_config={},
        min_train_years=10,
        max_imputation_gap_size=5,
    )
    train, val, test = split_data(df=valid_data, config=split_config)

    # Get split boundaries for each gauge
    split_boundaries = get_split_boundaries(train, val, test, gauge_ids)

    all_index_entries = []
    total_seq_length = input_length + output_length

    # Process each basin
    for gauge_id, basin_data in tqdm(valid_data.groupby("gauge_id"), desc="Processing basins"):
        # Create actual file path for this gauge
        ts_file_path = base_dir / f"{gauge_id}.parquet"

        # Get split boundaries for this gauge
        gauge_bounds = split_boundaries.get(gauge_id, {"val_start": None, "test_start": None})

        # Find valid sequences in this basin's data
        valid_positions, dates = find_valid_sequences(basin_data, input_length, output_length)

        # Create index entries with stage identification
        for idx in valid_positions:
            if idx + total_seq_length > len(basin_data):
                continue

            # Get the input_end_date for this sequence
            input_end_date = dates[idx + input_length - 1]

            # Determine stage based on input_end_date
            stage = determine_stage(input_end_date, gauge_bounds)

            # Create entry with stage information
            entry = {
                "file_path": str(ts_file_path),
                "static_file_path": str(static_file_path),
                "gauge_id": gauge_id,
                "start_idx": idx,
                "end_idx": idx + total_seq_length,
                "input_end_date": input_end_date,
                "valid_sequence": True,
                "stage": stage,
            }

            all_index_entries.append(entry)

    return all_index_entries


# Example usage:
data_folder = Path(
    "/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/timeseries/testing_run_hydro_processor/processed_data"
)
static_file = Path("/path/to/static_attributes.csv")
index_entries = create_basin_index(basins, data_folder, static_file)

Processing basins: 100%|██████████| 77/77 [00:00<00:00, 84.93it/s]


In [7]:
import polars as pl


def read_parquet_range(file_path: str, start_idx: int, end_idx: int) -> pl.DataFrame:
    """
    Efficiently read a row slice from a Parquet file using Polars.

    Args:
        file_path: Path to the Parquet file.
        start_idx: Start row index (inclusive).
        end_idx: End row index (exclusive).

    Returns:
        Polars DataFrame with the selected rows.
    """
    columns = [
        "date",
        "streamflow",
        "snow_depth_water_equivalent_mean",
        "surface_net_solar_radiation_mean",
        "surface_net_thermal_radiation_mean",
        "potential_evaporation_sum_ERA5_LAND",
        "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
        "temperature_2m_mean",
        "temperature_2m_min",
        "temperature_2m_max",
        "total_precipitation_sum",
    ]
    return pl.read_parquet(file_path, columns=columns, row_count_name=None, use_pyarrow=False).slice(
        start_idx, end_idx - start_idx
    )


# Example usage
random_indices = random.sample(range(len(index_entries)), 2048)
dfs = []
start = time.time()
for i in tqdm(random_indices, desc="Reading ranges"):
    df = read_parquet_range(
        file_path=index_entries[i]["file_path"],
        start_idx=index_entries[i]["start_idx"],
        end_idx=index_entries[i]["end_idx"],
    )

    dfs.append(df)
    del df
print(f"Time taken to read ranges: {time.time() - start:.5f}s")

  return pl.read_parquet(
Reading ranges: 100%|██████████| 2048/2048 [00:00<00:00, 2322.10it/s]

Time taken to read ranges: 0.88313s





In [8]:
index_entries[0]

{'file_path': '/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/timeseries/testing_run_hydro_processor/processed_data/CA_15013.parquet',
 'static_file_path': '/path/to/static_attributes.csv',
 'gauge_id': 'CA_15013',
 'start_idx': np.int64(70),
 'end_idx': np.int64(150),
 'input_end_date': np.datetime64('2000-05-20T00:00:00.000000000'),
 'valid_sequence': True,
 'stage': 'train'}

In [9]:
dfs[0]

date,streamflow,snow_depth_water_equivalent_mean,surface_net_solar_radiation_mean,surface_net_thermal_radiation_mean,potential_evaporation_sum_ERA5_LAND,potential_evaporation_sum_FAO_PENMAN_MONTEITH,temperature_2m_mean,temperature_2m_min,temperature_2m_max,total_precipitation_sum
datetime[ns],f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
2019-10-22 00:00:00,-0.442249,221.729996,100.93,-74.739998,-0.226614,0.48,-5.54,-9.3,-1.59,0.389522
2019-10-23 00:00:00,-0.398838,222.270004,117.529999,-95.989998,-0.149951,0.39,-7.44,-11.4,-1.68,-0.654873
2019-10-24 00:00:00,-0.442249,221.880005,127.089996,-110.260002,-0.085081,0.34,-6.97,-12.2,-0.56,-0.675271
2019-10-25 00:00:00,-0.463954,221.509995,130.270004,-115.809998,-0.067389,0.33,-7.28,-14.38,-0.15,-0.675271
2019-10-26 00:00:00,-0.463954,221.179993,128.919998,-110.620003,0.038761,0.48,-6.01,-12.76,0.61,-0.671192
…,…,…,…,…,…,…,…,…,…,…
2020-01-05 00:00:00,-0.724422,255.639999,41.389999,-47.349998,-0.928384,0.08,-19.700001,-27.0,-14.04,-0.475368
2020-01-06 00:00:00,-0.767834,257.290009,27.93,-27.25,-0.91659,0.1,-14.29,-17.309999,-10.6,0.316088
2020-01-07 00:00:00,-0.767834,258.570007,38.759998,-59.720001,-0.940178,0.0,-19.4,-25.01,-14.21,-0.548801
2020-01-08 00:00:00,-0.767834,258.630005,41.27,-59.540001,-0.940178,0.0,-23.98,-30.26,-16.98,-0.654873


## Splitting index entries

In [10]:
def split_index_entreis_by_stage(
    index_entries: list[dict],
) -> dict[str, list[dict]]:
    """
    Split index entries into train, val, and test sets based on their stage.

    Args:
        index_entries: List of index entries with stage information

    Returns:
        Dictionary with keys 'train', 'val', and 'test' mapping to lists of index entries
    """
    split_entries = {"train": [], "val": [], "test": []}

    for entry in index_entries:
        stage = entry["stage"]
        if stage in split_entries:
            split_entries[stage].append(entry)

    return split_entries


# Example usage
split_entries = split_index_entreis_by_stage(index_entries)

## Dealing with static attributes

In [11]:
path_to_static = Path(
    "/Users/cooper/Desktop/CaravanifyParquet/CA/post_processed/timeseries/testing_run_hydro_processor/processed_static_data/static_attributes.parquet"
)


def get_static_attributes_from_id(path, id):
    static_df = pd.read_parquet(path_to_static)
    static_df = static_df[static_df["gauge_id"] == id]
    return static_df


statics = []
start = time.time()
for id in basins:
    static = get_static_attributes_from_id(path_to_static, id)
    statics.append(static)
    del static
end = time.time()
print(f"Time taken to read static attributes: {end - start:.5f}s")

Time taken to read static attributes: 0.78807s


In [14]:
def get_static_attributes_from_id(static_df: pl.DataFrame, gauge_id: str, static_columns: list[str]) -> pl.DataFrame:
    """
    Retrieve static attributes for a specific gauge ID from a Polars DataFrame.

    Args:
        static_df: Polars DataFrame containing static attributes for all gauges.
        gauge_id: The gauge ID to filter for.

    Returns:
        Polars DataFrame with static attributes for the specified gauge.
    """
    static_df = pl.read_parquet(str(path_to_static), columns=static_columns)
    filtered_df = static_df.filter(pl.col("gauge_id") == gauge_id)
    return filtered_df


def get_all_statics_for_basins(
    path_to_static: Path, basin_ids: list[str], static_columns: list[str] = None
) -> list[pl.DataFrame]:
    """
    Efficiently retrieve static attributes for a list of basin IDs.

    Args:
        path_to_static: Path to the static attributes parquet file.
        basin_ids: List of basin (gauge) IDs.

    Returns:
        List of Polars DataFrames, one per basin ID.

    Example:
        >>> statics = get_all_statics_for_basins(Path("static_attributes.parquet"), ["CA-001", "CA-002"])
    """
    static_df = pl.read_parquet(path_to_static)
    statics = []
    start = time.time()
    for gauge_id in basin_ids:
        statics.append(get_static_attributes_from_id(static_df, gauge_id, static_columns))
    end = time.time()
    print(f"Time taken to read static attributes: {end - start:.5f}s")
    return statics


static_cols = [
    "gauge_id",
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]

statics = get_static_attributes_from_id(path_to_static, "CA_15013", static_cols)

In [15]:
statics

gauge_id,p_mean,area,ele_mt_sav,high_prec_dur,frac_snow,high_prec_freq,slp_dg_sav,cly_pc_sav,aridity_ERA5_LAND,aridity_FAO_PM
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""CA_15013""",2.243786,254.78646,-0.099639,1.139241,0.378615,0.035202,174.240171,14.563168,0.971101,0.531601
