# Efficient loading of data from zarr stores during training
This notebook demonstrates how data preprocessing can impact loading times of chunks of data during training.

In [1]:
import zarr
zarr.__version__

'2.18.2'

In [2]:
import xarray as xr
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


# Create some toy data:

In [3]:
data = np.random.rand(720,360,10000) # almost 20GB in RAM
lon = np.arange(0,360,0.5)
lat = np.arange(-90,90,0.5)
time = np.arange(10000)
ds = xr.Dataset({"temperature" :(["longitude", "latitude", "time"], data)}, coords = {"longitude" : lon, "latitude" : lat, "time" : time})

In [4]:
ds

# Chunk along time dimension:
Every timestep is one chunk of data. This will make loading individual time steps fast, because only that chunk will be loaded.

In [5]:
ds_timechunked = ds.chunk({"longitude" : 720, "latitude" : 360, "time" : 1})
ds_timechunked

Unnamed: 0,Array,Chunk
Bytes,19.31 GiB,1.98 MiB
Shape,"(720, 360, 10000)","(720, 360, 1)"
Dask graph,10000 chunks in 1 graph layer,10000 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 19.31 GiB 1.98 MiB Shape (720, 360, 10000) (720, 360, 1) Dask graph 10000 chunks in 1 graph layer Data type float64 numpy.ndarray",10000  360  720,

Unnamed: 0,Array,Chunk
Bytes,19.31 GiB,1.98 MiB
Shape,"(720, 360, 10000)","(720, 360, 1)"
Dask graph,10000 chunks in 1 graph layer,10000 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Save data as zarr store
The zarr store will also be chunked on disk.

In [6]:
!mkdir -p data
!rm -r data/data_timechunked.zarr

ds_timechunked.to_zarr("data/data_timechunked.zarr")

<xarray.backends.zarr.ZarrStore at 0x7ff91a15b7c0>

# Define pytorch Dataset
The dataset metadata is loaded in `__init__` with open_mfdataset, but the actual data is not.
In `__getitem__` the data is loaded with `.values`, but only after selecting a timestep. Thereby only that chunk of data needs to be loaded.

In this example, the data loading is optimised for a ML application in which only one timstep is needed but the whole lon/lat domain. For other applications you might, for example, need data cubes of 100 (lon) x 100 (lat) x 10 (time). In such a case it would be a good idea to chunk the data differentlty already during preprocessing, for example: ds.chunk({"longitude" : 20, "latitude" : 20, "time" : 5}). You might need to do some testing to find the optimal chunk sizes.

Therefore, you need to be aware of how you want to access your data during training to inform the preprocessing.

In [8]:
class DemoDatasetChunk(Dataset):
    def __init__(self, fname):
        '''
        Initializes the dataset. Note that data is not yet loaded into memory here.
        
        Arguments:
        fname : filename
        '''
        
        with xr.open_mfdataset(fname,engine="zarr",parallel=True) as ds:
            print(f"opening {ds}")
            self.dataset = ds

    def __getitem__(self, idx):
        '''
        Retrieve a sample from the dataset. This is where the data is actually loaded into memory.

        Arguments:
        idx (int): index in dataset

        Returns:
        torch.Tensor: 2D sample (temperature on lon/lat grid, one time step)
        torch.Tensor: time coordinate of sample
        
        '''
        
        data_raw = self.dataset.isel(time=idx).temperature.values  # .values actually loads the datachunk

        time = self.dataset.isel(time=idx).time.values

        return torch.from_numpy(data_raw), torch.from_numpy(time)

    def __len__(self):
        '''Returns number of samples in the dataset'''
        return len(self.dataset.time)

In [9]:
demo_dataset = DemoDatasetChunk("data/data_timechunked.zarr")


opening <xarray.Dataset> Size: 21GB
Dimensions:      (latitude: 360, longitude: 720, time: 10000)
Coordinates:
  * latitude     (latitude) float64 3kB -90.0 -89.5 -89.0 ... 88.5 89.0 89.5
  * longitude    (longitude) float64 6kB 0.0 0.5 1.0 1.5 ... 358.5 359.0 359.5
  * time         (time) int64 80kB 0 1 2 3 4 5 ... 9994 9995 9996 9997 9998 9999
Data variables:
    temperature  (longitude, latitude, time) float64 21GB dask.array<chunksize=(720, 360, 1), meta=np.ndarray>


In [10]:
%%time
demo_dataset[42]

CPU times: user 10.9 ms, sys: 2.22 ms, total: 13.1 ms
Wall time: 179 ms


(tensor([[0.1745, 0.7161, 0.6609,  ..., 0.9635, 0.8790, 0.7593],
         [0.1843, 0.1990, 0.1547,  ..., 0.0976, 0.3290, 0.5591],
         [0.6066, 0.4874, 0.2607,  ..., 0.0339, 0.6090, 0.5383],
         ...,
         [0.3210, 0.3709, 0.0256,  ..., 0.4539, 0.4760, 0.8520],
         [0.9763, 0.5351, 0.3883,  ..., 0.0294, 0.8320, 0.3298],
         [0.3897, 0.2578, 0.0368,  ..., 0.6200, 0.5197, 0.2864]],
        dtype=torch.float64),
 tensor(42))

Loading across non-chunked dimensions is much slower:

In [12]:
%%time
demo_dataset.dataset.isel(longitude=42).load()

CPU times: user 10.8 s, sys: 5.92 s, total: 16.7 s
Wall time: 8.01 s
