In [None]:
import pathlib

import foscat.Synthesis as synthe
import foscat.xarray as foscat
import numpy as np
import xarray as xr
from rich.console import Console
from rich.progress import track

console = Console()
xr.set_options(display_expand_attrs=False, display_expand_data=False, keep_attrs=True);

In [None]:
params = foscat.Parameters(
    n_orientations=4, kernel_size=3, jmax_delta=0, dtype="float32", backend="torch"
)

In [None]:
data_root = pathlib.Path("data/healpix")
stats_root = pathlib.Path("data/stats")
stats_root.mkdir(parents=True, exist_ok=True)

## compute stats

In [None]:
data_files = sorted(data_root.glob("*/*.zarr"))[:4]
data_files

In [None]:
def detect_key(ds, options):
    standard_names = ds.cf.standard_names
    for option in options:
        standard_name = standard_names.get(option)
        if standard_name is not None:
            return standard_name[0]

    raise ValueError(f"could not find a variable using the standard names {options}")


arrs = []
for path in track(data_files):
    ds = xr.open_dataset(path, chunks=None, decode_timedelta=True).load().dggs.decode()
    key = detect_key(ds, ["sea_water_temperature", "sea_surface_subskin_temperature"])
    temperature = ds[key].where(lambda arr: arr.notnull(), drop=True)
    temperature.encoding["source"] = path

    if "DEPTH" in temperature.dims:
        temperature = temperature.isel(DEPTH=0)

    arrs.append(temperature)

In [None]:
aligned_arrs_ = xr.align(
    *[arr.drop_indexes("cell_ids").set_xindex("cell_ids").squeeze() for arr in arrs],
    join="outer",
)
aligned_arrs = [x.dggs.decode() for x in aligned_arrs_]
aligned_arrs

In [None]:
stats_ = []
for arr in track(aligned_arrs):
    console.log(f"computing stats for {path.stem}")
    arr_ = arr.where(arr.notnull(), drop=True)
    stats = foscat.reference_statistics(
        arr_ - arr_.median(), parameters=params, variances=True, jmax=5, norm="self"
    )

    stats_.append(stats)

## synthesis

In [None]:
def The_loss(u, scat_operator, args):
    ref = args[0]
    sref = args[1]
    cell_ids = args[2]
    nside = args[3]

    # compute scattering covariance of the current synthetised map called u
    learn = scat_operator.eval(u, norm="self", cell_ids=cell_ids, nside=nside, Jmax=5)

    # make the difference withe the reference coordinates
    loss = scat_operator.reduce_distance(learn, ref, sigma=sref)

    return loss

In [None]:
scat = params.cache
loss_functions = []
for arr, stats in zip(aligned_arrs, stats_):
    ref, sref = foscat.statistics._xarray_to_scat_cov(stats)

    nside = arr.dggs.grid_info.nside

    loss_functions.append(
        synthe.Loss(The_loss, scat, ref, sref, arr.dggs.coord.data, nside)
    )

sy = synthe.Synthesis(loss_functions)

In [None]:
n_cell_ids = arrs[0].sizes["cells"]
std = np.max([arr.std().item() for arr in aligned_arrs])
imap = np.random.randn(1, n_cell_ids) * std

omap = scat.to_numpy(sy.run(imap, EVAL_FREQUENCY=1, NUM_EPOCHS=100))

In [None]:
synthesized = xr.DataArray(
    np.squeeze(omap), dims="cells", coords={"cell_ids": aligned_arrs[0].dggs.coord}
).dggs.decode(arrs[0].dggs.grid_info)
synthesized

In [None]:
synthesized.dggs.explore()