In [None]:
#@title LICENSE
# Licensed under the Apache License, Version 2.0

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='0'

In [3]:
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

2023-06-30 03:10:34.348079: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
import torch
torch.cuda.is_available()

True

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/jaxpruner/blob/main/colabs/mnist_pruning.ipynb)

## Imports / Helpers

In [5]:
from absl import logging
import flax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import optax
import tensorflow as tf
from tqdm import tqdm
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")


logging.set_verbosity(logging.INFO)

In [6]:
!pip3 install git+https://github.com/google-research/jaxpruner
import jaxpruner
import ml_collections

Collecting git+https://github.com/google-research/jaxpruner
  Cloning https://github.com/google-research/jaxpruner to /tmp/pip-req-build-1gywb8jb
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/jaxpruner /tmp/pip-req-build-1gywb8jb
  Resolved https://github.com/google-research/jaxpruner to commit d92a781c7e7c55a06c8b88d5bcf22b51cf44a890
  Preparing metadata (setup.py) ... [?25ldone


## Dataset

In [7]:
DATASET_PATH = "./data"
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0,1,2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0,1,2))
print("Data mean", DATA_MEANS)
print("Data std", DATA_STD)

Files already downloaded and verified
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]


In [8]:
def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - DATA_MEANS) / DATA_STD
    return img

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


In [9]:
test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      image_to_numpy
                                     ])
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))

Files already downloaded and verified
Files already downloaded and verified


In [10]:
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)
train_loader = data.DataLoader(train_set,
                               batch_size=128,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=8,
                               persistent_workers=True)
val_loader   = data.DataLoader(val_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)
test_loader  = data.DataLoader(test_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)

Files already downloaded and verified


In [11]:
# def get_datasets():
#   """Load MNIST train and test datasets into memory."""
#   ds_builder = tfds.builder('mnist')
#   ds_builder.download_and_prepare()
#   train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
#   test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
#   train_ds['image'] = jnp.float32(train_ds['image']) / 255.
#   test_ds['image'] = jnp.float32(test_ds['image']) / 255.
#   return train_ds, test_ds

# # Helper functions for images.

# def show_img(img, ax=None, title=None):
#   """Shows a single image."""
#   if ax is None:
#     ax = plt.gca()
#   ax.imshow(img[..., 0], cmap='gray')
#   ax.set_xticks([])
#   ax.set_yticks([])
#   if title:
#     ax.set_title(title)

# def show_img_grid(imgs, titles):
#   """Shows a grid of images."""
#   n = int(np.ceil(len(imgs)**.5))
#   _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
#   for i, (img, title) in enumerate(zip(imgs, titles)):
#     show_img(img, axs[i // n][i % n], title)

In [None]:
# Get datasets as dict of JAX arrays.
train_ds, test_ds = get_datasets()

## Training

In [12]:
import ml_collections

# class CNN(nn.Module):
#   """A simple CNN model."""

#   @nn.compact
#   def __call__(self, x):
#     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
#     x = nn.relu(x)
#     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
#     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
#     x = nn.relu(x)
#     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
#     x = x.reshape((x.shape[0], -1))  # flatten
#     x = nn.Dense(features=256)(x)
#     x = nn.relu(x)
#     x = nn.Dense(features=10)(x)
#     return x
from resnet import ResNet20 as res20


@jax.jit
def apply_model_train(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
  
    logits,updates = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, images, train = True, mutable = ['batch_stats'])
    
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, (logits, updates['batch_stats'])

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, new_batch_stats)), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy, new_batch_stats
@jax.jit
def apply_model_test(state, images, labels,batch_stats):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
  
    logits = state.apply_fn({'params': params, 'batch_stats': batch_stats}, images, train = False)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


@jax.jit
def update_model(state, grads, batch_stats):
  # updates, new_opt_state = state.tx.update(
  #     grads, state.opt_state, state.params)
  # new_params = optax.apply_updates(state.params, updates)
  state = state.apply_gradients(grads=grads, batch_stats=batch_stats)
  return state
