In [16]:
# pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # For NVIDIA GPU
# pip install flax optax
# pip install tensorflow-datasets==4.9.3
# pip install tfds-nightly


In [29]:
from flax import linen as nn
import numpy as np

class SimpleCNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        # First convolutional block
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Second convolutional block
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Flatten and pass to dense layers
        x = x.reshape((x.shape[0], -1)) # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x) # 10 output classes (e.g., for MNIST)
        return x

In [30]:
## Create Keys for Randomness
## JAX operations are deterministic. 
# JAX need to explicitly create and pass around keys for any random operations, like weight initialization.
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state


key = jax.random.PRNGKey(0)  

In [31]:
train_dsmodel = SimpleCNN()
# Dummy input for an MNIST image (batch size 1, 28x28 pixels, 1 channel)
dummy_input = jnp.ones([1, 28, 28, 1])
params = train_dsmodel.init(key, dummy_input)['params']

In [32]:
# Define the optimizer
tx = optax.adam(learning_rate=1e-3)

# Create the training state
state = train_state.TrainState.create(
    apply_fn=train_dsmodel.apply,
    params=params,
    tx=tx,
)

In [33]:
# Just in time compilation. Compile this whole function for performance
@jax.jit 
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']
        ).mean()
        return loss

    # Calculate loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(state.params)

    # Update the state 
    state = state.apply_gradients(grads=grads)
    return state, loss

In [34]:
@jax.jit
def eval_step(params, batch):
    """Computes accuracy for a batch."""
    logits = SimpleCNN().apply({'params': params}, batch['image'])
    accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['label'])
    return accuracy

### Dataset Loader and helper funcs

In [35]:
import tensorflow_datasets as tfds
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    
    # Normalize and add channel dimension
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds



def create_batches(data, batch_size):
    """Yields batches of data."""
    num_samples = data['image'].shape[0]
    num_batches = num_samples // batch_size
    # Create a random permutation of indices
    perm = np.random.permutation(num_samples)
    for i in range(num_batches):
        batch_indices = perm[i * batch_size : (i + 1) * batch_size]
        yield {
            'image': data['image'][batch_indices],
            'label': data['label'][batch_indices]
        }

In [36]:
# Training loop
num_epochs = 10
learning_rate = 1e-3
batch_size = 128


# Load data
train_ds, test_ds = get_datasets()

In [37]:
import time
print("Starting JAX/Flax training...")
start_time = time.time()


for epoch in range(num_epochs):
    epoch_loss = 0.
    for batch in create_batches(train_ds, batch_size):
        state, loss = train_step(state, batch)
        # Evaluation step
        test_accuracy = eval_step(state.params, test_ds)
        print(f"Epoch {epoch + 1}, Test Accuracy: {test_accuracy * 100:.2f}%")

 # IMPORTANT: JAX is asynchronous. block_until_ready() ensures all computations are finished.
jax.block_until_ready(state) 
end_time = time.time()
print("-" * 30)
print(f"JAX/Flax Training Time: {end_time - start_time:.4f} seconds")
print("-" * 30)

Starting JAX/Flax training...
Epoch 1, Test Accuracy: 26.45%
Epoch 1, Test Accuracy: 47.54%
Epoch 1, Test Accuracy: 46.77%
Epoch 1, Test Accuracy: 58.24%
Epoch 1, Test Accuracy: 67.62%
Epoch 1, Test Accuracy: 71.54%
Epoch 1, Test Accuracy: 72.61%
Epoch 1, Test Accuracy: 72.47%
Epoch 1, Test Accuracy: 77.96%
Epoch 1, Test Accuracy: 81.49%
Epoch 1, Test Accuracy: 79.06%
Epoch 1, Test Accuracy: 81.12%
Epoch 1, Test Accuracy: 81.46%
Epoch 1, Test Accuracy: 80.49%
Epoch 1, Test Accuracy: 81.10%
Epoch 1, Test Accuracy: 82.37%
Epoch 1, Test Accuracy: 82.76%
Epoch 1, Test Accuracy: 82.91%
Epoch 1, Test Accuracy: 84.80%
Epoch 1, Test Accuracy: 86.59%
Epoch 1, Test Accuracy: 86.93%
Epoch 1, Test Accuracy: 85.02%
Epoch 1, Test Accuracy: 82.87%
Epoch 1, Test Accuracy: 84.23%
Epoch 1, Test Accuracy: 85.90%
Epoch 1, Test Accuracy: 87.33%
Epoch 1, Test Accuracy: 88.57%
Epoch 1, Test Accuracy: 87.74%
Epoch 1, Test Accuracy: 85.18%
Epoch 1, Test Accuracy: 85.05%
Epoch 1, Test Accuracy: 87.62%
Epoch 1, 

KeyboardInterrupt: 