In [1]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'

# This is required to run multiple processes with JAX.
from multiprocessing import set_start_method
set_start_method('spawn', force=True)

In [2]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from tqdm import tqdm
from pathlib import Path

# jax.config.update("jax_debug_nans", True)

from config import Config

In [None]:
cfg = Config.from_file("/nas/cee-water/cjgleason/ted/swot-ml/runs/reservoirs/e5.yml")

In [5]:
import data
cfg.log = False
cfg.quiet = False
cfg.num_epochs = 2
dataset = data.HydroDataset(cfg)
dataloader = data.HydroDataLoader(cfg, dataset)

Loading graph network file
Loading static attributes
Loading dynamic data
Data Hash: 6b4776391d600d71baf58187379f9cff82407f551d2911e96118668dd32b43df
No matching cached dataset.


Loading Basins: 100%|██████████| 1962/1962 [30:53<00:00,  1.06it/s]


Dataloader using 1 parallel CPU worker(s).
Batch sharding set to 1 cpu(s)


In [5]:
import train
trainer = train.Trainer(cfg, dataloader)
trainer.start_training()

Model contains 29,565 parameters, using 115.49KB memory.


Epoch:001: 100%|██████████| 108/108 [08:27<00:00,  4.70s/it, Loss:0.5774]

2025-08-18 20:07:52,326 - INFO - Epoch: 1, Loss: 0.6132



Validating Epoch:001: 100%|██████████| 108/108 [09:57<00:00,  5.53s/it]

2025-08-18 20:17:49,841 - INFO - Epoch: 1, Validation Loss: 0.1815
2025-08-18 20:17:49,842 - INFO - EarlyStopper: Patience reset. New best loss 0.1815. Improvement was 0.00%.



Epoch:002: 100%|██████████| 108/108 [07:32<00:00,  4.19s/it, Loss:0.5770]

2025-08-18 20:25:22,316 - INFO - Epoch: 2, Loss: 0.6032



Validating Epoch:002: 100%|██████████| 108/108 [09:17<00:00,  5.16s/it]

2025-08-18 20:34:39,947 - INFO - Epoch: 2, Validation Loss: 0.1812
2025-08-18 20:34:39,949 - INFO - EarlyStopper: Patience 1/5. Improvement was 0.14%.
2025-08-18 20:34:39,950 - INFO - ~~~ training done ~~~





In [14]:
for basin, date, batch in dataloader:
    break
batch['y'].shape

(16, 30, 1962, 1)

In [20]:
batch['y'].shape

(16, 30, 1962, 1)

In [21]:
import equinox as eqx

data = batch
keys = jax.random.split(jax.random.PRNGKey(0), cfg.batch_size)

static_keys = ["graph"]
in_axes_data = {k: (None if k in static_keys and k in data else 0) for k in data}
in_axes_keys = 0
y_pred = jax.vmap(trainer.model, in_axes=(in_axes_data, in_axes_keys))(data, keys)

In [26]:
y_pred = y_pred[...,-1]    
y_pred.shape

(16, 30, 1962)

In [25]:
y = batch['y'][...,-1]
y.shape

(16, 30, 1962)

In [31]:
mask = ~jnp.isnan(y)
masked_y = jnp.where(mask, y, 0)
masked_y_pred = jnp.where(mask, y_pred, 0)

Array([[ 0.      ,  0.      ,  0.      , ..., 76.15781 ,  0.      ,
         0.      ],
       [ 0.      ,  0.      ,  0.      , ..., 50.950108,  0.      ,
         0.      ],
       [ 0.      ,  0.      ,  0.      , ..., 26.99347 ,  0.      ,
         0.      ],
       ...,
       [ 0.      ,  0.      ,  0.      , ..., 48.004665,  0.      ,
         0.      ],
       [ 0.      ,  0.      ,  0.      , ..., 23.192217,  0.      ,
         0.      ],
       [ 0.      ,  0.      ,  0.      , ..., 30.40829 ,  0.      ,
         0.      ]], dtype=float32)

In [70]:
seq_len = masked_y.shape[1]
weights = jax.nn.sigmoid(jnp.linspace(-10, 10, seq_len))

sq_error = jnp.square(masked_y - masked_y_pred) * mask * weights[None, :, None]
std_y = jnp.std(masked_y, axis=1, where=mask.astype(bool))  # Per-basin standard deviation
se_sum = jnp.sum(sq_error, axis=1)  # Sum of squared errors per basin
denom = jnp.square(jnp.nan_to_num(std_y) + 0.1)  # Denominator with epsilon for stability

jnp.mean(se_sum / denom)

Array(41.677998, dtype=float32)

In [72]:
"test" in ["test1", "test2"]

False

In [64]:
jnp.nan_to_num(std_y)

Array([[0.        , 0.        , 0.        , ..., 0.47595596, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.38656533, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.54881686, 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.366854  , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.6478458 , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.49915627, 0.        ,
        0.        ]], dtype=float32)

In [None]:
import geopandas as gpd
proj_dir = Path("/nas/cee-water/cjgleason/ted/swot-ml/data/distributed")
metadata_dir = proj_dir / "metadata"
basins = gpd.read_file(metadata_dir / 'matchups.geojson').set_index("HYBAS_ID")
basins.index = basins.index.astype(str)
basins

In [None]:
import evaluate
pred = evaluate.predict(trainer.model, dataloader, denormalize=False)

In [None]:
pred

In [None]:
x = pred.xs('2021-02-08', level='date')


x['pred']['discharge']

In [None]:
x['pred']['discharge'].hist()

In [None]:
basins_pred = basins.merge(x['pred']['discharge'], left_index=True, right_index=True)
basins_pred.plot('discharge')

In [None]:
plt.scatter(x['obs']['discharge'], x['pred']['discharge'])
# plt.plot([0,500], [0,500], 'r--')