# Training neural network with DALI and Pax

This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from Pax codebse that can be found [here](https://github.com/google/paxml/blob/paxml-v1.1.0/paxml/tasks/vision/params/mnist.py).

We will use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra).

This example focuses on how to use DALI pipeline with Pax. For more information on DALI pipeline look into [Getting started](../../getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)

In [7]:
!python -m paxml.main --job_log_dir=/tmp/dali_pax_logs --exp dali_pax_example.MnistExperiment

I0925 23:48:12.906274 139756425134720 py_utils.py:1015] [PAX STATUS]: E2E time: Starting timer for <_main> @ <.../paxml/main.py:445>
I0925 23:48:12.906384 139756425134720 main.py:450] [PAX STATUS]: Program start.
I0925 23:48:12.906458 139756425134720 py_utils.py:1015] [PAX STATUS]: Starting timer for <setup_jax> @ <.../paxml/main.py:465>
I0925 23:48:12.972933 139756425134720 xla_bridge.py:622] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter
I0925 23:48:12.973237 139756425134720 xla_bridge.py:622] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0925 23:48:12.973349 139756425134720 setup_jax.py:78] JAX process: 0 / 1
I0925 23:48:12.973390 139756425134720 setup_jax.py:79] JAX devices: [gpu(id=0)]
I0925 23:48:12.973586 139756425134720 setup_jax.py:80] jax.device_count(): 1
I0925 23:48:12.973684 139756425134720 setup_jax.py:81] jax.local_d

Next we create a helper function to print training accuracy from the logs. 

In [8]:
import os

from tensorflow.core.util import event_pb2
from tensorflow.python.lib.io import tf_record
from tensorflow.python.framework import tensor_util

def print_logs(path):
    "Helper function to print logs from logs directory created by paxml example"
    def summary_iterator():
        for r in tf_record.tf_record_iterator(path):
            yield event_pb2.Event.FromString(r)
            
    for summary in summary_iterator():
        for value in summary.summary.value:
            if value.tag == 'Metrics/accuracy':
                t = tensor_util.MakeNdarray(value.tensor)
                print(f"Iteration: {summary.step}, accuracy: {t}")

With this helper function we can print the accuracy of the training:

In [9]:
for file in os.listdir('/tmp/dali_pax_logs/summaries/train/'):
    print_logs(os.path.join('/tmp/dali_pax_logs/summaries/train/', file))

Iteration: 100, accuracy: 0.4111328125
Iteration: 200, accuracy: 0.5380859375
Iteration: 300, accuracy: 0.7451171875
Iteration: 400, accuracy: 0.8232421875
Iteration: 500, accuracy: 0.8857421875
Iteration: 600, accuracy: 0.888671875
Iteration: 700, accuracy: 0.888671875
Iteration: 800, accuracy: 0.8984375
Iteration: 900, accuracy: 0.9150390625
Iteration: 1000, accuracy: 0.904296875