def test_epoch(state,testloader, batch_stats):
  epoch_loss = []
  epoch_accuracy = []
  for batch in tqdm(testloader,desc = 'Testing' , leave = False):
    batch_images = batch[0]
    batch_labels = batch[1]
    _,loss,accuracy = apply_model_test(state,batch_images, batch_labels, batch_stats = batch_stats)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  return np.mean(epoch_loss), np.mean(epoch_accuracy)

def train_epoch(state, trainloader, sparsity_updater,batch_stats):
  """Train for a single epoch."""
  # train_ds_size = len(train_ds['image'])
  # steps_per_epoch = train_ds_size // batch_size

  # perms = jax.random.permutation(rng, len(train_ds['image']))
  # perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  # perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []
  is_ste = isinstance(sparsity_updater, (jaxpruner.SteMagnitudePruning,
                                         jaxpruner.SteRandomPruning))
  pre_op = jax.jit(sparsity_updater.pre_forward_update)
  # for i, perm in enumerate(perms):
  i = 1
  for batch in tqdm(trainloader, desc = 'Training' , leave = False):
    i+=1;
    batch_images = batch[0]
    batch_labels = batch[1]
    # Following is only needed for STE.
    new_params = pre_op(state.params, state.opt_state[2])
    forward_state = state.replace(params=new_params)

    grads, loss, accuracy, batch_stats = apply_model_train(forward_state, batch_images,
                                        batch_labels)
    state = update_model(state, grads, batch_stats)
    post_params = sparsity_updater.post_gradient_update(
        state.params, state.opt_state[2])
    state = state.replace(params=post_params)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
    if i % 100 == 0:
      if is_ste:
        print(jaxpruner.summarize_sparsity(
            new_params, only_total_sparsity=True))
      else:
        print(jaxpruner.summarize_sparsity(
            state.params, only_total_sparsity=True))
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy


def create_train_state(rng, config, num_epochs, num_steps_per_epoch):
  """Creates initial `TrainState`."""
  cnn = res20()
  variables = cnn.init(rng, jnp.ones([1, 32,32, 3]),train = False)
  params = variables['params']
  batch_stats = variables['batch_stats']
  sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
  # tx = optax.adam(config.learning_rate, config.momentum)
  # tx = sparsity_updater.wrap_optax(tx)
  opt_class = optax.sgd
  lr_schedule = optax.piecewise_constant_schedule(
            init_value=config.learning_rate,
            boundaries_and_scales=
                {int(num_steps_per_epoch*num_epochs*0.6): 0.1,
                 int(num_steps_per_epoch*num_epochs*0.85): 0.1}
        )
        # Clip gradients at max value, and evt. apply weight decay
  transf = [optax.clip(1.0)]
  hparam_dict = { "momentum" : config.momentum, "weight_decay" : 1e-4}
  if opt_class == optax.sgd :  # wd is integrated in adamw
            transf.append(optax.add_decayed_weights(1e-4))
  optimizer = optax.chain(
            *transf,
            sparsity_updater.wrap_optax(opt_class(lr_schedule))
        )
        # Initialize training state
  # self.state = TrainState.create(apply_fn=self.model.apply,
  #                                      params=self.init_params if self.state is None else self.state.params,
  #                                      batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats,
  #                                      tx=optimizer)


  return TrainState.create(
      apply_fn=cnn.apply, params=params, tx=optimizer, batch_stats = batch_stats), sparsity_updater


import time

def train_and_evaluate(trainloader,testloader,config: ml_collections.ConfigDict
                       ) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
  # train_ds, test_ds = get_datasets()
  rng = jax.random.PRNGKey(0)

  rng, init_rng = jax.random.split(rng)

  state, sparsity_updater= create_train_state(init_rng, config, config.num_epochs, len(trainloader))

  # for epoch in range(1, config.num_epochs + 1):
  for epoch in tqdm(range(1,config.num_epochs+1)):
    s_time = time.time()
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(state, trainloader,
                                                    sparsity_updater)
    # Following is only needed for STE.
    new_params = sparsity_updater.pre_forward_update(
        state.params, state.opt_state[2])
    forward_state = state.replace(params=new_params)
      
    test_loss, test_accuracy = test_epoch(forward_state, testloader,batch_stats)

    print(
        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f, time: %.2f'
        % (epoch, train_loss, train_accuracy * 100, test_loss,
           test_accuracy * 100, time.time() - s_time))

  return state

