## Simple Example with air_temperature Dataset

In [None]:
%load_ext autoreload
%autoreload 2

import jax
import xarray as xr
import jax.numpy as jnp
import xarray_jax
import ipdb

air_ds = xr.tutorial.load_dataset("air_temperature")


@jax.jit
def process_ds(ds: xr.Dataset):
    # Take the mean temp across time.
    # Current discussion on making this smoother by enabling automatic dispatching.
    # https://github.com/pydata/xarray/issues/7848
    ds["air_mean"] = xr.apply_ufunc(
        jnp.mean, ds["air"], input_core_dims=[["time"]], kwargs={"axis": -1}
    )

    # We can use the usual tree.map to mask out non-2D arrays.
    two_dim_ds = jax.tree.map(lambda x: x if x.ndim == 2 else None, ds)
    return ds, two_dim_ds


with ipdb.launch_ipdb_on_exception():
    ds, two_dim_ds = process_ds(air_ds)  # Output of type XJDataset

In [None]:
air_ds

# Using a xr.DataArray with Diffrax

In [None]:
import diffrax
import jax
import jax.numpy as jnp

da = xr.DataArray(jnp.arange(10), dims=["x"])


@jax.jit
def fn(t, y, args):
    y_dot = -1.0 * xr.apply_ufunc(jnp.square, y)
    return y_dot


term = diffrax.ODETerm(fn)
solver = diffrax.Dopri5()
ts = jnp.linspace(0, 1, 100)

with ipdb.launch_ipdb_on_exception():
    sol = diffrax.diffeqsolve(
        term, solver, t0=0.0, t1=1.0, dt0=0.01, y0=da, saveat=diffrax.SaveAt(ts=ts)
    )

print(sol.ys)  # TODO: handle the fact that we introduced new coords/dims