In [1]:
import xarray as xr
import numpy as np

# Example data
- 10 data points
- 3 trajectory of length [2, 3, 5]

In [2]:
dt = xr.Dataset(
    data_vars=dict(
        value=(["x"], [1,1,2,2,2,3,3,3,3,3]), 
    ),
    coords=dict(
        lon=(["x"], np.linspace(0,1,10)),
    ),
).chunk(chunks={'x': tuple([2,3,5])}) # three chunks of different size

`apply_ufunc` can also be used when the input and output have different size, but require a bit more work
 the documentation is limited but there is this [example](https://xarray.pydata.org/en/stable/examples/apply_ufunc_vectorize_1d.html).

# Operation on the full array

In [3]:
r = (dt.value-np.mean(dt.value))

In [4]:
r.compute()

# Constant size between input and output

## first example

In [5]:
def func(array):
    nchunks = len(array.chunks[0])
    print(nchunks)
    return array.map_blocks(lambda x: x-np.mean(x), chunks=array.chunks)

In [6]:
xr.apply_ufunc(
    func,
    dt,
    input_core_dims=[['x']],
    output_core_dims=[['x']],
    dask='allowed'
).compute()

3


## second example
The function expect a 1D array so if we want to combine two variables, we can do as follow:

In [7]:
def per_chunk_fft(array):
    return array.map_blocks(lambda x: np.fft.fft(x), chunks=array.chunks)

In [8]:
xr.apply_ufunc(
    per_chunk_fft,
    (dt.value + dt.value*1j).data,  # could be different variable from the dataset
    input_core_dims=[["x"]],  # input dimension to the per_block function
    output_core_dims=[["x"]],  # output still has one dimension
    dask='allowed'
).compute()

array([ 2. +2.j,  0. +0.j,  6. +6.j,  0. +0.j,  0. +0.j, 15.+15.j,
        0. +0.j,  0. +0.j,  0. +0.j,  0. +0.j])

# Different size between input and output

In [9]:
def per_block_mean(array):
    nchunks = len(array.chunks[0])
    output_chunks = ([1] * nchunks,) # 1 value per chunk
    # must return an array (https://github.com/dask/dask/issues/8822)
    return array.map_blocks(lambda x: np.mean(x, keepdims=True), chunks=output_chunks)

In [10]:
xr.apply_ufunc(
    per_block_mean,
    dt,
    input_core_dims=[["x"]],  # input dimension to the per_block function
    output_core_dims=[["x"]],  # output still has one dimension
    exclude_dims=set("x"),  # size of x changes so it has to be in the exclude_dims param
    dask="allowed",
).compute()