# Jaxpruner API

In [13]:
jaxpruner.ALGORITHMS

('no_prune',
 'magnitude',
 'random',
 'saliency',
 'magnitude_ste',
 'random_ste',
 'global_magnitude',
 'global_saliency',
 'static_sparse',
 'rigl',
 'set')

In [14]:
config = ml_collections.ConfigDict()

config.learning_rate = 0.01
config.momentum = 0.9
config.batch_size = 128
config.num_epochs = 200 # 1 epoch is 468 steps for bs=128

config.sparsity_config = ml_collections.ConfigDict()
config.sparsity_config.algorithm = 'rigl'
config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 200
config.sparsity_config.sparsity = 0.5
config.sparsity_config.dist_type = 'erk'

In [15]:
jaxpruner.create_updater_from_config(config.sparsity_config)

RigL(scheduler=PolynomialSchedule(update_freq=10, update_start_step=200, update_end_step=1000, power=3), skip_gradients=True, is_sparse_gradients=True, sparsity_type=Unstructured(), sparsity_distribution_fn=functools.partial(<function erk at 0x7f5ee129ca60>, sparsity=0.5), rng_seed=Array([0, 8], dtype=uint32), use_packed_masks=False, eps=1e-05, drop_fraction_fn=<function cosine_decay_schedule.<locals>.schedule at 0x7f6086f17eb0>, is_debug=False)

In [16]:
state = train_and_evaluate(train_loader,test_loader,config)

TypeError: TrainState.__init__() got an unexpected keyword argument 'batch_stats'

# One Shot pruning

In [None]:
config.sparsity_config.update_start_step = 0
config.sparsity_config.update_end_step = 0
config.sparsity_config.skip_gradients=True
state = train_and_evaluate(config)

# STE

In [None]:
# STE also supports gradual pruning schedules. 
# Here we train weights with sparse forward pass from the start.
config.sparsity_config.algorithm = 'magnitude_ste'
config.sparsity_config.sparsity = 0.95
config.sparsity_config.update_end_step = 0
config.sparsity_config.update_start_step = 0
config.sparsity_config.dist_type = 'erk'
state = train_and_evaluate(config)

# Global Pruning

In [None]:
config.sparsity_config.algorithm = 'global_magnitude'
config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 200
config.sparsity_config.sparsity = 0.95
config.sparsity_config.dist_type = 'erk'

state = train_and_evaluate(config)

In [None]:
jaxpruner.summarize_sparsity(state.opt_state.masks)

# Dynamic Sparse Training

In [None]:
config.sparsity_config.algorithm = 'rigl'

config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 1
config.sparsity_config.sparsity = 0.95
config.sparsity_config.dist_type = 'erk'

state = train_and_evaluate(config)

In [None]:
config.sparsity_config.algorithm = 'set'
state = train_and_evaluate(config)

In [None]:
config.sparsity_config.algorithm = 'static_sparse'

state = train_and_evaluate(config)

# Pruning After Training

In [None]:
config.sparsity_config.algorithm = 'no_prune'
state = train_and_evaluate(config)

In [None]:
config.sparsity_config.algorithm = 'magnitude'
config.sparsity_config.sparsity = 0.9
sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
pruned_params, _ = sparsity_updater.instant_sparsify(state.params)
print(jaxpruner.summarize_sparsity(pruned_params, only_total_sparsity=True))

In [None]:
_, test_ds = get_datasets()
pruned_state = state.replace(params=pruned_params)
_, _, test_accuracy  = apply_model(pruned_state, test_ds['image'], test_ds['label'])
print(test_accuracy*100)

# N:M sparsity

