In [1]:
# import os
# os.environ['JAX_PLATFORMS'] = 'cpu'

from multiprocessing import set_start_method
set_start_method('spawn', force=True)

In [2]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
from pathlib import Path
from importlib import reload

import config, data, models, train, evaluate

In [3]:
run_dir = Path("/nas/cee-water/cjgleason/ted/swot-ml/runs/Ohio/masked_assimilation/e5_sr_sl_gs_20250819_134746")

fig_dir = run_dir / "figures" / "feature_importance"
fig_dir.mkdir(exist_ok=True, parents=True)

In [4]:
trainer = train.Trainer.load_last_checkpoint(run_dir)
cfg = trainer.cfg

Model contains 33,149 parameters, using 129.49KB memory.
Creating new Trainer instance...
Logging at /nas/cee-water/cjgleason/ted/swot-ml/runs/Ohio/masked_assimilation/e5_sr_sl_gs_20250819_134746


In [5]:
cfg.quiet = False
dataset = data.HydroDataset(cfg)
dataloader = data.HydroDataLoader(cfg, dataset)

Loading graph network file
Loading static attributes
Loading dynamic data
Data Hash: b276fccb0bccc2e09f7b6b931938fb863d54fd920d2264a01fc69e18b2108215
Using cached basin dataset.
Dataloader using 1 parallel CPU worker(s).
Batch sharding set to 1 gpu(s)


In [6]:
model = trainer.model
target = 'discharge'
max_iter = 10

In [7]:
df = evaluate.get_intgrads_df(cfg, model, dataloader, target, max_iter=max_iter)
df

Calculating IG for discharge:   0%|          | 0/10 [00:07<?, ?it/s]

TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value Traced<ShapedArray(float32[1961])>with<DynamicJaxprTrace>

In [8]:
for basin, date, batch in dataloader:
    break
batch.keys()

dict_keys(['dynamic', 'static', 'graph', 'y'])

In [14]:
batch['dynamic'].keys()

dict_keys(['era5', 'swot_r', 'swot_l', 'glow'])

In [20]:
def get_single(b):
    single = {}
    for k, v in b.items():
        if isinstance(v, dict):
            single[k] = {kk: vv[0,...] for kk, vv in v.items()}
        elif k == 'graph':
            single[k] = v
        else:
            single[k] = v[0,...]

    return single

single = get_single(batch)


print(single['dynamic']['era5'].shape)
print(single['y'].shape)

(90, 1962, 20)
(90, 1962, 1)


In [21]:
key = jax.random.PRNGKey(0)
model(single, key)

Array([[[ 1.9637854 ],
        [ 3.375012  ],
        [-0.17155433],
        ...,
        [-1.2903222 ],
        [-2.5174625 ],
        [-2.4821103 ]],

       [[ 5.3253927 ],
        [ 6.0375004 ],
        [-0.01480161],
        ...,
        [-1.2726734 ],
        [-2.3978026 ],
        [-2.397909  ]],

       [[ 5.6072664 ],
        [ 5.9878926 ],
        [-0.8873311 ],
        ...,
        [-0.04824136],
        [-2.2399478 ],
        [-2.315516  ]],

       ...,

       [[ 6.004189  ],
        [ 6.0954647 ],
        [-0.475321  ],
        ...,
        [-0.00909531],
        [-0.9732019 ],
        [-1.1797688 ]],

       [[ 5.970257  ],
        [ 6.029078  ],
        [-0.57309586],
        ...,
        [-0.05207691],
        [-1.0274509 ],
        [-1.2709379 ]],

       [[ 5.951088  ],
        [ 6.0028963 ],
        [-0.64526165],
        ...,
        [-0.09896851],
        [-1.1272274 ],
        [-1.3353819 ]]], dtype=float32)

In [9]:
import jax
import jax.numpy as jnp
import equinox as eqx


def _calculate_ig_single(model, single_input_tree, baseline_tree, target_idx, m_steps):
    """Core Integrated Gradients calculation for one sequence."""
    key = jax.random.PRNGKey(0)

    @eqx.filter_grad
    def grad_target_output_fn(interpolated_input):
        pred = model(interpolated_input, key=key)
        return pred[target_idx]

    alphas = jnp.linspace(start=0.0, stop=1.0, num=m_steps + 1)

    # Interpolate inputs: Pytree leaves become shape (m_steps+1, *original_shape)
    interpolated_inputs = jax.vmap(
        lambda alpha: jax.tree.map(
            lambda x, b: b + alpha * (x - b),
            single_input_tree,
            baseline_tree,
        )
    )(alphas)

    # Calculate gradients at each interpolated step
    # Pytree leaves become shape (m_steps+1, *original_shape)
    static_keys = ["graph"]
    data_axes_spec = {key: (None if key in static_keys else 0) for key in interpolated_inputs}
    interpolated_grads = jax.vmap(grad_target_output_fn, in_axes=(data_axes_spec,))(interpolated_inputs)

    # Average grads using trapezoidal rule approximation
    def trapezoid_avg(g):
        integral_approx = jnp.sum(g[1:] + g[:-1], axis=0) / 2.0
        return integral_approx / m_steps

    avg_grads = jax.tree.map(trapezoid_avg, interpolated_grads)

    # Calculate diff and final attribution: (input - baseline) * avg_grads
    input_diff = jax.tree.map(
        lambda x, b: x - b,
        single_input_tree,
        baseline_tree,
    )
    ig_attribs = jax.tree.map(
        lambda diff, avg_grad: diff * avg_grad,
        input_diff,
        avg_grads,
    )

    return ig_attribs




