# Basic settings and imports

In [None]:
## JAX
import jax
import jax.numpy as jnp
from jax import random

from functools import partial

In [None]:
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "saved_models"

# Seeding for random operations
seed = 42
main_rng = random.PRNGKey(seed)

print("Device:", jax.devices()[0])

# Prepare datasets

In [None]:
import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

from dataset import image_to_numpy, numpy_collate

In [None]:
test_transform = partial(image_to_numpy, dataset_name="CIFAR10")
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      partial(image_to_numpy, dataset_name="CIFAR10")
                                     ])

# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(seed))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(seed))

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

In [None]:
# For initial testing part of dataset
SMALL_DATASET_SIZE = 0.3
train_set_small = torch.utils.data.Subset(train_set, range(int(SMALL_DATASET_SIZE * len(train_set))))
val_set_small = torch.utils.data.Subset(val_set, range(int(SMALL_DATASET_SIZE * len(val_set))))
test_set_small = torch.utils.data.Subset(test_set, range(int(SMALL_DATASET_SIZE * len(test_set))))

In [None]:
# We define a set of data loaders that we can use for training and validation
train_loader = torch.utils.data.DataLoader(train_set,
                               batch_size=128,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=8,
                               persistent_workers=True)
val_loader   = torch.utils.data.DataLoader(val_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)
test_loader  = torch.utils.data.DataLoader(test_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)

# small dataloaders

small_train_loader = torch.utils.data.DataLoader(train_set_small,
                               batch_size=128,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=8,
                               persistent_workers=True)
small_val_loader   = torch.utils.data.DataLoader(val_set_small,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)
small_test_loader  = torch.utils.data.DataLoader(test_set_small,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)

In [None]:
import numpy as np
#check images variance
imgs, _ = next(iter(train_loader))
print("Batch mean", imgs.mean(axis=(0,1,2)))
print("Batch std", imgs.std(axis=(0,1,2)))
np.testing.assert_almost_equal(imgs.mean(axis=(0,1,2)), np.array([0,0,0]), decimal=1)
np.testing.assert_almost_equal(imgs.std(axis=(0,1,2)), np.array([1,1,1]), decimal=1)

In [None]:
from PIL import Image
from utils import visualize_img

NUM_IMAGES = 4
images = [train_dataset[idx][0] for idx in range(NUM_IMAGES)]
orig_images = [Image.fromarray(train_dataset.data[idx]) for idx in range(NUM_IMAGES)]
orig_images = [test_transform(img) for img in orig_images]

visualize_img(orig_images+images)

In [None]:
from train_modules import TrainerModule, TrainerModuleBatch
from network import GoogleNet
from flax import linen as nn

In [None]:
num_epochs = 2
googlenet_trainer = TrainerModuleBatch(model_class=GoogleNet,
                                model_name="GoogleNet",
                                model_hparams={"num_classes": 10,
                                                "act_fn": nn.relu},
                                optimizer_name="adamw",
                                checkpoint_dir=CHECKPOINT_PATH,
                                optimizer_hparams={"lr": 1e-3,
                                                    "weight_decay": 1e-4},
                                exmp_imgs=jax.device_put(next(iter(train_loader))[0]),)
googlenet_trainer.train_model(small_train_loader, small_val_loader, num_epochs=num_epochs)
googlenet_trainer.load_model()
val_acc = googlenet_trainer.eval_model(small_val_loader)
test_acc = googlenet_trainer.eval_model(small_test_loader)
val_acc, test_acc

In [None]:
num_epochs = 2
googlenet_trainer_layer = TrainerModule(model_class=GoogleNet,
                                model_name="GoogleNet_layer",
                                model_hparams={"num_classes": 10,
                                               "batch_norm": False,
                                                "act_fn": nn.relu},
                                optimizer_name="adamw",
                                checkpoint_dir=CHECKPOINT_PATH,
                                optimizer_hparams={"lr": 1e-3,
                                                    "weight_decay": 1e-4},
                                exmp_imgs=jax.device_put(next(iter(train_loader))[0]),)
