## Simple Example with air_temperature Dataset

In [1]:
%load_ext autoreload
%autoreload 2

from xarray_jax import XJDataset
import jax
import xarray as xr
import jax.numpy as jnp

air_ds = xr.tutorial.load_dataset("air_temperature")


@jax.jit
def process_ds(ds: XJDataset):  # xr.Dataset gets converted to XJDataset automatically
    ds = ds.to_xarray()  # Convert back to a standard xr.Dataset

    # Take the mean temp across time.
    # Current discussion on making this smoother by enabling automatic dispatching.
    # https://github.com/pydata/xarray/issues/7848
    ds["air_mean"] = xr.apply_ufunc(
        jnp.mean, ds["air"], input_core_dims=[["time"]], kwargs={"axis": -1}
    )

    # We can use the usual tree.map to mask out non-2D arrays.
    two_dim_ds = jax.tree.map(lambda x: x if x.ndim == 2 else None, ds)
    return ds, two_dim_ds


xjds, two_dim_xjds = process_ds(air_ds)  # Output of type XJDataset

ds, two_dim_ds = xjds.to_xarray(), two_dim_xjds.to_xarray()
print(ds)
print(two_dim_ds)

<xarray.Dataset> Size: 16MB
Dimensions:   (time: 2920, lat: 25, lon: 53)
Coordinates:
  * lat       (lat) float32 100B 75.0 72.5 70.0 67.5 ... 22.5 20.0 17.5 15.0
  * lon       (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
  * time      (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air       (time, lat, lon) float32 15MB ...
    air_mean  (lat, lon) float32 5kB ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
<xarray.Dataset> Size: 29kB
Dimensions:   (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat       (lat) float32 100B 75.0 72.5 70.0 67.5 ... 22.5 20.0 17.5 15.0
  * lon       (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
  * time      (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12

# Using a xr.DataArray with Diffrax

In [9]:
import diffrax
import jax
import jax.numpy as jnp

da = xr.DataArray(jnp.arange(10), dims=["x"])


@jax.jit
def fn(t, y, args):
    da = y.to_xarray()
    y_dot = -1.0 * xr.apply_ufunc(jnp.square, da)
    return y_dot


term = diffrax.ODETerm(fn)
solver = diffrax.Dopri5()
ts = jnp.linspace(0, 1, 100)
sol = diffrax.diffeqsolve(
    term, solver, t0=0.0, t1=1.0, dt0=0.01, y0=da, saveat=diffrax.SaveAt(ts=ts)
)

print(sol.ys)  # TODO: handle the fact that we introduced new coords/dims

XJDataArray(
  variable=XJVariable(data=f32[100,10], dims=('x',), attrs={}),
  coords=_HashableCoords(Coordinates:
    *empty*),
  name=None
)
