# Standard PyTree Registration of Xarray Types Breaks Tree Manipulation Patterns

In response to a [discussion on Github](https://github.com/pydata/xarray/discussions/8164#discussioncomment-10660889)

Consider the following method of registering xarray types as PyTree instances, where we try to reconstruct the tree structure in the unflatten methods

In [None]:
import equinox as eqx
import xarray
from typing import Tuple, Hashable, Mapping, Optional
import jax
from xarray_jax.structs import _HashableCoords


def _flatten_variable(
    v: xarray.Variable,
) -> Tuple[Tuple[jax.typing.ArrayLike], Tuple[Hashable, ...]]:
    children = (v.data,)
    aux = v.dims
    return children, aux


def _unflatten_variable(
    aux: Tuple[Hashable, ...], children: Tuple[jax.typing.ArrayLike]
) -> xarray.Variable:
    dims = aux
    return xarray.Variable(dims=dims, data=children[0])


def _flatten_data_array(
    v: xarray.DataArray,
) -> Tuple[
    # Children (data variable, jax_coord_vars):
    Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
    # Static auxiliary data (name, static_coord_vars):
    Tuple[Optional[Hashable], _HashableCoords],
]:
    """Flattens a DataArray for jax.tree_util."""
    children = (v.variable,)
    aux = (v.name, _HashableCoords(v.coords))
    return children, aux


def _unflatten_data_array(
    aux: Tuple[Optional[Hashable], _HashableCoords],
    children: Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
) -> xarray.DataArray:
    """Unflattens a DataArray for jax.tree_util."""
    (variable,) = children
    name, coord_vars = aux
    return xarray.DataArray(variable, name=name, coords=coord_vars)


def _flatten_dataset(
    dataset: xarray.Dataset,
) -> Tuple[
    # Children (data variables, jax_coord_vars):
    Tuple[Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]],
    # Static auxiliary data (static_coord_vars):
    _HashableCoords,
]:
    """Flattens a Dataset for jax.tree_util."""
    variables = {
        name: data_array.variable for name, data_array in dataset.data_vars.items()
    }
    children = (variables,)
    aux = _HashableCoords(dataset.coords)
    return children, aux


def _unflatten_dataset(
    aux: _HashableCoords,
    children: Tuple[
        Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]
    ],
) -> xarray.Dataset:
    """Unflattens a Dataset for jax.tree_util."""
    (data_vars,) = children
    coords = aux
    dataset = xarray.Dataset(data_vars, coords=coords)
    return dataset


jax.tree_util.register_pytree_node(
    xarray.Variable, _flatten_variable, _unflatten_variable
)
jax.tree_util.register_pytree_node(
    xarray.IndexVariable, _flatten_variable, _unflatten_variable
)
jax.tree_util.register_pytree_node(
    xarray.DataArray, _flatten_data_array, _unflatten_data_array
)

jax.tree_util.register_pytree_node(xarray.Dataset, _flatten_dataset, _unflatten_dataset)


## Breaking it

Now let's work with an example dataset where there is 3D and 2D data. Perhaps we will want to partition the dataset using `eqx.partition` into a subset that is 2D and a subset that is 3D, we run into an issue

In [7]:
import equinox as eqx
import jax

air_ds = xarray.tutorial.load_dataset("air_temperature")
air_ds["average_air"] = air_ds.air.mean(dim="time")


eqx.partition(air_ds, filter_spec=lambda x: x.ndim == 2)

ValueError: dimensions ('time', 'lat', 'lon') must have the same length as the number of data dimensions, ndim=0

## The Root Issue

The breakage above is explained at a lower level by the fact that the xarray types can't take in bools. Let's consider an even more minimal example, where we try to construct a tree mask.

In [9]:
jax.tree.map(lambda x: x.ndim == 2, air_ds)

ValueError: dimensions ('time', 'lat', 'lon') must have the same length as the number of data dimensions, ndim=0