# JAX Tree Manipulations + Xarray

One of the challenges with registering Xarray types as JAX PyTrees is it is common in JAX-land to perform tree manipulations. For example, a common pattern is to create boolean tree masks using `jax.tree.map`, but doing so calls constructors that perform things like dimension checks.

In the case of `xr.Variable` and `xr.DataArray`, we can just manually disable a couple checks, and then things work as expected.

In [4]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp
import xarray as xr
import jax
import xarray_jax

# After manually commenting out a couple lines of code in Xarray, the following are ok!
var = xr.Variable("x", jnp.array([1, 2, 3]))

var_mask = jax.tree.map(lambda x: True, var)

da = xr.DataArray(var, dims=["x"], coords={"x": [0, 1, 2]})

da_mask = jax.tree.map(lambda x: True, da)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## `xr.Dataset`

`xr.Dataset` is more involved because the function `calculate_dimensions`, which maps dimensions to sizes seems to be needed.

In [5]:
ds = xr.Dataset({"da": da})

ds_mask = jax.tree.map(lambda x: True, ds)

ValueError: zip() argument 2 is shorter than argument 1