In [16]:
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import tensorflow as tf


In [20]:
#load data
ds_builder = tfds.builder('cifar10')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split='train', shuffle_files=True)
test_ds = ds_builder.as_dataset(split='test', shuffle_files=False)

#normalization function
train_ds, test_ds = train_ds/255.0, test_ds/255.0
def normalize_img(data):
    """Normalize images: `uint8` -> `float32`."""
    data['image'] = tf.cast(data['image'], tf.float32) / 255.0
    return data

#normalizing data
train_ds = train_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)


In [19]:
#build a model

# model
class CNN(nn.Module):

    @nn.compact
    def __call__(self, x, is_training):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        
        return x