In [1]:
from pathlib import Path

import pandas as pd

from hydro_forecasting.data.caravanify_parquet import (
    CaravanifyParquet,
    CaravanifyParquetConfig,
)

In [2]:
PATH_TO_CLUSTER_ASSIGNED = Path(
    "/Users/cooper/Desktop/hydro-forecasting/scripts/cluster_basins/clustering_results/cluster_assignments_shifted_refactor.csv"
)

df = pd.read_csv(PATH_TO_CLUSTER_ASSIGNED)


def get_ids_for_cluster(df, cluster_id):
    """Get the gauge_ids for a given cluster id"""
    return df[df["cluster"] == cluster_id]["gauge_id"].tolist()


def filter_ids_by_country(ids, country_prefix):
    """Returns the ids that start with the given country prefix"""
    return [gauge_id for gauge_id in ids if gauge_id.startswith(country_prefix)]


# Define countries and static attributes of interest
countries = ["CH", "CL", "USA", "camelsaus", "camelsgb", "camelsbr", "hysets", "lamah"]
statics_of_interest = [
    "gauge_id",  # Important: include gauge_id for filtering
    "p_mean",
    "pet_mean_FAO_PM",
    "frac_snow",
    "aridity_FAO_PM",
    "seasonality_FAO_PM",
    "area",
    "ele_mt_sav",
]

# Create configurations for all countries
configs = {}
for country in countries:
    configs[country] = CaravanifyParquetConfig(
        attributes_dir=f"/Users/cooper/Desktop/CaravanifyParquet/{country}/post_processed/attributes",
        timeseries_dir=f"/Users/cooper/Desktop/CaravanifyParquet/{country}/post_processed/timeseries/csv",
        shapefile_dir=f"/Users/cooper/Desktop/CaravanifyParquet/{country}/post_processed/shapefiles",
        gauge_id_prefix=country,
        use_hydroatlas_attributes=True,
        use_caravan_attributes=True,
        use_other_attributes=True,
    )

# Process all clusters
clusters = sorted(df["cluster"].unique())
all_clusters_data = []

print(f"Processing {len(clusters)} clusters...")

for cluster in clusters:
    print(f"Processing cluster {cluster}...")

    # Get all gauge_ids for this cluster
    cluster_gauge_ids = get_ids_for_cluster(df, cluster)
    cluster_static_dfs = []

    # Process each country
    for country in countries:
        # Filter gauge_ids for this country
        country_gauge_ids = filter_ids_by_country(cluster_gauge_ids, country)

        if len(country_gauge_ids) == 0:
            continue

        print(f"  - Loading {len(country_gauge_ids)} stations for {country}")

        try:
            # Create CaravanifyParquet instance for this country
            caravan = CaravanifyParquet(configs[country])

            # Load static attributes for this country's gauge_ids
            caravan._load_static_attributes(country_gauge_ids)
            country_static = caravan.get_static_attributes()

            if country_static.empty:
                print(f"    Warning: No static attributes found for {country}")
                continue

            # Filter to only include the attributes we're interested in
            available_attrs = [attr for attr in statics_of_interest if attr in country_static.columns]
            missing_attrs = [attr for attr in statics_of_interest if attr not in country_static.columns]

            if missing_attrs:
                print(f"    Warning: Missing attributes for {country}: {missing_attrs}")

            if available_attrs:
                country_static_filtered = country_static[available_attrs].copy()
                cluster_static_dfs.append(country_static_filtered)

        except Exception as e:
            print(f"    Error loading data for {country}: {e}")
            continue

    # Combine all country static data for this cluster
    if cluster_static_dfs:
        cluster_static = pd.concat(cluster_static_dfs, axis=0, ignore_index=True)

        # Add cluster information
        cluster_static["cluster"] = cluster

        # Apply unit conversions (daily to annual)
        if "p_mean" in cluster_static.columns:
            cluster_static["p_mean"] = cluster_static["p_mean"] * 365
        if "pet_mean_FAO_PM" in cluster_static.columns:
            cluster_static["pet_mean_FAO_PM"] = cluster_static["pet_mean_FAO_PM"] * 365

        all_clusters_data.append(cluster_static)
        print(f"  - Added {len(cluster_static)} stations to cluster {cluster}")
    else:
        print(f"  - No data found for cluster {cluster}")

# Combine all cluster data
if all_clusters_data:
    combined_df = pd.concat(all_clusters_data, axis=0, ignore_index=True)
    print(f"\nTotal stations processed: {len(combined_df)}")

    # Calculate cluster statistics
    results = pd.DataFrame(index=sorted(combined_df["cluster"].unique()))

    # Calculate statistics for each attribute (excluding gauge_id and cluster)
    numeric_attrs = [attr for attr in statics_of_interest if attr not in ["gauge_id", "cluster"]]

    for attr in numeric_attrs:
        if attr in combined_df.columns:
            stats = combined_df.groupby("cluster")[attr].agg(["mean", "std", "count"])
            formatted_stats = stats.apply(
                lambda row: f"{row['mean']:.1f} ± {row['std']:.1f} (n={int(row['count'])})", axis=1
            )
            results[attr] = formatted_stats
        else:
            print(f"Warning: Attribute {attr} not found in combined data")

    # Add number of stations column
    station_counts = combined_df.groupby("cluster").size()
    results["number_of_stations"] = station_counts

    print("\nCluster Statistics:")
    print("=" * 50)
    print(results)

    # Save results
    results.to_csv("cluster_statistics.csv")
    combined_df.to_csv("cluster_detailed_data.csv", index=False)
    print(f"\nResults saved to 'cluster_statistics.csv' and 'cluster_detailed_data.csv'")

else:
    print("No data was successfully processed!")

Processing 11 clusters...
Processing cluster 0...
  - Loading 5 stations for CL
  - Loading 54 stations for USA
  - Loading 83 stations for camelsbr
  - Loading 1012 stations for hysets
  - Loading 24 stations for lamah
  - Added 1178 stations to cluster 0
Processing cluster 1...
  - Loading 34 stations for CH
  - Loading 45 stations for CL
  - Loading 233 stations for USA
  - Loading 51 stations for camelsaus
  - Loading 42 stations for camelsgb
  - Loading 40 stations for camelsbr
  - Loading 3092 stations for hysets
  - Loading 132 stations for lamah
  - Added 3669 stations to cluster 1
Processing cluster 2...
  - Loading 53 stations for CH
  - Loading 59 stations for CL
  - Loading 75 stations for USA
  - Loading 27 stations for camelsbr
  - Loading 1536 stations for hysets
  - Loading 164 stations for lamah
  - Added 1914 stations to cluster 2
Processing cluster 3...
  - Loading 9 stations for CH
  - Loading 187 stations for CL
  - Loading 56 stations for USA
  - Loading 11 statio