# Acknowledgement:
I looked at this youtube video for model architecture and training code: https://www.youtube.com/watch?v=Q0vvh95wes8

## Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import Javascript  # Restrict height of output cell.
from sklearn.metrics import ConfusionMatrixDisplay

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
from jax.tree_util import tree_map
from flax.training import train_state

In [2]:
from utils import image_to_numpy
from utils import numpy_collate

In [3]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [19]:
# model saving
from flax import serialization
from flax.training import train_state

## Initialization

In [5]:
IMAGE_SIZE = 32
BATCH_SIZE = 128
DATA_MEANS = np.array([0.49139968, 0.48215841, 0.44653091])
DATA_STD = np.array([0.24703223, 0.24348513, 0.26158784])
CROP_SCALES = (0.8, 1.0)
CROP_RATIO = (0.9, 1.1)

SEED = 42

plt.style.use('dark_background')

In [6]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x
     

## Load CIFAR 10 Dataset

In [7]:
classes = ('airplane, automobile, ship, truck', 'bird, dog, frog, horse')
# images in the test set will only be converted into numpy arrays
test_transform = image_to_numpy
# images in the train set will be randomly flipped, cropped, and then converted to numpy arrays
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),
    image_to_numpy
])

# Validation set should not use train_transform.
train_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = torchvision.datasets.CIFAR10('data', train=True, transform=test_transform, download=True)
# We're going to splite the train and test sets by seed 0 since we're exploring the differences between S_init and S_batch
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(0))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(0))
test_set = torchvision.datasets.CIFAR10('data', train=False, transform=test_transform, download=True)

train_data_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_data_loader = torch.utils.data.DataLoader(
    val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_data_loader = torch.utils.data.DataLoader(
    test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
     

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [8]:
print('CHECK CHECK CHECK')
print(f'number of samples in train_set:         {len(train_set)}')
print(f'number of batches in train_data_loader: {len(train_data_loader)}')
print(f'number of samples / batch size:         {len(train_set)} / {BATCH_SIZE} = {len(train_set)/BATCH_SIZE}')
print(f'number of samples in test_set:          {len(test_set)}')
print(f'number of batches in test_data_loader:  {len(test_data_loader)}')
print(f'number of samples / batch size:         {len(test_set)} / {BATCH_SIZE} = {len(test_set)/BATCH_SIZE}')

CHECK CHECK CHECK
number of samples in train_set:         45000
number of batches in train_data_loader: 351
number of samples / batch size:         45000 / 128 = 351.5625
number of samples in test_set:          10000
number of batches in test_data_loader:  79
number of samples / batch size:         10000 / 128 = 78.125


In [9]:
print(f'size of images in the first train batch: {next(iter(train_data_loader))[0].shape}')
print(f'type of images in the first train batch: {next(iter(train_data_loader))[0].dtype}')
print(f'size of labels in the first train batch: {next(iter(train_data_loader))[1].shape}')
print(f'type of labels in the first train batch: {next(iter(train_data_loader))[1].dtype}')

size of images in the first train batch: (128, 32, 32, 3)
type of images in the first train batch: float64
size of labels in the first train batch: (128,)
type of labels in the first train batch: int64


## Initializing the Model (RANDOM SEEDS)

In [10]:
model = CNN()

In [11]:
# Thi
optimizer = optax.adam(learning_rate=1e-4)

# inp_rng affects the batches
# init_rng affects the initialization of the model
rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(SEED), 3)
params = model.init(jax.random.PRNGKey(SEED),
                    jax.random.normal(inp_rng, (BATCH_SIZE, 32, 32, 3)))

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)


## Training

In [12]:
@jax.jit
def apply_model(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""

  def loss_fn(params):
    logits = state.apply_fn(params, images)
    one_hot = jax.nn.one_hot(labels, logits.shape[1])
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [13]:
writer = SummaryWriter('runs/experiment_1')

def train_epoch(state, data_loader, current_epoch):
    """Train for a single epoch."""
    epoch_loss = []
    epoch_accuracy = []

    progress_bar = tqdm(data_loader, desc="Training")
    
    for batch in progress_bar:
        batch_images, batch_labels = batch
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        state = update_model(state, grads)
        
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

        # Update progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=np.mean(epoch_loss), accuracy=np.mean(epoch_accuracy))

    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)

    # Log metrics to TensorBoard
    writer.add_scalar('Loss/train', train_loss, current_epoch)  # Use the current_epoch variable
    writer.add_scalar('Accuracy/train', train_accuracy, current_epoch)

    return state, train_loss, train_accuracy

In [14]:
def train_model(state, train_data_loader, num_epochs):
    # Training loop
    for epoch in range(num_epochs):
        print(f'Starting epoch {epoch + 1}/{num_epochs}')
        
        # Call the train_epoch function, passing the current epoch
        state, train_loss, train_accuracy = train_epoch(state, train_data_loader, epoch)

        # Print training loss and accuracy
        print(f'epoch: {epoch + 1:03d}, train loss: {train_loss:.4f}, train accuracy: {train_accuracy:.4f}')
    
    return state


