This notebook can be used in Google Colab

It implements a CIFAR10 training using JAX Deep Learning framework

In [1]:
# This notebook is inspired from JAX mnist example
# https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

In [3]:
from pathlib import Path
from typing import Tuple

import tensorflow as tf
import tensorflow_datasets as tfds

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")


def get_data_from_tfds(
    name: str,
    data_dir: Path,
) -> Tuple[tf.Tensor | tf.data.Dataset, tf.Tensor | tf.data.Dataset]:
    """Fetch full datasets for evaluation
    Args:
        name: name of the dataset for tfds.load() method
        data_dir: path to save the data
    Returns:
        tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
    """
    cifar10_data, _ = tfds.load(
        name=name,
        batch_size=-1,
        data_dir=data_dir,
        with_info=True,
    )
    cifar10_data = tfds.as_numpy(cifar10_data)
    train_data, test_data = cifar10_data["train"], cifar10_data["test"]
    return train_data, test_data
