In [1]:
# pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # For NVIDIA GPU
# pip install flax optax
# pip install tensorflow-datasets==4.9.3
# pip install tfds-nightly


In [2]:
#!pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# for google colab TPU
#!pip install -U "jax[tpu]"


In [3]:
!pip install flax optax



In [4]:
!pip install tensorflow-datasets==4.9.3

!pip install tfds-nightly



In [5]:
!pip install tensorflow



In [6]:
import tensorflow as tf
tf.__version__

'2.18.0'

In [7]:
from flax import linen as nn
import numpy as np

class SimpleCNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        # First convolutional block
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Second convolutional block
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Flatten and pass to dense layers
        x = x.reshape((x.shape[0], -1)) # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x) # 10 output classes (e.g., for MNIST)
        return x

In [8]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


In [9]:
!nvidia-smi

Wed Jul 30 17:40:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   59C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [10]:
import jax
## Should show GPU if setup correctly
print(jax.devices())

[CudaDevice(id=0)]


In [11]:
## Create Keys for Randomness
## JAX operations are deterministic.
# JAX need to explicitly create and pass around keys for any random operations, like weight initialization.
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state


key = jax.random.PRNGKey(0)

In [12]:
train_dsmodel = SimpleCNN()
# Dummy input for an MNIST image (batch size 1, 28x28 pixels, 1 channel)
dummy_input = jnp.ones([1, 28, 28, 1])
params = train_dsmodel.init(key, dummy_input)['params']

In [13]:
# Define the optimizer
tx = optax.adam(learning_rate=1e-3)

# Create the training state
state = train_state.TrainState.create(
    apply_fn=train_dsmodel.apply,
    params=params,
    tx=tx,
)

In [14]:
# Just in time compilation. Compile this whole function for performance
@jax.jit
def train_step(state, batch):
    image, label = batch
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=label
        ).mean()
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [15]:
@jax.jit
def eval_step(params, batch):
    image, label = batch
    logits = SimpleCNN().apply({'params': params}, image)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == label)
    return accuracy

### Dataset Loader and helper funcs

In [16]:
import tensorflow_datasets as tfds
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

    # Normalize and add channel dimension
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds



def get_datasets(batch_size):
    (ds_train, ds_test), ds_info = tfds.load(
        'mnist',
        split=['train', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.
        return image, label

    # Use the same optimized pipeline as the Keras example
    train_ds = ds_train.map(preprocess).cache().shuffle(10000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    test_ds = ds_test.map(preprocess).cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)

    # Convert to numpy iterators for JAX
    return tfds.as_numpy(train_ds), tfds.as_numpy(test_ds)

In [17]:
# Training loop
num_epochs = 10
learning_rate = 1e-3
batch_size = 128


# Load data
train_ds, test_ds = get_datasets(batch_size)

In [18]:
import time
print("Starting JAX/Flax training...")
start_time = time.time()


# Training loop
for epoch in range(num_epochs):
    for batch in train_ds:
        state, loss = train_step(state, batch)

    # Evaluation (on a few test batches for speed)
    accuracies = []
    for test_batch in test_ds:
        accuracies.append(eval_step(state.params, test_batch))
    test_accuracy = np.mean(accuracies)
    print(f"Epoch {epoch + 1}, Test Accuracy: {test_accuracy * 100:.2f}%")

 # IMPORTANT: JAX is asynchronous. block_until_ready() ensures all computations are finished.
jax.block_until_ready(state)
end_time = time.time()
print("-" * 30)
print(f"JAX/Flax Training Time: {end_time - start_time:.4f} seconds")
print("-" * 30)

Starting JAX/Flax training...
Epoch 1, Test Accuracy: 98.24%
Epoch 2, Test Accuracy: 98.86%
Epoch 3, Test Accuracy: 98.81%
Epoch 4, Test Accuracy: 99.01%
Epoch 5, Test Accuracy: 98.94%
Epoch 6, Test Accuracy: 99.01%
Epoch 7, Test Accuracy: 98.89%
Epoch 8, Test Accuracy: 99.19%
Epoch 9, Test Accuracy: 99.09%
Epoch 10, Test Accuracy: 99.21%
------------------------------
JAX/Flax Training Time: 27.1709 seconds
------------------------------


In [19]:
import time
print("Starting JAX/Flax training for Another Time...")
start_time = time.time()


# Training loop
for epoch in range(num_epochs):
    for batch in train_ds:
        state, loss = train_step(state, batch)

    # Evaluation (on a few test batches for speed)
    accuracies = []
    for test_batch in test_ds:
        accuracies.append(eval_step(state.params, test_batch))
    test_accuracy = np.mean(accuracies)
    print(f"Epoch {epoch + 1}, Test Accuracy: {test_accuracy * 100:.2f}%")

 # IMPORTANT: JAX is asynchronous. block_until_ready() ensures all computations are finished.
jax.block_until_ready(state)
end_time = time.time()
print("-" * 30)
print(f"JAX/Flax Training Time: {end_time - start_time:.4f} seconds")
print("-" * 30)

Starting JAX/Flax training for Another Time...
Epoch 1, Test Accuracy: 99.07%
Epoch 2, Test Accuracy: 99.14%
Epoch 3, Test Accuracy: 99.16%
Epoch 4, Test Accuracy: 99.20%
Epoch 5, Test Accuracy: 99.18%
Epoch 6, Test Accuracy: 99.26%
Epoch 7, Test Accuracy: 99.25%
Epoch 8, Test Accuracy: 99.18%
Epoch 9, Test Accuracy: 99.21%
Epoch 10, Test Accuracy: 99.19%
------------------------------
JAX/Flax Training Time: 10.6536 seconds
------------------------------


In [20]:
import time
print("Starting JAX/Flax training for Another Time...")
start_time = time.time()


# Training loop
for epoch in range(num_epochs):
    for batch in train_ds:
        state, loss = train_step(state, batch)

    # Evaluation (on a few test batches for speed)
    accuracies = []
    for test_batch in test_ds:
        accuracies.append(eval_step(state.params, test_batch))
    test_accuracy = np.mean(accuracies)
    print(f"Epoch {epoch + 1}, Test Accuracy: {test_accuracy * 100:.2f}%")

 # IMPORTANT: JAX is asynchronous. block_until_ready() ensures all computations are finished.
jax.block_until_ready(state)
end_time = time.time()
print("-" * 30)
print(f"JAX/Flax Training Time: {end_time - start_time:.4f} seconds")
print("-" * 30)

Starting JAX/Flax training for Another Time...
Epoch 1, Test Accuracy: 99.13%
Epoch 2, Test Accuracy: 98.98%
Epoch 3, Test Accuracy: 99.21%
Epoch 4, Test Accuracy: 99.25%
Epoch 5, Test Accuracy: 99.25%
Epoch 6, Test Accuracy: 99.30%
Epoch 7, Test Accuracy: 99.27%
Epoch 8, Test Accuracy: 99.35%
Epoch 9, Test Accuracy: 99.36%
Epoch 10, Test Accuracy: 99.34%
------------------------------
JAX/Flax Training Time: 12.6540 seconds
------------------------------
