In [None]:
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns

from gr6j import ForcingData, ObservedData, Parameters, calibrate, run
from gr6j.utils import CaravanDataSource, analyze_dem, compute_mean_annual_solid_precip


In [2]:
dem = analyze_dem("/Users/nicolaslazaro/Desktop/work/GR6J/data/mountainous-us-basins/REGION_NAME=camels/data_type=DEM/camels_06224000_DEM.tif")

print(dem)

DEMStatistics(
  min_elevation=1792.50,
  max_elevation=4138.39,
  mean_elevation=3131.41,
  median_elevation=3177.64,
  hypsometric_curve=<array shape=(101,)>
)


In [3]:
sns.set_context("paper", font_scale=1.3)

In [None]:
data_source = "/Users/nicolaslazaro/Desktop/work/GR6J/data/mountainous-us-basins"

caravan_ds = CaravanDataSource(data_source)

In [None]:
gids = caravan_ds.list_gauge_ids()

# precipitation: mswep_precipitation
# pet: potential_evaporation_sum_FAO_PENMAN_MONTEITH
# temperature: temperature_2m_mean
# observed discharge: streamflow

time_series = caravan_ds.get_timeseries(gauge_ids=[gids[2]], columns=[
    "era5_land_precipitation_multimet",
    "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
    "temperature_2m_mean",
    "streamflow"
]).collect()

print(time_series.head())

In [None]:
caravan_ds.list_timeseries_variables()

In [None]:
# Schema([('date', Date),
#         ('temperature_2m_mean', Float64),
#         ('potential_evaporation_sum_FAO_PENMAN_MONTEITH', Float64),
#         ('mswep_precipitation', Float64),
#         ('streamflow', Float64),
#         ('REGION_NAME', String),
#         ('gauge_id', String)])


forcing = ForcingData(
    time=np.array(time_series["date"]),
    precip=np.array(time_series["era5_land_precipitation_multimet"]),
    pet=np.array(time_series["potential_evaporation_sum_FAO_PENMAN_MONTEITH"]),
)

print(len(forcing))

In [None]:
params = Parameters(
    x1=350.0,  # Production store capacity [mm]
    x2=0.0,  # Intercatchment exchange coefficient [mm/day]
    x3=90.0,  # Routing store capacity [mm]
    x4=1.7,  # Unit hydrograph time constant [days]
    x5=0.0,  # Intercatchment exchange threshold [-]
    x6=5.0,  # Exponential store scale parameter [mm]
)

# Run the model
output = run(params, forcing)

# Access streamflow
print(output.gr6j.streamflow)

In [None]:
observations = np.array(time_series["streamflow"])

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(forcing.time, observations, label="Observed", color="blue")
plt.plot(forcing.time, output.gr6j.streamflow, label="Simulated", color="orange")
plt.xlabel("Time")
plt.ylabel("Streamflow (mm/day)")
plt.legend()
sns.despine()
plt.show()

In [None]:
warmup_period = 365  # days

observed = ObservedData(
    time=np.array(time_series["date"])[warmup_period:],
    streamflow=np.array(time_series["streamflow"])[warmup_period:],
)

bounds = {
    "x1": (1, 2500),    # Production store capacity [mm]
    "x2": (-5, 5),      # Intercatchment exchange [mm/day]
    "x3": (1, 1000),    # Routing store capacity [mm]
    "x4": (0.5, 10),    # UH time constant [days]
    "x5": (-4, 4),      # Exchange threshold [-]
    "x6": (0.01, 20),   # Exponential store parameter [mm]
}

In [None]:
result = calibrate(
    forcing=forcing,
    observed=observed,
    objectives=["nse"],
    bounds=bounds,
    warmup=365,
    population_size=20,
    generations=500,
    seed=42,
)

print(f"Best NSE: {result.score['nse']:.3f}")
print(f"X1: {result.parameters.x1:.1f}")

In [None]:
output = run(result.parameters, forcing)

plt.figure(figsize=(10, 5))
plt.plot(forcing.time[365:365*2], observations[365:365*2], label="Observed", color="blue")
plt.plot(forcing.time[365:365*2], output.gr6j.streamflow[365:365*2], label="Simulated", color="orange")
plt.xlabel("Time")
plt.ylabel("Streamflow (mm/day)")
plt.legend()
sns.despine()
plt.show()