# Multithreaded Sharding

## Quick-note on the project directory setup

The main root dir `~/.kosmoss` is structured as follow:
* `data/` contains raw and preprocessed data. 
    * `raw/` is actually a symbolic link to the same repo for all candidates, DO NOT TOUCH IT!
    * `processed/` will be created when data is preprocessed and will contain all transformed data
* `logs` contains the logs generated by the frameworks we will use throughout
* `artifacts` contains the artifacts generated by our experiments

It additionally contains the:
* `config.yaml` that contain the basic configuration setup for our experiments, including the `timestep` that informs of the sampling coefficient
* `metadata.json` that will be generated by the session on dataproc and that contains the data splitting setup

In [1]:
from kosmoss import CACHED_DATA_PATH, CONFIG, PROCESSED_DATA_PATH
from kosmoss.utils import save_metadata, prime_factors, purgedirs

## Download dataset

The data has already been downloaded for you with the `climetlab` library, provided by the ECMWF. We'll just load it.

In [None]:
import climetlab as cml
import dask
import dask.array as da
import numpy as np
import os
import os.path as osp
from pprint import pprint

step = CONFIG['timestep']

cml.settings.set("cache-directory", CACHED_DATA_PATH)
cmlds = cml.load_dataset(
    'maelstrom-radiation', 
    dataset='3dcorrection', 
    raw_inputs=False, 
    timestep=list(range(0, 3501, step)), 
    minimal_outputs=False,
    patch=list(range(0, 16, 1)),
    hr_units='K d-1',
)

By downloading data from this dataset, you agree to the terms and conditions defined at https://apps.ecmwf.int/datasets/licences/general/ If you do not agree with such terms, do not download the data. 


  0%|                                                                                                                                                                               | 0/64 [00:00<?, ?it/s]