In [None]:
config.sparsity_config.algorithm = 'magnitude'
config.sparsity_config.sparsity_type = 'nm_1,4'
sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
pruned_params, masks = sparsity_updater.instant_sparsify(state.params)
print(jaxpruner.summarize_sparsity(pruned_params, only_total_sparsity=True))

In [None]:
masks['Dense_0']['kernel'][0][:16]

In [None]:
_, test_ds = get_datasets()
pruned_state = state.replace(params=pruned_params)
_, _, test_accuracy  = apply_model(pruned_state, test_ds['image'], test_ds['label'])
print(test_accuracy*100)

# Block Sparsity

In [None]:
config.sparsity_config.algorithm = 'magnitude'
config.sparsity_config.sparsity = 0.7
config.sparsity_config.sparsity_type = 'block_2,2'
sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
pruned_params, masks = sparsity_updater.instant_sparsify(state.params)
print(jaxpruner.summarize_sparsity(pruned_params, only_total_sparsity=True))

In [None]:
masks['Dense_0']['kernel'][:4, :16]

In [None]:
_, test_ds = get_datasets()
pruned_state = state.replace(params=pruned_params)
_, _, test_accuracy  = apply_model(pruned_state, test_ds['image'], test_ds['label'])
print(test_accuracy*100)

In [11]:
from typing import Any
from collections import defaultdict
from flax.training import train_state, checkpoints
import time



