# Using DALI with JAX

### Overview

This simple example shows how to train a neural network implementet in JAX with DALI pipelines. It builds on MNIST training example from JAX codebse that can be found [here](https://github.com/google/jax/blob/main/examples/mnist_classifier_fromscratch.py).

We will use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra)

In [58]:
import os

training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')

First step is to prepare function that will be later used to create instances of DALI pipelines. 

In [59]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types


batch_size = 100
image_size = 28
num_classes = 10

@pipeline_def(device_id=0, batch_size=batch_size, num_threads=4)
def mnist_pipeline(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader")
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()
    
    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Next step is to instantiate pipelines and build them. Building creates and initializes pipeline internals.

In [60]:
print('Creating pipelines')
training_pipeline = mnist_pipeline(data_path=training_data_path, random_shuffle=True)
validation_pipeline = mnist_pipeline(data_path=validation_data_path, random_shuffle=False)

print('Building pipelines')
training_pipeline.build()
validation_pipeline.build()

print(training_pipeline)
print(validation_pipeline)

Creating pipelines
Building pipelines
<nvidia.dali.pipeline.Pipeline object at 0x7f34081594e0>
<nvidia.dali.pipeline.Pipeline object at 0x7f340815b010>


For DALI pipeline to work with JAX it needs to be wrapped with appropriate DALI iterator. To get iterator compatible with JAX we need to import in from DALI JAX plugin.

In [61]:
from nvidia.dali.plugin import jax as dax


print('Creating iterators')
training_iterator = dax.DALIGenericIterator(
    training_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

validation_iterator = dax.DALIGenericIterator(
    validation_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(training_iterator)
print(f"Training iterator size = {training_iterator.size}")
print(f"Validation iterator size = {validation_iterator.size}")

Creating iterators
<nvidia.dali.plugin.jax.DALIGenericIterator object at 0x7f340815bcd0>
Training iterator size = 60000
Validation iterator size = 10000


DALI iterators are ready for the training. 

In [62]:
from model import init_random_params, update, accuracy

In [63]:
print('Starting training')
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10

params = init_random_params(param_scale, layer_sizes)

for epoch in range(num_epochs):
    for batch in training_iterator:
        params = update(params, batch, step_size)

    test_acc = accuracy(params, validation_iterator)
    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")

Starting training
Epoch 0 sec
Test set accuracy 0.7761000394821167
Epoch 1 sec
Test set accuracy 0.835800051689148
Epoch 2 sec
Test set accuracy 0.8623000383377075
Epoch 3 sec
Test set accuracy 0.8778000473976135
Epoch 4 sec
Test set accuracy 0.8875000476837158
Epoch 5 sec
Test set accuracy 0.896600067615509
Epoch 6 sec
Test set accuracy 0.9018000364303589
Epoch 7 sec
Test set accuracy 0.9058000445365906
Epoch 8 sec
Test set accuracy 0.9095000624656677
Epoch 9 sec
Test set accuracy 0.913800060749054
