In [None]:
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
import tensorflow as tf

import numpy as np


In [None]:
#load data
ds_builder = tfds.builder('cifar10') #load data
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split='train', shuffle_files=True)
test_ds = ds_builder.as_dataset(split='test', shuffle_files=False)

#normalization function

def normalize_img(data):
    """Normalize images: `uint8` -> `float32`."""
    data['image'] = tf.cast(data['image'], tf.float32) / 255.0
    return data

#normalizing data
train_ds = train_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

# Batching the data
batch_size = 32
train_ds = train_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)

# Converting to numpy and preparing for JAX
def tfds_to_numpy(dataset):
    """Convert a TensorFlow dataset to NumPy format."""
    return [{'image': np.array(batch['image']), 'label': np.array(batch['label'])} for batch in dataset]

# Convert TensorFlow datasets to NumPy format
train_ds = tfds_to_numpy(train_ds)
test_ds = tfds_to_numpy(test_ds)


In [None]:
#build a model

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x, is_training: bool):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)  # 10 classes for CIFAR-10
        return x

In [None]:
# Define the loss function
def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels).mean()

# Define the optimizer
optimizer = optax.adam(learning_rate=0.001)


In [None]:
# Define the train state including the model parameters and optimizer
class TrainState(train_state.TrainState):
    pass

# Define the training step function
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'], is_training=True)
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


In [None]:
# Define the evaluation function
@jax.jit
def evaluate(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['image'], is_training=False)
    loss = cross_entropy_loss(logits, batch['label'])
    accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['label'])
    return loss, accuracy

def evaluate_model(state, test_dataset):
    total_loss = 0
    total_accuracy = 0
    num_batches = 0

    for batch in test_dataset:
        loss, accuracy = evaluate(state, batch)
        total_loss += loss
        total_accuracy += accuracy
        num_batches += 1

    avg_loss = total_loss / num_batches
    avg_accuracy = total_accuracy / num_batches
    return avg_loss, avg_accuracy

In [None]:
# Initialize model and state
model = CNN()
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones([1, 32, 32, 3]), is_training=True)['params']

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer
)

In [None]:
#training loop
num_epochs = 10  # Number of epochs to train
for epoch in range(num_epochs):
    for batch in train_ds:
        state, loss = train_step(state, batch)
    
    print(f"Epoch {epoch + 1}, Loss: {loss}")

In [None]:
avg_loss, avg_accuracy = evaluate_model(state, test_ds)

print(f"Test Loss: {avg_loss}")
print(f"Test Accuracy: {avg_accuracy}")