In [None]:
class TrainerModule:

    def __init__(self,
                 model_name : str,
                 model_class : nn.Module,
                 model_hparams : dict,
                 optimizer_name : str,
                 optimizer_hparams : dict,
                 exmp_imgs : Any,
                 seed=42):
        """
        Module for summarizing all training functionalities for classification on CIFAR10.

        Inputs:
            model_name - String of the class name, used for logging and saving
            model_class - Class implementing the neural network
            model_hparams - Hyperparameters of the model, used as input to model constructor
            optimizer_name - String of the optimizer name, supporting ['sgd', 'adam', 'adamw']
            optimizer_hparams - Hyperparameters of the optimizer, including learning rate as 'lr'
            exmp_imgs - Example imgs, used as input to initialize the model
            seed - Seed to use in the model initialization
        """
        super().__init__()
        self.model_name = model_name
        self.model_class = model_class
        self.model_hparams = model_hparams
        self.optimizer_name = optimizer_name
        self.optimizer_hparams = optimizer_hparams
        self.seed = seed
        # Create empty model. Note: no parameters yet
        self.model = res20()
        # Prepare logging
        self.log_dir = os.path.join('./', self.model_name)
        self.logger = SummaryWriter(log_dir=self.log_dir)
        # Create jitted training and eval functions
        self.create_functions()
        # Initialize model
        self.init_model(exmp_imgs)

    def create_functions(self):
        # Function to calculate the classification loss and accuracy for a model
        def calculate_loss(params, batch_stats, batch, train):
            imgs, labels = batch
            # Run model. During training, we need to update the BatchNorm statistics.
            outs = self.model.apply({'params': params, 'batch_stats': batch_stats},
                                    imgs,
                                    train=train,
                                    mutable=['batch_stats'] if train else False)
            logits, new_model_state = outs if train else (outs, None)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
            acc = (logits.argmax(axis=-1) == labels).mean()
            return loss, (acc, new_model_state)
        # Training function
        def train_step(state, batch):
            loss_fn = lambda params: calculate_loss(params, state.batch_stats, batch, train=True)
            # Get loss, gradients for loss, and other outputs of loss function
            ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
            loss, acc, new_model_state = ret[0], *ret[1]
            # Update parameters and batch statistics
            state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
            return state, loss, acc
        # Eval function
        def eval_step(state, batch):
            # Return the accuracy for a single batch
            _, (acc, _) = calculate_loss(state.params, state.batch_stats, batch, train=False)
            return acc
        # jit for efficiency
        self.train_step = jax.jit(train_step)
        self.eval_step = jax.jit(eval_step)

    def init_model(self, exmp_imgs):
        # Initialize model
        init_rng = jax.random.PRNGKey(self.seed)
        variables = self.model.init(init_rng, exmp_imgs, train=True)
        self.init_params, self.init_batch_stats = variables['params'], variables['batch_stats']
        self.state = None

    def init_optimizer(self, num_epochs, num_steps_per_epoch):
        # Initialize learning rate schedule and optimizer
        if self.optimizer_name.lower() == 'adam':
            opt_class = optax.adam
        elif self.optimizer_name.lower() == 'adamw':
            opt_class = optax.adamw
        elif self.optimizer_name.lower() == 'sgd':
            opt_class = optax.sgd
        else:
            assert False, f'Unknown optimizer "{opt_class}"'
        # We decrease the learning rate by a factor of 0.1 after 60% and 85% of the training
        lr_schedule = optax.piecewise_constant_schedule(
            init_value=self.optimizer_hparams.pop('lr'),
            boundaries_and_scales=
                {int(num_steps_per_epoch*num_epochs*0.6): 0.1,
                 int(num_steps_per_epoch*num_epochs*0.85): 0.1}
        )
        # Clip gradients at max value, and evt. apply weight decay
        # transf = [optax.clip(1.0)]
        if opt_class == optax.sgd and 'weight_decay' in self.optimizer_hparams:  # wd is integrated in adamw
            transf.append(optax.add_decayed_weights(self.optimizer_hparams.pop('weight_decay')))
        optimizer = optax.chain(
            *transf,
            opt_class(lr_schedule, **self.optimizer_hparams)
        )
        # Initialize training state
        self.state = TrainState.create(apply_fn=self.model.apply,
                                       params=self.init_params if self.state is None else self.state.params,
                                       batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats,
                                       tx=optimizer)

    def train_model(self, train_loader, val_loader, num_epochs=200):
        # Train model for defined number of epochs
        # We first need to create optimizer and the scheduler for the given number of epochs
        self.init_optimizer(num_epochs, len(train_loader))
        # Track best eval accuracy
        best_eval = 0.0
        for epoch_idx in tqdm(range(1, num_epochs+1)):
            self.train_epoch(train_loader, epoch=epoch_idx)
            if epoch_idx % 2 == 0:
                eval_acc = self.eval_model(val_loader)
                self.logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
                if eval_acc >= best_eval:
                    best_eval = eval_acc
                    self.save_model(step=epoch_idx)
                self.logger.flush()

    def train_epoch(self, train_loader, epoch):
        # Train model for one epoch, and log avg loss and accuracy
        metrics = defaultdict(list)
        for batch in tqdm(train_loader, desc='Training', leave=False):
            self.state, loss, acc = self.train_step(self.state, batch)
            metrics['loss'].append(loss)
            metrics['acc'].append(acc)
        for key in metrics:
            avg_val = np.stack(jax.device_get(metrics[key])).mean()
            self.logger.add_scalar('train/'+key, avg_val, global_step=epoch)

    def eval_model(self, data_loader):
        # Test model on all images of a data loader and return avg loss
        correct_class, count = 0, 0
        for batch in data_loader:
            acc = self.eval_step(self.state, batch)
            correct_class += acc * batch[0].shape[0]
            count += batch[0].shape[0]
        eval_acc = (correct_class / count).item()
        return eval_acc

    def save_model(self, step=0):
        # Save current model at certain training iteration
        checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
                                    target={'params': self.state.params,
                                            'batch_stats': self.state.batch_stats},
                                    step=step,
                                   overwrite=True)

    def load_model(self, pretrained=False):
        # Load model. We use different checkpoint for pretrained models
        if not pretrained:
            state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=None)
        else:
            state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join('./', f'{self.model_name}.ckpt'), target=None)
        self.state = TrainState.create(apply_fn=self.model.apply,
                                       params=state_dict['params'],
                                       batch_stats=state_dict['batch_stats'],
                                       tx=self.state.tx if self.state else optax.sgd(0.1)   # Default optimizer
                                      )

    def checkpoint_exists(self):
        # Check whether a pretrained model exist for this autoencoder
        return os.path.isfile(os.path.join('./', f'{self.model_name}.ckpt'))