In [15]:
trained_model_state = train_model(model_state, train_data_loader, num_epochs=35)
writer.close()

Starting epoch 1/35


Training: 100%|████| 351/351 [00:15<00:00, 22.05it/s, accuracy=0.402, loss=1.69]


epoch: 001, train loss: 1.6871, train accuracy: 0.4016
Starting epoch 2/35


Training: 100%|████| 351/351 [00:15<00:00, 22.00it/s, accuracy=0.496, loss=1.43]


epoch: 002, train loss: 1.4269, train accuracy: 0.4965
Starting epoch 3/35


Training: 100%|████| 351/351 [00:16<00:00, 21.55it/s, accuracy=0.531, loss=1.33]


epoch: 003, train loss: 1.3343, train accuracy: 0.5313
Starting epoch 4/35


Training: 100%|████| 351/351 [00:16<00:00, 21.51it/s, accuracy=0.553, loss=1.28]


epoch: 004, train loss: 1.2762, train accuracy: 0.5534
Starting epoch 5/35


Training: 100%|████| 351/351 [00:16<00:00, 21.37it/s, accuracy=0.573, loss=1.22]


epoch: 005, train loss: 1.2217, train accuracy: 0.5731
Starting epoch 6/35


Training: 100%|████| 351/351 [00:16<00:00, 21.44it/s, accuracy=0.586, loss=1.18]


epoch: 006, train loss: 1.1838, train accuracy: 0.5855
Starting epoch 7/35


Training: 100%|████| 351/351 [00:16<00:00, 21.46it/s, accuracy=0.602, loss=1.15]


epoch: 007, train loss: 1.1454, train accuracy: 0.6019
Starting epoch 8/35


Training: 100%|████| 351/351 [00:16<00:00, 21.09it/s, accuracy=0.613, loss=1.11]


epoch: 008, train loss: 1.1137, train accuracy: 0.6131
Starting epoch 9/35


Training: 100%|█████| 351/351 [00:16<00:00, 21.09it/s, accuracy=0.62, loss=1.09]


epoch: 009, train loss: 1.0886, train accuracy: 0.6199
Starting epoch 10/35


Training: 100%|█████| 351/351 [00:16<00:00, 20.94it/s, accuracy=0.63, loss=1.06]


epoch: 010, train loss: 1.0626, train accuracy: 0.6303
Starting epoch 11/35


Training: 100%|████| 351/351 [00:16<00:00, 20.93it/s, accuracy=0.639, loss=1.04]


epoch: 011, train loss: 1.0393, train accuracy: 0.6389
Starting epoch 12/35


Training: 100%|████| 351/351 [00:16<00:00, 20.96it/s, accuracy=0.645, loss=1.02]


epoch: 012, train loss: 1.0204, train accuracy: 0.6454
Starting epoch 13/35


Training: 100%|███████| 351/351 [00:16<00:00, 20.93it/s, accuracy=0.649, loss=1]


epoch: 013, train loss: 1.0004, train accuracy: 0.6495
Starting epoch 14/35


Training: 100%|████| 351/351 [00:16<00:00, 21.19it/s, accuracy=0.66, loss=0.981]


epoch: 014, train loss: 0.9814, train accuracy: 0.6604
Starting epoch 15/35


Training: 100%|███| 351/351 [00:16<00:00, 20.87it/s, accuracy=0.667, loss=0.962]


epoch: 015, train loss: 0.9616, train accuracy: 0.6671
Starting epoch 16/35


Training: 100%|███| 351/351 [00:16<00:00, 20.92it/s, accuracy=0.672, loss=0.947]


epoch: 016, train loss: 0.9469, train accuracy: 0.6716
Starting epoch 17/35


Training: 100%|███| 351/351 [00:16<00:00, 20.83it/s, accuracy=0.675, loss=0.936]


epoch: 017, train loss: 0.9360, train accuracy: 0.6745
Starting epoch 18/35


Training: 100%|███| 351/351 [00:16<00:00, 20.96it/s, accuracy=0.682, loss=0.918]


epoch: 018, train loss: 0.9183, train accuracy: 0.6825
Starting epoch 19/35


Training: 100%|███| 351/351 [00:16<00:00, 20.99it/s, accuracy=0.687, loss=0.906]


epoch: 019, train loss: 0.9059, train accuracy: 0.6868
Starting epoch 20/35


Training: 100%|███| 351/351 [00:16<00:00, 20.87it/s, accuracy=0.692, loss=0.893]


epoch: 020, train loss: 0.8925, train accuracy: 0.6919
Starting epoch 21/35


Training: 100%|███| 351/351 [00:16<00:00, 20.79it/s, accuracy=0.698, loss=0.875]


epoch: 021, train loss: 0.8753, train accuracy: 0.6977
Starting epoch 22/35


Training: 100%|███| 351/351 [00:16<00:00, 20.93it/s, accuracy=0.701, loss=0.869]