In [12]:
from tqdm import tqdm

target = 'discharge'
m_steps = 10
max_iter = 50


# Set model to inference mode (no dropout)
model = eqx.nn.inference_mode(model)

targets = cfg.features.target
target_idx = targets.index(target)

baseline = jax.tree.map(jnp.zeros_like, batch)
if 'graph' in baseline.keys():
    baseline['graph'] = batch['graph']

# Create an in_axes pytree that matches the batch structure.
# Use `None` for static keys and `0` for keys to be batched over.
# TODO: If we start training with mixes of different basins we will need to fix this.
static_keys = ["graph"]
data_axes_spec = {key: (None if key in static_keys else 0) for key in batch}
# Vmap the single calculation over the batch
batched_ig_fn = jax.vmap(
    _calculate_ig_single,
    in_axes=(None, data_axes_spec, data_axes_spec, None, None),
)
batch_ig_attribs_tree = batched_ig_fn(
    model,
    batch,
    baseline,
    target_idx,
    m_steps,
)

Calculating IG for discharge:   0%|          | 0/50 [00:20<?, ?it/s]


ValueError: Incompatible shapes for broadcasting: shapes=[(1962, 1), (1962, 16), (11, 16)]

In [11]:
batch.keys()

dict_keys(['dynamic', 'static', 'graph', 'y'])

In [18]:
static_keys = ["graph"]
in_axes_data = {k: (None if k in static_keys else 0) for k in batch}
in_axes_keys = 0
(in_axes_data, in_axes_keys)

({'dynamic': 0, 'static': 0, 'graph': None, 'y': 0}, 0)

In [10]:
import jax

def test_fn(batch):
    print(batch['y'].shape)

static_keys = ["graph"]
axes_spec = {key: (None if key in static_keys else 0) for key in batch}
jax.vmap(test_fn, in_axes=(axes_spec,))(batch)

(90, 1962, 1)


In [11]:
import jax
import jax.numpy as jnp

baseline = jax.tree.map(jnp.zeros_like, batch)
if 'graph' in baseline.keys():
    baseline['graph'] = batch['graph']



In [13]:
alphas = jnp.linspace(start=0.0, stop=1.0, num=10 + 1)

# Interpolate inputs: Pytree leaves become shape (m_steps+1, *original_shape)
interpolated_inputs = jax.vmap(
    lambda alpha: jax.tree.map(
        lambda x, b: b + alpha * (x - b),
        batch,
        baseline,
    )
)(alphas)

In [16]:
interpolated_inputs['y'].shape

(11, 16, 90, 1962, 1)

In [14]:
if 'graph' in baseline.keys():
    baseline['graph'] = batch['graph']

baseline['graph']

GraphData(edge_index=array([[   1,    2,    3, ..., 1959, 1960, 1961],
       [   0,    0,    1, ..., 1957, 1959, 1959]], dtype=int32), edge_features=array([[-1.1752551],
       [-1.1686114],
       [-0.8497144],
       ...,
       [-0.5972544],
       [-0.1454836],
       [-0.1454836]], dtype=float32), node_features=array([[ 9.49360817,  8.59906121,  9.53725847, ..., 11.79588306,
        -1.81860692, -1.91744996],
       [ 9.48582837,  8.60398514,  9.53403527, ..., 11.79588306,
        -1.81860692,  1.03641061],
       [-0.30431284, -0.31547071, -0.30267431, ..., 11.79588306,
        -1.81860692,  1.03641061],
       ...,
       [-0.29283699, -0.24314263, -0.2676474 , ..., -0.08477534,
        -0.59581502, -0.73590573],
       [-0.29952549, -0.28103475, -0.28692021, ..., -0.08477534,
        -0.59581502, -0.73590573],
       [-0.29998211, -0.29098784, -0.29161852, ..., -0.08477534,
        -0.59581502, -0.73590573]]))