# demo of the xarray interface to foscat

## set up example data

In [None]:
import numpy as np
import xarray as xr

In [None]:
ds = (
    xr.open_dataset(
        "data/healpix/avhrr-sst-metop_b/0E0FEB4C-D050-11EC-ACC4-48DF3747D358.zarr",
        engine="zarr",
        chunks={},
        decode_timedelta=True,
    )
    .compute()
    .dggs.decode()
)
ds

In [None]:
arr = (
    ds["sea_surface_temperature"].where(lambda arr: arr.notnull(), drop=True).squeeze()
)
arr

## compute reference statistics

In [None]:
import foscat.xarray as foscat

In [None]:
params = foscat.Parameters(
    n_orientations=4, kernel_size=5, jmax_delta=0, dtype="float32", backend="torch"
)

In [None]:
%%time
stats = foscat.reference_statistics(
    arr - arr.median(), parameters=params, variances=True
)
stats.attrs["foscat_backend"] = params.cache.backend
stats

In [None]:
stats.foscat.plot()

## compute cross statistics

In [None]:
rng = np.random.default_rng()
n_timesteps = 5
noise = xr.DataArray(
    rng.normal(scale=0.1, size=(n_timesteps, arr.sizes["cells"])),
    dims=["time", "cells"],
    coords={
        "cell_ids": arr["cell_ids"],
        "time": xr.date_range(arr["time"].item(), freq="3min", periods=n_timesteps),
    },
)
arr2 = arr.squeeze() + noise
arr2

In [None]:
%%time
stats = foscat.cross_statistics(
    arr - arr.median(dim="cells"),
    arr2 - arr2.median(dim="cells"),
    parameters=params,
    variances=True,
)
stats.attrs["foscat_backend"] = params.cache.backend
stats

In [None]:
stats.foscat.plot()