# Using DALI with JAX

This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from JAX codebse that can be found [here](https://github.com/google/jax/blob/jax-v0.4.13/examples/mnist_classifier_fromscratch.py).

We will use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra).

In [1]:
import os

training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')

First step is to create a pipeline definition function that will later be used to create instances of DALI pipelines. It defines all steps of the preprocessing. In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with `labels.gpu()` and apply one hot encoding to them for training with `fn.one_hot`.

This example focuses on how to use DALI pipeline with JAX. For more information on DALI pipeline look into [Getting started](../../getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)

Next step is to instantiate pipelines and build them. Building creates and initializes pipeline internals.

For DALI pipeline to work with JAX it needs to be wrapped with appropriate DALI iterator. To get the iterator compatible with JAX we need to import it from DALI JAX plugin. In addition to the pipeline we can pass the `output_map`, `reader_name` and `auto_reset` parameters to the iterator. 

**Here is a quick explnation of how these parameters work:**

 - `output_map`: iterators return a dictionary with outputs of the pipeline as its values. Keys in this dictionary are defined by `output_map`. For example, `labels` output returned from the DALI pipeline defined above will be accessible as `iterator_output['labels']`,
 - `reader_name`: setting this parameter introduces the notion of an epoch to our iterator. DALI pipeline itself is infinite, it will return the data indefinately, wrapping around the dataset. DALI readers (such as `fn.readers.caffe2` used in this example) have access to the information about the size of the dataset. If we want to pass this information to the iterator we need to point to the operator that should be queried for the dataset size. We do it by naming the operator (note `name="mnist_caffe2_reader"`) and passing the same name as the value for `reader_name` argument,
  - `auto_reset`: this argument controls the behaviour of the iterator after the end of an epoch. If set to `True` will automatically reset the state of the iterator and prepare it to start the next epoch.

### Multiple GPUs

This section shows how to train the network with multiple GPUs.


We start by modifiying pipeline definition function. 

Note new arguments passed to the `fn.readers.caffe2`: `num_shards` and `shard_id`. They are used to controll sharding:
 - `num_shards` sets total number of shards
 - `shard_id` tells the pipeline for which shard in the training it is responsible. 

 Also `device_id` argument was removed from the decorator. Since we want these pipelines run on different GPUs we will pass particualr `device_id` in pipeline creation. Most often, `device_id` and `shard_id` will have the same value but it is not a requirement.

 If you want to learn more about DALI sharding behaviour look into [DALI sharding docs page](../../general/getting_started.ipynb).

In [2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import jax


batch_size = 200
image_size = 28
num_classes = 10
batch_size_per_gpu = batch_size // jax.device_count()


@pipeline_def(batch_size=batch_size_per_gpu, num_threads=4)
def mnist_sharded_pipeline(data_path, random_shuffle, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id)
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()
    
    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Creating pipelines. Note the `device_id` values that are passed to place a pipeline on a different device.

In [3]:
from nvidia.dali.plugin import jax as dax

print('Creating training pipelines')

pipelines = []
for id, device in enumerate(jax.devices()):
    pipeline = mnist_sharded_pipeline(data_path=training_data_path, random_shuffle=True, num_shards=jax.device_count(), shard_id=id, device_id=id)
    print(f'Pipeline {pipeline} working on device {pipeline.device_id}')
    pipelines.append(pipeline)

print('Creating training iterator')

training_iterator = dax.DALIGenericIterator(
    pipelines,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(f"Number of batches in training iterator = {len(training_iterator)}")

Creating training pipelines
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7fdbeb5638e0> working on device 0
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7fdbeb45c1f0> working on device 1
Creating training iterator
Number of batches in training iterator = 300


In [4]:
print('Creating validation iterator')
validation_pipeline = mnist_sharded_pipeline(data_path=validation_data_path, random_shuffle=False, num_shards=1, shard_id=0, device_id=0)

validation_iterator = dax.DALIGenericIterator(
    validation_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(f"Number of batches in validation iterator = {len(validation_iterator)}")

Creating validation iterator
Number of batches in validation iterator = 100


In [5]:
import jax.numpy as jnp
from model import init_model, update_parallel, accuracy


model = init_model()
model = jax.tree_map(lambda x: jnp.array([x] * jax.device_count()), model) 

In [6]:
from model import update_parallel

num_epochs = 10

for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update_parallel(model, batch)
        
    test_acc = accuracy(jax.tree_map(lambda x: x[0], model), validation_iterator)
    
    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")

Epoch 0 sec
Test set accuracy 0.6746000051498413
Epoch 1 sec
Test set accuracy 0.7845000624656677
Epoch 2 sec
Test set accuracy 0.8244000673294067
Epoch 3 sec
Test set accuracy 0.8450000286102295
Epoch 4 sec
Test set accuracy 0.8600000143051147
Epoch 5 sec
Test set accuracy 0.8705000281333923
Epoch 6 sec
Test set accuracy 0.878000020980835
Epoch 7 sec
Test set accuracy 0.8832000494003296
Epoch 8 sec
Test set accuracy 0.8874000310897827
Epoch 9 sec
Test set accuracy 0.8919000625610352