epoch: 022, train loss: 0.8690, train accuracy: 0.7014
Starting epoch 23/35


Training: 100%|███| 351/351 [00:16<00:00, 20.82it/s, accuracy=0.705, loss=0.851]


epoch: 023, train loss: 0.8511, train accuracy: 0.7052
Starting epoch 24/35


Training: 100%|███| 351/351 [00:16<00:00, 20.74it/s, accuracy=0.712, loss=0.839]


epoch: 024, train loss: 0.8385, train accuracy: 0.7116
Starting epoch 25/35


Training: 100%|███| 351/351 [00:16<00:00, 21.01it/s, accuracy=0.714, loss=0.829]


epoch: 025, train loss: 0.8290, train accuracy: 0.7138
Starting epoch 26/35


Training: 100%|███| 351/351 [00:17<00:00, 20.65it/s, accuracy=0.714, loss=0.821]


epoch: 026, train loss: 0.8214, train accuracy: 0.7137
Starting epoch 27/35


Training: 100%|███| 351/351 [00:16<00:00, 20.78it/s, accuracy=0.721, loss=0.809]


epoch: 027, train loss: 0.8090, train accuracy: 0.7208
Starting epoch 28/35


Training: 100%|███| 351/351 [00:16<00:00, 20.69it/s, accuracy=0.724, loss=0.797]


epoch: 028, train loss: 0.7970, train accuracy: 0.7241
Starting epoch 29/35


Training: 100%|███| 351/351 [00:17<00:00, 20.52it/s, accuracy=0.725, loss=0.793]


epoch: 029, train loss: 0.7927, train accuracy: 0.7250
Starting epoch 30/35


Training: 100%|███| 351/351 [00:17<00:00, 20.28it/s, accuracy=0.732, loss=0.778]


epoch: 030, train loss: 0.7779, train accuracy: 0.7322
Starting epoch 31/35


Training: 100%|███| 351/351 [00:17<00:00, 20.33it/s, accuracy=0.733, loss=0.772]


epoch: 031, train loss: 0.7715, train accuracy: 0.7328
Starting epoch 32/35


Training: 100%|███| 351/351 [00:17<00:00, 20.16it/s, accuracy=0.738, loss=0.757]


epoch: 032, train loss: 0.7567, train accuracy: 0.7385
Starting epoch 33/35


Training: 100%|███| 351/351 [00:16<00:00, 21.03it/s, accuracy=0.742, loss=0.748]


epoch: 033, train loss: 0.7483, train accuracy: 0.7422
Starting epoch 34/35


Training: 100%|███| 351/351 [00:17<00:00, 20.14it/s, accuracy=0.744, loss=0.742]


epoch: 034, train loss: 0.7417, train accuracy: 0.7436
Starting epoch 35/35


Training: 100%|███| 351/351 [00:17<00:00, 20.28it/s, accuracy=0.747, loss=0.732]

epoch: 035, train loss: 0.7322, train accuracy: 0.7472





# Testing

In [16]:
test_loss = []
test_accuracy = []

for batch in test_data_loader:
  batch_images, batch_labels = batch
  _, loss, accuracy = apply_model(trained_model_state, batch_images, batch_labels)
  test_loss.append(loss)
  test_accuracy.append(accuracy)

print(f'loss: {np.mean(test_loss):.4f}, accuracy: {np.mean(test_accuracy):.4f}')

loss: 0.8824, accuracy: 0.6957


## Saving/Loading Model

In [17]:
# Saving model
with open('trained_model.pkl','wb') as f:
    f.write(serialization.to_bytes(trained_model_state))

In [21]:
# Loading the model
with open('trained_model.pkl', 'rb') as f:
    loaded_model_state = serialization.from_bytes(train_state.TrainState, f.read())

In [22]:
loaded_model_state

{'step': array(12285, dtype=int32),
 'params': {'params': {'Conv_0': {'bias': array([-0.05509524, -0.01494891,  0.02423832,  0.17565621,  0.03038183,
            0.02479529, -0.02108995, -0.06242558,  0.12463821,  0.10383037,
            0.10325531,  0.06030441, -0.00580891, -0.07965636,  0.09768064,
            0.12192987, -0.10313437,  0.03542982, -0.05132829,  0.10119658,
            0.00237998, -0.01142217,  0.17714292,  0.0493258 , -0.02489948,
            0.11390908, -0.01538698,  0.12576143,  0.08806267,  0.00765791,
            0.12520294,  0.01378171], dtype=float32),
    'kernel': array([[[[ 1.84198797e-01, -4.88828160e-02, -3.60032320e-01,
               1.28478050e-01, -1.18141524e-01,  6.02924787e-02,
              -1.40627488e-01,  1.87220395e-01,  2.35586509e-01,
              -2.42053360e-01,  3.93796340e-02,  3.61317307e-01,
              -1.21444546e-01, -7.78316185e-02,  2.35576987e-01,
               1.25078216e-01,  1.38448432e-01, -1.73388034e-01,
              -3