In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
from flax.training import train_state
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
from tqdm.notebook import tqdm

# 1. Load Data
def get_datasets(downsampling_factor=8):
    """Load MNIST and apply downsampling."""
    train_ds = tfds.load(name='mnist', split='train', as_supervised=True)
    test_ds = tfds.load(name='mnist', split='test', as_supervised=True)

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.
        # Downsample
        if downsampling_factor > 1:
            shape = tf.shape(image)
            n_pixels = shape[0] * shape[1]
            image = tf.reshape(image, [n_pixels // downsampling_factor, downsampling_factor])
            image = tf.reduce_mean(image, axis=1)
        else:
            image = tf.reshape(image, [-1]) # Flatten
        return image, label

    train_ds = train_ds.map(preprocess).cache().shuffle(10000).batch(128).prefetch(1)
    test_ds = test_ds.map(preprocess).cache().batch(128).prefetch(1)
    return train_ds, test_ds

# mnist98 corresponds to a downsampling_factor of 8 (784/8=98)
train_ds, test_ds = get_datasets(downsampling_factor=8)
train_ds = tfds.as_numpy(train_ds)
test_ds = tfds.as_numpy(test_ds)

# 2. Define MLP model
class MLP(nn.Module):
    num_classes: int = 10
    num_features: int = 512
    num_layers: int = 2

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers - 1):
            x = nn.Dense(features=self.num_features)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.num_classes)(x)
        return x

# 3. Define loss and accuracy
def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits, one_hot_labels).mean()

def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics

# 4. Training Step
@jax.jit
def train_step(state, batch):
    images, labels = batch
    images, labels = jnp.array(images), jnp.array(labels)
    print(images.shape)
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        loss = cross_entropy_loss(logits, labels)
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, labels)
    return state, metrics

# 5. Evaluation Step
@jax.jit
def eval_step(state, batch):
    images, labels = batch
    images = jnp.array(images)
    logits = state.apply_fn({'params': state.params}, images)
    return compute_metrics(logits, labels)

# 6. Training Loop
def train_one_epoch(state, dataloader):
    batch_metrics = []
    for batch in dataloader:
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    epoch_metrics_np = jax.device_get(batch_metrics)
    epoch_summary = {
        k: np.mean([m[k] for m in epoch_metrics_np])
        for k in epoch_metrics_np[0]
    }
    return state, epoch_summary

def evaluate_model(state, test_ds):
    metrics = []
    for batch in test_ds:
        metric = eval_step(state, batch)
        metrics.append(metric)
    
    metrics_np = jax.device_get(metrics)
    summary = {
        k: np.mean([m[k] for m in metrics_np])
        for k in metrics_np[0]
    }
    return summary

# Initialization
key = jax.random.PRNGKey(0)
dummy_input = jnp.ones([1, 98]) # 784/8 = 98
]) # MLP(num_features=512, num_layers=1)
params = model.init(key, dummy_input)['params']
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

num_epochs = 20
for epoch in range(1, num_epochs + 1):
    state, train_metrics = train_one_epoch(state, train_ds)
    print(f"Epoch {epoch} | Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f}")
    
    test_metrics = evaluate_model(state, test_ds)
    print(f"Epoch {epoch} | Test Loss: {test_metrics['loss']:.4f}, Test Acc: {test_metrics['accuracy']:.4f}")


(128, 98)
(96, 98)
Epoch 1 | Train Loss: 1.6768, Train Acc: 0.6675
Epoch 1 | Test Loss: 1.2289, Test Acc: 0.8193
Epoch 2 | Train Loss: 1.0340, Train Acc: 0.8230
Epoch 2 | Test Loss: 0.8622, Test Acc: 0.8477
Epoch 3 | Train Loss: 0.7877, Train Acc: 0.8439
Epoch 3 | Test Loss: 0.6940, Test Acc: 0.8645
Epoch 4 | Train Loss: 0.6615, Train Acc: 0.8565
Epoch 4 | Test Loss: 0.5976, Test Acc: 0.8739
Epoch 5 | Train Loss: 0.5849, Train Acc: 0.8660
Epoch 5 | Test Loss: 0.5364, Test Acc: 0.8815
Epoch 6 | Train Loss: 0.5333, Train Acc: 0.8726
Epoch 6 | Test Loss: 0.4938, Test Acc: 0.8862
Epoch 7 | Train Loss: 0.4965, Train Acc: 0.8779
Epoch 7 | Test Loss: 0.4624, Test Acc: 0.8889
Epoch 8 | Train Loss: 0.4687, Train Acc: 0.8821
Epoch 8 | Test Loss: 0.4382, Test Acc: 0.8919
Epoch 9 | Train Loss: 0.4472, Train Acc: 0.8850
Epoch 9 | Test Loss: 0.4200, Test Acc: 0.8938
Epoch 10 | Train Loss: 0.4301, Train Acc: 0.8873
Epoch 10 | Test Loss: 0.4049, Test Acc: 0.8953
Epoch 11 | Train Loss: 0.4162, Train Ac