googlenet_trainer_layer.train_model(small_train_loader, small_val_loader, num_epochs=num_epochs)
googlenet_trainer_layer.load_model()
val_acc = googlenet_trainer_layer.eval_model(small_val_loader)
test_acc = googlenet_trainer_layer.eval_model(small_test_loader)
val_acc, test_acc

In [None]:
num_epochs = 2
googlenet_trainer_gelu = TrainerModuleBatch(model_class=GoogleNet,
                                model_name="GoogleNet_gelu",
                                model_hparams={"num_classes": 10,
                                                "act_fn": nn.gelu},
                                optimizer_name="adamw",
                                checkpoint_dir=CHECKPOINT_PATH,
                                optimizer_hparams={"lr": 1e-3,
                                                    "weight_decay": 1e-4},
                                exmp_imgs=jax.device_put(next(iter(train_loader))[0]),)
googlenet_trainer_gelu.train_model(small_train_loader, small_val_loader, num_epochs=num_epochs)
googlenet_trainer_gelu.load_model()
val_acc = googlenet_trainer_gelu.eval_model(small_val_loader)
test_acc = googlenet_trainer_gelu.eval_model(small_test_loader)
val_acc, test_acc

In [None]:
from network import ResNet, ResNetBlock, PreActResNetBlock

num_epochs = 2
resnet_trainer = TrainerModuleBatch(model_class=ResNet,
                                model_name="ResNet",
                                model_hparams={"num_classes": 10,
                                                "c_hidden": (16, 32, 64),
                                                "num_blocks": (3, 3, 3),
                                                "act_fn": nn.relu,
                                                "block_class": ResNetBlock},
                                optimizer_name="SGD",
                                checkpoint_dir=CHECKPOINT_PATH,
                                optimizer_hparams={"lr": 0.1,
                                                    "momentum": 0.9,
                                                    "weight_decay": 1e-4},
                                exmp_imgs=jax.device_put(next(iter(train_loader))[0]),)
resnet_trainer.train_model(small_train_loader, small_val_loader, num_epochs=num_epochs)
resnet_trainer.load_model()
val_acc = resnet_trainer.eval_model(small_val_loader)
test_acc = resnet_trainer.eval_model(small_test_loader)
val_acc, test_acc

In [None]:
num_epochs = 2
preactresnet_trainer = TrainerModuleBatch(model_class=ResNet,
                                model_name="ResNet",
                                model_hparams={"num_classes": 10,
                                                "c_hidden": (16, 32, 64),
                                                "num_blocks": (3, 3, 3),
                                                "act_fn": nn.relu,
                                                "block_class": PreActResNetBlock},
                                optimizer_name="SGD",
                                checkpoint_dir=CHECKPOINT_PATH,
                                optimizer_hparams={"lr": 0.1,
                                                    "momentum": 0.9,
                                                    "weight_decay": 1e-4},
                                exmp_imgs=jax.device_put(next(iter(train_loader))[0]),)
preactresnet_trainer.train_model(small_train_loader, small_val_loader, num_epochs=num_epochs)
preactresnet_trainer.load_model()
val_acc = preactresnet_trainer.eval_model(small_val_loader)
test_acc = preactresnet_trainer.eval_model(small_test_loader)
val_acc, test_acc

In [None]:
from network import DenseNet

num_epochs = 2
densenet_trainer = TrainerModuleBatch(model_class=DenseNet,
                                model_name="DenseNet",
                                model_hparams={"num_classes": 10,
                                                "num_layers": [6, 6, 6, 6],
                                                "bn_size": 2,
                                                "act_fn": nn.relu,
                                                "growth_rate": 16},
                                optimizer_name="adamw",
                                checkpoint_dir=CHECKPOINT_PATH,
                                optimizer_hparams={"lr": 1e-3,
                                                    "weight_decay": 1e-4},
                                exmp_imgs=jax.device_put(next(iter(train_loader))[0]),)
densenet_trainer.train_model(small_train_loader, small_val_loader, num_epochs=num_epochs)
densenet_trainer.load_model()
val_acc = densenet_trainer.eval_model(small_val_loader)
test_acc = densenet_trainer.eval_model(small_test_loader)
val_acc, test_acc