In [1]:
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [1]:
from typing import Sequence

from flax import linen as nn
import jax
import jax.numpy as jnp
import optax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

2024-11-28 17:05:29.803195: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732784729.812104 1086326 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732784729.814582 1086326 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [32]:
# @markdown The learning rate for the optimizer:
LEARNING_RATE = 0.002 # @param{type:"number"}
# @markdown Number of samples in each batch:
BATCH_SIZE = 128 # @param{type:"integer"}
# @markdown Total number of epochs to train for:
N_EPOCHS = 5 # @param{type:"integer"}

In [39]:
(train_loader, test_loader), info = tfds.load(
    "mnist", split=["train", "test"], as_supervised=True, with_info=True
)

min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
train_loader = train_loader.map(min_max_rgb)
test_loader = test_loader.map(min_max_rgb)

NUM_CLASSES = info.features["label"].num_classes
IMG_SIZE = info.features["image"].shape

train_loader_batched = train_loader.shuffle(
    buffer_size=10000, reshuffle_each_iteration=True
).batch(BATCH_SIZE, drop_remainder=True)

test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)

In [76]:
class MLP(nn.Module):
    """A simple multilayer perceptron model for image classification."""
    hidden_sizes: Sequence[int] = (100, 100)

    @nn.compact
    def __call__(self, x):
        # Flatten the input images
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=self.hidden_sizes[0])(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_sizes[1])(x)
        x = nn.relu(x)
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x


class LeNet(nn.Module):
    "A simple LeNet model for image classification."
    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Conv(features=6, kernel_size=(5, 5))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=16, kernel_size=(5, 5))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=120)(x)
        x = nn.relu(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.25, deterministic=not train)(x)
        x = nn.Dense(features=84)(x)
        x = nn.relu(x)
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x

In [77]:
# net = MLP()
net = LeNet()

In [73]:
@jax.jit
def predict(params, inputs, batch_states):
    return net.apply({"params": params, "batch_stats": batch_states}, inputs, train=True)

def loss_accuracy(params, data):
    """Computes loss and accuracy over a mini-batch.

    Args:
        params: the model parameters
        data: a tuple of (images, labels)
    Returns:
        loss: the average loss over the mini-batch (float)
    """

    inputs, labels = data
    logits, updates = predict(params, inputs)
    batch_stats = updates['batch_stats']
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, {"accuracy": accuracy}

@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [74]:
solver = optax.adam(LEARNING_RATE)
rng1 = jax.random.PRNGKey(0)
rng2 = jax.random.PRNGKey(1)
dummy_data = jnp.ones((1,)+IMG_SIZE, dtype=jnp.float32)
vars = net.init({"params": rng1, ""}, dummy_data, train=False)
params = vars["params"]

solver_state = solver.init(params)

def dataset_stats(params, data_loader):
    """Compute the loss and accuracy over a dataset."""
    all_accuracy = []
    all_loss = []
    for batch in data_loader.as_numpy_iterator():
        batch_loss, batch_aux = loss_accuracy(params, batch)
        all_loss.append(batch_loss)
        all_accuracy.append(batch_aux['accuracy'])

    return {"loss": np.mean(all_loss), "accuracy": np.mean(all_accuracy)}

In [75]:
train_accuracy = []
train_losses = []

# Compute test set accuracy before training
test_stats = dataset_stats(params, test_loader_batched)
test_accuracy = [test_stats["accuracy"]]
test_losses = [test_stats["loss"]]

@jax.jit
def train_step(params, solver_state, batch):
    # performs a one-step update, aux is the accuracy
    (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(params, batch, solver_state)
    updates, solver_state = solver.update(grad, solver_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, solver_state, loss, aux

for epoch in range(N_EPOCHS):
    train_accuracy_epoch = []
    train_losses_epoch = []

    for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):
        params, solver_state, loss, aux = train_step(params, solver_state, train_batch)
        train_accuracy_epoch.append(aux['accuracy'])
        train_losses_epoch.append(loss)
        if step % 100 == 0:
            print(f"Step {step:<4}, Loss: {loss:<8.4e}, Accuracy: {aux['accuracy']:<5.2f}")

    test_stats = dataset_stats(params, test_loader_batched)
    test_accuracy.append(test_stats["accuracy"])
    test_losses.append(test_stats["loss"])
    train_accuracy.append(np.mean(train_accuracy_epoch))
    train_losses.append(np.mean(train_losses_epoch))

    print(f"Epoch {epoch:<4}, Loss: {train_losses[-1]:<8.4e}, Accuracy: {train_accuracy[-1]:<5.2f}, Test Accuracy: {test_accuracy[-1]:<5.2f}")

InvalidRngError: Dropout_0 needs PRNG for "dropout" (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)

In [56]:
f"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}"

'Improved accuracy on test DS from 0.10136217623949051 to 0.9888821840286255'