In [1]:
import numpy as np
import xarray as xr

xr.set_options(display_expand_data=False);

In [2]:
rng = np.random.default_rng()

In [3]:
categories = xr.DataArray(["Tech", "Finances", "Food", "Energy"], dims=["sector"])
sector_idx = rng.choice(np.arange(categories.size), size=77)
obs = xr.DataArray(
    rng.normal(size=(4, 500, 77)),
    dims=["chain", "draw", "obs_id"],
    coords={"sector_idx": (("obs_id",), sector_idx)}
)
obs

In [4]:
# all for cases below are `categories.values[index]`, what changes are the names and dimensions in output
categories.isel(sector=sector_idx)

In [5]:
categories.isel(sector=obs.coords["sector_idx"])

In [6]:
categories.isel(sector=(("new_dim",), sector_idx))

In [7]:
categories.isel(sector=(("new_dim", "new_obs"), sector_idx.reshape(7, 11)))

In [8]:
# not supported which is nice imo
categories.isel(sector=sector_idx.reshape(7, 11))

IndexError: Unlabeled multi-dimensional array cannot be used for indexing: sector

In [9]:
divergences_chain_idxs = xr.DataArray(rng.choice(np.arange(4), size=200), dims=["divergence_id"])
divergences_draw_idxs = xr.DataArray(rng.choice(np.arange(500), size=200), dims=["divergence_id"])

In [10]:
# obs.values[divergences_chain_idxs.values, divergences_draw_idxs.values, :]
obs.isel(chain=divergences_chain_idxs, draw=divergences_draw_idxs)

In [11]:
mask = obs.coords["sector_idx"] == 3
mask

In [12]:
obs.isel(chain=divergences_chain_idxs, draw=divergences_draw_idxs, obs_id=mask)

In [13]:
# I don't think it is possible to achieve this without converting the boolean mask to their integer indexes equivalent
# if it is I don't know how
mask_idx, = np.nonzero(mask.values)
obs.values[divergences_chain_idxs.values[:, None], divergences_draw_idxs.values[:, None], mask_idx[None, :]]

array([[-1.293101  , -0.71500273,  0.33632339, ...,  0.93487119,
        -0.99257948, -0.25500398],
       [-0.10294998,  0.66230001, -0.0227726 , ...,  0.03384014,
         0.67932258,  0.71239728],
       [ 0.6063118 , -1.59812246, -0.75991642, ...,  1.69820961,
         1.01326507,  0.55442668],
       ...,
       [-0.88487083, -0.556576  ,  1.482051  , ...,  0.97568127,
        -0.9281762 , -1.66976961],
       [-0.19259503,  0.23339986,  0.73059913, ...,  0.35347429,
         0.88142908, -0.7993959 ],
       [ 0.57106014,  0.05834826,  0.35048907, ...,  1.06661333,
        -0.04908627, -0.57607001]], shape=(200, 16))

In [14]:
# naive numpy breaks
obs.values[divergences_chain_idxs.values, divergences_draw_idxs.values, mask]

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (200,) (200,) (16,) 

In [15]:
# and attempted broadcasting with boolean mask instead of indexes breaks too
obs.values[divergences_chain_idxs.values[:, None], divergences_draw_idxs.values[:, None], mask[None, :]]

IndexError: too many indices

In [16]:
# not supported, also nice imo
obs.isel(chain=divergences_chain_idxs, draw=divergences_draw_idxs, obs_id=mask.rename(obs_id="other_name"))

IndexError: Boolean indexer should be unlabeled or on the same dimension to the indexed array. Indexer is on ('other_name',) but the target dimension is obs_id.

In [17]:
divergences_draw_idxs = xr.DataArray(rng.choice(np.arange(500), size=121), dims=["other_dim"])
obs.isel(chain=divergences_chain_idxs, draw=divergences_draw_idxs)

In [18]:
# two consecutive calls works, but in the next example it gets quite confusing
obs.values[divergences_chain_idxs.values, :, :][:, divergences_draw_idxs.values, :].shape

(200, 121, 77)

In [19]:
# broadcasting also works and seems quite robust as long as no boolean masks/they are first converted to integer as above
obs.values[divergences_chain_idxs.values[:, None], divergences_draw_idxs.values[None, :], :].shape

(200, 121, 77)

In [20]:
# naive numpy breaks
obs.values[divergences_chain_idxs.values, divergences_draw_idxs.values, :]

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (200,) (121,) 

In [21]:
obs.isel(draw=(("new_dim",), np.arange(7)), obs_id=(("new_dim", "new_obs"), sector_idx.reshape(7, 11)))

In [22]:
np.allclose(
    obs.isel(draw=(("new_dim",), np.arange(7)), obs_id=(("new_dim", "new_obs"), sector_idx.reshape(7, 11))),
    obs.values[:, np.arange(7)[:, None], sector_idx.reshape(7, 11)]
)

True

In [23]:
np.allclose(
    obs.isel(draw=(("new_dim",), np.arange(7)), obs_id=(("new_dim", "new_obs"), sector_idx.reshape(7, 11))),
    obs.values[:, :, sector_idx.reshape(7, 11)][:, np.arange(7), np.arange(7), :]
)

True