rad4NN_inputs_2020010100_1000c11.nc:  22%|█████████████████████████████▊                                                                                                       | 64.0M/285M [00:00<?, ?B/s][A

rad4NN_inputs_2020010100_0c9.nc:  18%|████████████████████████▉                                                                                                                | 52.0M/285M [00:00<?, ?B/s][A[A


rad4NN_inputs_2020010100_1000c13.nc:  21%|███████████████████████████▉                                                                                                         | 60.0M/285M [00:00<?, ?B/s][A[A[A



rad4NN_inputs_2020010100_1000c1.nc:  19%|█████████████████████████▊                                                                                             

Internally, Climatelab checks that all of the requested bits have been downloaded. To process the data, we need to convert it in a usable format. Xarray is a framework built on top of Dask, and widely used in the scientific machine learning. It reads netCDF4 files, which are a layer over the popular HDF5 file format which provides it with metadata to carry additional information. There's already a lot of tech involved:

* [Dask is a Python framework](https://dask.org/) that 'provides advanced parallelism for analytics'.
* [Xarray](https://docs.xarray.dev/en/stable/) sits on top of Dask and provides an abstraction for HDF5/netCDF4.
* HDF5 is both a file format and a library to process large, n-dimensional, datasets. More info on [the format initiative](https://www.hdfgroup.org/solutions/hdf5/) and specifically for [the Python library](https://docs.h5py.org/en/stable/)
* [netCDF4 is an extension](https://unidata.github.io/netcdf4-python/) of the HDF5 file format that provides additional metadata.

In [None]:
xr_array = cmlds.to_xarray()
xr_array

As you can see, the 'loading' is almost instant, because the data is not loaded. Still, there's a bit of overhead each time a file's metadata is read, but nothing like plain loading into memory.

The dataset contains a few variable, only a few of which we'll use for modeling. To give you a sense of scale, we've been taking only 3 instants (snapshots at a particular time), but the data is already quite large for a DL use-case.

In [None]:
print(f"num of instants: {3500 // step} /3500")
print(f"size: {xr_array.nbytes / float(1 << 30):,.0f} GB")

If we were to download the entire set, it would result in some 24 TB of data. This does not fit in memory, let alone in disk in the current setup! For the sake of this session, we will work with a reduced subset, but the principles apply for the full dataset.

Let's have a look at the data for `sca_inputs`.

In [None]:
xr_array.sca_inputs

The object is a `xarray.DataArray` array totaling 70 MB of data for 1,085,440 rows of 17 features each. On the right, the chunk setup shows how the data is chunked into memory. Here, `sca_inputs` will be virtually splitted along all axes into 384 chunks of equal data size. Still, **the data is seen in memory as a continuous block, but all of the operations on that data will be parallelized if possible.**

For instance, the `mean` operation is parallelizable and can be computed on each chunk individually, then reassembled.

It means that:
* Computations are parallelizable, provided a math formula exists to distribute computation
* Entire dataset is not required to fully fit in memory, so large dataset can be processed like this

In [None]:
xr_array.col_inputs

In [None]:
features = [
    'sca_inputs',
    'col_inputs',
    'hl_inputs',
    'inter_inputs',
    'flux_dn_sw',
    'flux_up_sw',
    'flux_dn_lw',
    'flux_up_lw',
]

for feat in features:
    print(f'{feat}: {xr_array[feat].data}')

A few thoughts.

To later constitute a batch of data (of let's say 256 elements) and improve the training performance with parallelization, we need to be able to randomly read 256 elements in parallel, so 256 time 1 element, ideally (let's say we can have 256 threads). As we noticed before, reading the file metadata not is nanoseconds territory, so to achieve parallelization, we need to physically split the data in smaller bits of data containing only a few rows and save those into disk. This operation is called sharding.

To address this with Dask, **we need to align shard size with chunk size**, then rechunk the data into smaller bits.

As we can see above, each variable is chunked equally on the first axis, but some are chunked along the last axis as well. If we save the shards as is, **we won't have the full row of data for each of the variable on a single shard**. What we want is to have the data chunked along the first axis only.

## One Data representation, Two Models

We will demonstrate the use-case along two types of models, that take two different kinds of data as inputs:
* An MLP that takes flat data (vector)
* A GNN that take a graph (nodes + connectivity + edge attributes)

We COULD just save the data as flat vectors, and generate the graphs on-the-fly at **loading time** or directly within a **preprocessing layer of the model**. But for the sake of this presentation, we decided to just redunduntly save the data into separate files. One stack for the flattened data, and one stack for the graph data.

## Flattened data

To chunk the data into equal pieces along the first axis, we need to divide 1,085,440 with one of its factors.

In [None]:
dataset_len = xr_array.dims['column']
print(f"dataset len: {dataset_len}")
print(f"prime factor decomposition: {prime_factors(dataset_len)}")

Sharding is a subtil balance. 

We're going to split a single file into multiple pieces—from a loading perspective, the more the better, but from a FileSystem perspective and even more so an NFS, storing a large number of files can congest the system or network.

Sill, for the sake of training, we choose to favor the multiplication of files which will speedup dramatically the data loading process.

In [None]:
num_shards = 53 * 2 ** 6
shard_size = dataset_len // num_shards

First, we configure Dask to execute in a multithreated environment, as opposed to multiprocessed. 

Just making the implicit explicit here, since it's the default value.

In [None]:
dask.config.set(scheduler='threads')

Let's first flatten the data.

In [None]:
data = {}
for feat in features:
    array = xr_array[feat].data
    array = da.reshape(array, shape=(array.shape[0], -1))
    data.update({feat: array})

Then concatenate:

* `hl_inputs`, `inter_inputs`, `sca_inputs`, and `col_inputs` into x
* `flux_dn_sw`, `flux_up_sw`, `flux_dn_lw`, and `flux_up_lw` into y

In [None]:
data['hl_inputs'].shape, data['inter_inputs'].shape, data['sca_inputs'].shape

In [None]:
x = da.concatenate([
    data['hl_inputs'],
    data['inter_inputs'],
    data['sca_inputs'],
    data['col_inputs']
], axis=-1)

y = da.concatenate([
    data['flux_dn_sw'],
    data['flux_up_sw'],
    data['flux_dn_lw'],
    data['flux_up_lw'],
], axis=-1)

In [None]:
x

As mentioned before, `x` and `y` are chunked along all axes, which is far from ideal, so let's rechunk the data along the first axis only

In [None]:
x_ = da.rechunk(x, chunks=(shard_size, *x.shape[1:]))
y_ = da.rechunk(y, chunks=(shard_size, *y.shape[1:]))

In [None]:
x_

**Looking better!**

We can now save the chunks into shards on disk.

In [None]:
out_dir = osp.join(PROCESSED_DATA_PATH, f'flattened-{step}')

x_path, y_path = purgedirs([
    osp.join(out_dir, 'x'), 
    osp.join(out_dir, 'y')
])
    
# use da.to_npy_stack()
da.to_npy_stack(x_path, x_, axis=0)
da.to_npy_stack(y_path, y_, axis=0)

And save the metadata in a JSON so we can use it later without even opening a single file. It'll make sense later.

In [None]:
metadata_flattened = {
    "dtype": x_.dtype.name,
    "dataset_len": len(x_),
    "num_shards": len(x_.chunks[0]),
    "x_shape": x_.chunksize,
    "y_shape": y_.chunksize,
}
pprint(metadata_flattened)

save_metadata(
    step, 
    metadata_flattened, 
    'flattened'
)

## Feature engineering

As for the graph data, we will be building a path graph with 138 nodes and 137 both-ways connections (undirected index). 

For that, we need to prepare 3 pieces of data:

* Nodes features in `x` and `y`
* Edge attributes in `edge`

First, going back to the original `xarray.DataArray` structure, let's rechunk the data for each feature.

In [None]:
data = {}
for feat in features:
    array = xr_array[feat].data
    array = da.rechunk(array, chunks=(shard_size, *array.shape[1:]))
    data.update({feat: array})

In [None]:
def broadcast_features(array: da.Array):
    a = da.repeat(array, 138, axis=-1)
    a = da.moveaxis(a, -2, -1)
    return a

def pad_tensor(array: da.Array):
    a = da.pad(array, ((0, 0), (1, 1), (0, 0)))
    return a

We'll push `hl_inputs`, `inter_inputs` and `sca_inputs` to the nodes, and `col_inputs` to the edge. There's a bit of feature engineering.

In [None]:
x = da.concatenate([
    data['hl_inputs'],
    pad_tensor(data['inter_inputs']),
    broadcast_features(data['sca_inputs'][..., np.newaxis])
], axis=-1)

y = da.concatenate([
    data['flux_dn_sw'][..., np.newaxis],
    data['flux_up_sw'][..., np.newaxis],
    data['flux_dn_lw'][..., np.newaxis],
    data['flux_up_lw'][..., np.newaxis],
], axis=-1)

edge = data['col_inputs']

print(f"x of shape: {x.shape}")
print(f"y of shape: {y.shape}")
print(f"edge of shape: {edge.shape}")

In [None]:
x

Contraire to the flattened data, we need to rechunk the data along the first axis only.

In [None]:
x_ = da.rechunk(x, chunks=(shard_size, *x.shape[1:]))
y_ = da.rechunk(y, chunks=(shard_size, *y.shape[1:]))
edge_ = da.rechunk(edge, chunks=(shard_size, *edge.shape[1:]))

### To a single HDF5 file

Saving to a single HDF5 file will allow us to experiment later with a toy-example in MPI to shard the file differently. 

Create a dataset for `x`, `y`, and `edge`.

In [None]:
out_file = osp.join(PROCESSED_DATA_PATH, f'features-{step}.h5')
if osp.isfile(out_file): os.remove(out_file)
    
x_.to_hdf5(out_file, '/x')
y_.to_hdf5(out_file, '/y')
edge_.to_hdf5(out_file, '/edge')

### To a stack of NumPy files

Finally, we also want to shard the data to `.npy` files.

In [None]:
out_dir = osp.join(PROCESSED_DATA_PATH, f'features-{step}')

x_path, y_path, edge_path = purgedirs([
    osp.join(out_dir, 'x'), 
    osp.join(out_dir, 'y'), 
    osp.join(out_dir, 'edge')
])
    
da.to_npy_stack(x_path, x, axis=0)
da.to_npy_stack(y_path, y, axis=0)
da.to_npy_stack(edge_path, edge, axis=0)

And save the metadata for later use.

In [None]:
metadata_features = {
    "dtype": x_.dtype.name,
    "dataset_len": len(x_),
    "num_shards": len(x_.chunks[0]),
    "x_shape": x_.chunksize,
    "y_shape": y_.chunksize,
    "edge_shape": edge_.chunksize,
}
pprint(metadata_features)

save_metadata(
    step, 
    metadata_features, 
    'features'
)