In [1]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
#tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# as_supervised=True gives us the (image, label) as a tuple instead of a dict
data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']


2024-10-30 18:10:11.859445: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1730326211.870911  110388 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1730326211.874391  110388 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-30 18:10:11.885684: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
I0000 00:00:1730326213.540410  110388 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 96

In [2]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes

In [3]:
import jax.numpy as jnp
from flax import linen as nn
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax import linen as nn
from flax.training import train_state
from clu import metrics
import flax
import optax


In [4]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label
     
data_train_vis = data_train.map(preprocess)


In [5]:
train_data = tfds.as_numpy(data_train.map(preprocess).batch(32).prefetch(1))
test_data  = tfds.as_numpy(data_test.map(preprocess).batch(32).prefetch(1))


In [6]:
class MLP(nn.Module):
  """A simple MLP model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(5,5))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2,2))
    x = nn.Conv(features=64, kernel_size=(5,5))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2,2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=1024)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x


In [7]:
model = MLP()
key1, key2 = random.split(random.PRNGKey(0))
random_flattened_image = random.normal(key1, (1,28,28,1))
params = model.init(key2, random_flattened_image) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

{'params': {'Conv_0': {'bias': (32,), 'kernel': (5, 5, 1, 32)},
  'Conv_1': {'bias': (64,), 'kernel': (5, 5, 32, 64)},
  'Dense_0': {'bias': (1024,), 'kernel': (43264, 1024)},
  'Dense_1': {'bias': (10,), 'kernel': (1024, 10)}}}

In [8]:
model.apply(params, random_flattened_image)

Array([[ 0.5565213 ,  1.2295414 , -0.245256  , -0.25400254,  1.2998863 ,
        -0.5383994 , -0.08894709, -0.41814116, -0.9940413 ,  0.5447468 ]],      dtype=float32)

In [9]:
@flax.struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
  metrics: Metrics

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.sgd(learning_rate=0.01, momentum=0.9),
    metrics=Metrics.empty())


In [10]:
@jax.jit
def compute_metrics(state, x, y):
  logits = state.apply_fn(state.params, x)
  loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=y).mean()
  metric_updates = state.metrics.single_from_model_output(
    logits=logits, labels=y, loss=loss)
  metrics = state.metrics.merge(metric_updates)
  state = state.replace(metrics=metrics)
  return state

In [11]:
@jax.jit
def update(train_state, x, y):
  """A single training step"""
  def loss(params, images, targets):
    """Categorical cross entropy loss function."""
    logits = train_state.apply_fn(params, images)
    loss_ce = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=targets).mean()
    return loss_ce
  loss_value, grads = jax.value_and_grad(loss)(train_state.params, x, y)
  train_state = train_state.apply_gradients(grads=grads)
  return train_state, loss_value

In [12]:
num_epochs = 25

In [14]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in train_data:
    y = y.astype(jnp.int32)
    #y = jax.nn.one_hot(y, NUM_LABELS)
    state, loss_value = update(state, x, y)
    state = compute_metrics(state, x, y)
  epoch_time = time.time() - start_time
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))

  for metric,value in state.metrics.compute().items():
    print(f"Training set {metric} {value}")
  state = state.replace(metrics=state.metrics.empty())

  test_state = state
  for x, y in test_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = y.astype(jnp.int32)
    test_state = compute_metrics(test_state, x, y)

  for metric,value in test_state.metrics.compute().items():
    print(f"Test set {metric} {value}")

2024-10-30 18:11:16.696841: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


ScopeParamShapeError: Initializer expected to generate shape (5, 5, 1, 32) but got shape (5, 5, 784, 32) instead for parameter "kernel" in "/Conv_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)