## Xarray to Pytorch loader example

Brief example based on xbatcher and xarray into pytorch

In [None]:
from nerc_eds.apis import loader_api
# Can use datapoint under the hood?

xbatcher_kwargs = {}

# loader_api backed by geocroissant/STAC for searching data?
search = loader_api.search(
    keywords=["climate", "CMIP6", "temperature", "precipitation"],
    spatial_coverage="global",
    temporal_range=("2015-01-01", "2100-12-31")
)

# Load_dataset uses code from below to create xbatcher dataset
xb_dataset = search.load_dataset(
    **xbatcher_kwargs
)

In [None]:
# Example xbatcher (behind the scenes)
# https://xbatcher.readthedocs.io/en/latest/user-guide/training-a-neural-network-with-Pytorch-and-xbatcher.html
import xarray as xr
import xbatcher as xb
import xbatcher.loaders.torch

ds = xr.open_dataset(
    's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',
    engine='zarr',
    chunks={},
    backend_kwargs={'storage_options': {'anon': True}},
)

# Define batch generators
X_bgen = xb.BatchGenerator(
    ds['variable'],
    input_dims={'time':ds.time.size, 'lat':ds.lat.size, 'lon': ds.lon.size},
    preload_batch=False,
)
y_bgen = xb.BatchGenerator(
    ds['labels'], input_dims={'sample': 2000}, preload_batch=False
)

xb_dataset = xbatcher.loaders.torch.MapDataset(X_bgen, y_bgen)


In [None]:
import torch
# Xbatcher to pytorch
train_dataloader = torch.utils.data.DataLoader(
    xb_dataset,
    batch_size=None,  # Using batches defined by the dataset itself (via xbatcher)
    prefetch_factor=3,  # Prefetch up to 3 batches in advance to reduce data loading latency
    num_workers=4,  # Use 4 parallel worker processes to load data concurrently
    persistent_workers=True,  # Keep workers alive between epochs for faster subsequent epochs
    multiprocessing_context='forkserver',  # Use "forkserver" to spawn subprocesses, ensuring stability in multiprocessing
)