In [None]:
class TrainState(train_state.TrainState):
    # A simple extension of TrainState to also include batch statistics
    batch_stats: Any

In [None]:
def train_classifier(*args, num_epochs=200, **kwargs):
    # Create a trainer module with specified hyperparameters
    trainer = TrainerModule(*args, **kwargs)
    if not trainer.checkpoint_exists():  # Skip training if pretrained model exists
        trainer.train_model(train_loader, val_loader, num_epochs=num_epochs)
        trainer.load_model()
    else:
        trainer.load_model(pretrained=True)
    # Test trained model
    val_acc = trainer.eval_model(val_loader)
    test_acc = trainer.eval_model(test_loader)
    return trainer, {'val': val_acc, 'test': test_acc}

In [None]:
resnet_kernel_init = nn.initializers.variance_scaling(2.0, mode='fan_out', distribution='normal')

class ResNetBlock(nn.Module):
    act_fn : callable  # Activation function
    c_out : int   # Output feature size
    subsample : bool = False  # If True, we apply a stride inside F

    @nn.compact
    def __call__(self, x, train=True):
        # Network representing F
        z = nn.Conv(self.c_out, kernel_size=(3, 3),
                    strides=(1, 1) if not self.subsample else (2, 2),
                    kernel_init=resnet_kernel_init,
                    use_bias=False)(x)
        z = nn.BatchNorm()(z, use_running_average=not train)
        z = self.act_fn(z)
        z = nn.Conv(self.c_out, kernel_size=(3, 3),
                    kernel_init=resnet_kernel_init,
                    use_bias=False)(z)
        z = nn.BatchNorm()(z, use_running_average=not train)

        if self.subsample:
            x = nn.Conv(self.c_out, kernel_size=(1, 1), strides=(2, 2), kernel_init=resnet_kernel_init)(x)

        x_out = self.act_fn(z + x)
        return x_out

In [None]:
class ResNet(nn.Module):
    num_classes : int
    act_fn : callable
    block_class : nn.Module
    num_blocks : tuple = (3, 3, 3)
    c_hidden : tuple = (16, 32, 64)

    @nn.compact
    def __call__(self, x, train=True):
        # A first convolution on the original image to scale up the channel size
        x = nn.Conv(self.c_hidden[0], kernel_size=(3, 3), kernel_init=resnet_kernel_init, use_bias=False)(x)
        if self.block_class == ResNetBlock:  # If pre-activation block, we do not apply non-linearities yet
            x = nn.BatchNorm()(x, use_running_average=not train)
            x = self.act_fn(x)

        # Creating the ResNet blocks
        for block_idx, block_count in enumerate(self.num_blocks):
            for bc in range(block_count):
                # Subsample the first block of each group, except the very first one.
                subsample = (bc == 0 and block_idx > 0)
                # ResNet block
                x = self.block_class(c_out=self.c_hidden[block_idx],
                                     act_fn=self.act_fn,
                                     subsample=subsample)(x, train=train)

        # Mapping to classification output
        x = x.mean(axis=(1, 2))
        x = nn.Dense(self.num_classes)(x)
        return x

In [None]:
resnet_trainer, resnet_results = train_classifier(model_name="ResNet",
                                                  model_class=ResNet,
                                                  model_hparams={"num_classes": 10,
                                                                 "c_hidden": (16, 32, 64),
                                                                 "num_blocks": (3, 3, 3),
                                                                 "act_fn": nn.relu,
                                                                 "block_class": ResNetBlock},
                                                  optimizer_name="SGD",
                                                  optimizer_hparams={"lr": 0.1,
                                                                     "momentum": 0.9,
                                                                     "weight_decay": 1e-4},
                                                  exmp_imgs=jax.device_put(
                                                      next(iter(train_loader))[0]),
                                                  num_epochs=200)


In [None]:
resnet_results