# Introduction

In this notebook, we will be training an image classifier on [the Oxford-IIIT Pet dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) using flaim. The reader is assumed to be proficient at deep learning and familiar with JAX, Flax, and Optax, although this notebook might also serve as a helpful guide for those new to the JAX/Flax/Optax ecosystem who are seeking a basic training script free of more advanced ingredients like distributed or mixed-precision training.

This notebook was run in Colab on a T4 GPU. The Colab notebook can be accessed [here](https://colab.research.google.com/drive/1U02GNjWUUmjLt9gTHvxOikYvGdA5qFIz?usp=share_link).

_If you run out of memory during training, please restart the kernel and try again but conduct no training before the point at which you encounter the out-of-memory error._

# Dependencies

Colab already comes with popular deep learning frameworks like PyTorch and JAX, and flaim is the only library that requires to be installed to train our Flax image classifier. We will also be re-implementing our code in PyTorch + timm to compare-and-contrast and will therefore need to install timm as well.

In [None]:
%%capture
!pip install git+https://github.com/BobMcDear/flaim.git
!pip install git+https://github.com/rwightman/pytorch-image-models.git

In [None]:
import typing as T

# Flax + flaim training
import flaim
import jax
import optax
from flax.core.frozen_dict import FrozenDict
from flax.training.train_state import TrainState
from jax import numpy as jnp

# PyTorch + timm training
import timm
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, _LRScheduler

# Data loading
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet

# Data

TensorFlow-based input pipelines tend to be more efficient than their PyTorch counterparts, especially in multi-worker settings, and it is recommended that JAX developers process data using TensorFlow. For the purpose of learning flaim though, either would do, and we shall opt for PyTorch data loaders thanks to their simplicity.

Pre-processing will be kept to a minimum and consists of random resized cropping for the training set and resizing followed by a center crop for the validation set, in addition to converting the images to tensors and normalizing them. One caveat is that the input data must be channels-last JAX or NumPy arrays for Flax, so torchvision cannot be wholly relied upon. Instead, we'll write ```img_to_numpy``` for converting images to NumPy arrays, ```NumPyNormalize``` for normalizing them, and ```numpy_collate``` to be used as the collate function for the data loaders. Alternatively, we could've developed the data pipeline in pure PyTorch and converted the inputs and targets to JAX arrays during training on the fly.

In [None]:
def img_to_numpy(img) -> np.ndarray:
  """
  Converts an image to a float NumPy array with range [0, 1].

  Args:
    img: Image to convert.
  
  Returns (np.ndarray): Image converted to a float NumPy array
  with range [0, 1]. 
  """
  return np.asarray(img)/255.
  

class NumPyNormalize:
  """
  Normalizes a NumPy array along the last axis.

  Args:
    mean (T.Tuple[float, ...]): Mean for normalization.
    Default is (0.485, 0.456, 0.406).
    std (T.Tuple[float, ...]): Standard deviation for normalization.
    Default is (0.229, 0.224, 0.225).
  """
  def __init__(
      self,
      mean: T.Tuple[float, ...] = (0.485, 0.456, 0.406),
      std: T.Tuple[float, ...] = (0.229, 0.224, 0.225),
      ) -> None:
      self.mean = np.array(mean)
      self.std = np.array(std)
    
  def __call__(self, input: np.ndarray) -> np.ndarray:
    return (input - self.mean) / self.std


def numpy_collate(
    batch: T.List[T.Tuple[np.ndarray, int]],
    ) -> T.List[np.ndarray]:
  """
  Collates an input batch using NumPy for a PyTorch data loader.

  Args:
    batch (T.List[T.Tuple[np.ndarray, int]]): Batch of samples to collate.
  
  Returns (T.List[np.ndarray]): Collated batch.
  """
  transposed_batch = list(zip(*batch))
  return [np.stack(sample) for sample in transposed_batch]

```get_pets_dls``` fetches the data loaders for us.

In [None]:
def get_pets_dls(
    root: str = '.',
    val_resize: int = 256,
    size: int = 224,
    numpy: bool = True,
    bs: int = 64,
    norm_mean: T.Tuple[float, ...] = (0.485, 0.456, 0.406),
    norm_std: T.Tuple[float, ...] = (0.229, 0.224, 0.225),
    ) -> T.Tuple[DataLoader, DataLoader]:
  """
  Gets training and validation PyTorch data loaders for the Oxford-IIIT Pets dataset.

  Args:
    root (str): Root directory for storing the dataset.
    Default is '.'
    val_resize (int): Size to which the validation set is resized to before
    being center cropped.
    Default is 256.
    size (int): Random resized crop size for the training set and center
    crop size for the validation set.
    Default is 224.
    numpy (bool): Whether the data should be returned as NumPy channels-last
    arrays. If False, it is returned as channels-first PyTorch tensors.
    Default is True.
    bs (int): Batch size.
    Default is 64.
    norm_mean (T.Tuple[float, ...]): Mean for normalization.
    Default is (0.485, 0.456, 0.406).
    norm_std (T.Tuple[float, ...]): Standard deviation for normalization.
    Default is (0.485, 0.456, 0.406).
  
  Returns (T.Tuple[DataLoader, DataLoader]): Training and validation data loaders.
  """
  if numpy:
    to_tensor_and_normalize = (img_to_numpy, NumPyNormalize(norm_mean, norm_std))
    collate_fn = numpy_collate
  
  else:
    to_tensor_and_normalize = (transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std))
    collate_fn = None

  train_ds = OxfordIIITPet(
      root='.',
      split='trainval',
      transform=transforms.Compose([
          transforms.RandomResizedCrop(size),
          *to_tensor_and_normalize,
          ]),
      download=True,
      )
  valid_ds = OxfordIIITPet(
      root='.',
      split='test',
      transform=transforms.Compose([
          transforms.Resize(val_resize),
          transforms.CenterCrop(size),
          *to_tensor_and_normalize,
          ]),
      download=True,
      )

  train_dl = DataLoader(
      dataset=train_ds,
      batch_size=bs,
      shuffle=True,
      collate_fn=collate_fn,
      drop_last=True,
      )
  valid_dl = DataLoader(
      dataset=valid_ds,
      batch_size=bs,
      collate_fn=collate_fn,
      drop_last=True,
      )
  
  return train_dl, valid_dl

# Training (Flax + flaim)

We are ready to train our classifier, which will be based off of [ConvNeXt](https://arxiv.org/abs/2201.03545), a convolutional neural network (CNN) that borrows ideas from the transformer literature for state-of-the-art visual recognition. To construct models, flaim provides [```flaim.get_model```](https://github.com/bobmcdear/flaim#usage), which returns a model, its parameters, and optionally the corresponding normalization statistics. The training process is more or less identical for other networks, one major exception being models that incorporate batch normalization (BN), which will be studied later.

In [None]:
model, vars, norm_stats = flaim.get_model(
    model_name='convnext_small',
    pretrained='in1k_224',
    n_classes=37, # The number of breeds in the Pets dataset
    )

Using ```norm_stats```, we can create our data loaders.

In [None]:
%%capture
train_dl, valid_dl = get_pets_dls(
    norm_mean=norm_stats['mean'],
    norm_std=norm_stats['std'],
    )

Next, we set up an AdamW optimizer, with a cosine decay learning rate scheduler, using Optax. The scheduler requires the number of training iterations, which in turn depends on the number of epochs - we will  decide here to train for 5 epochs. 

In [None]:
n_epochs = 5
lr = 6e-4
wd = 1e-2

lr_scheduler = optax.cosine_decay_schedule(
  init_value=lr,
  decay_steps=n_epochs*len(train_dl), # Number of training iterations
)
optim = optax.adamw(
    learning_rate=lr_scheduler,
    weight_decay=wd,
    )

Currently, the three major elements of training - the model, parameters, and optimizer - are completely decoupled, so we should bunch them together using Flax's ```TrainState``` class.

In [None]:
state = TrainState.create(
    apply_fn=model.apply,
    params=vars,
    tx=optim,
    )

Finally, training can commence. The central component of our training script is ```train_iter```, a function that recieves ```state``` and a batch of inputs and targets, calculates the model's loss, and updates the parameters. ```train_iter``` itself contains an inner function, ```get_loss```, that only computes the cross-entropy loss, and JAX's ```value_and_grad``` automatically evaluates the gradients for us. Also, we just-in-time compile (JIT) our code for substantial performance boost.

In [None]:
@jax.jit
def train_iter(
    state: TrainState,
    input,
    target,
    ) -> T.Tuple[TrainState, float]:
  """
  Calculates the model's loss on the current batch
  and updates its parameters.

  Args:
    state (TrainState): State.
    input: Input.
    target: Target.
  
  Returns (T.Tuple[TrainState, float]): Updated state and loss.
  """
  def get_loss(vars):
    output = state.apply_fn(vars, input)
    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(output, target))
    return loss
  
  loss, grads = jax.value_and_grad(get_loss)(state.params)
  return state.apply_gradients(grads=grads), loss

A similar function is necessary for validation. It differs from ```train_iter``` in that it does not update the parameters and also returns accuracy.

In [None]:
@jax.jit
def valid_iter(
    state: TrainState,
    input,
    target,
    ) -> T.Tuple[float, float]:
  """
  Calculates the model's loss and accuracy on the current batch.

  Args:
    state (TrainState): State.
    input: Input.
    target: Target.
  
  Returns (T.Tuple[float, float]): Loss and accuracy.
  """
  output = state.apply_fn(state.params, input)
  loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(output, target))
  accuracy = jnp.mean(jnp.argmax(output, axis=-1) == target)
  return loss, accuracy

The remainder is standard deep learning - a training epoch consists of iterating through each batch and updating the model, after which the model is validated on the validation set.

In [None]:
def train_epoch(
    state: TrainState,
    train_dl: DataLoader,
    ) -> T.Tuple[TrainState, float]:
  """
  Performs one training epoch.

  Args:
    state (TrainState): State.
    train_dl (DataLoader): Training data loader.
  
  Returns (T.Tuple[TrainState, float]): Updated state and total loss.
  """
  n_samples = 0
  loss = 0

  for ind, (input, target) in enumerate(train_dl):
    if ind%10 == 0:
      print('\r', end='')
      print(f'Training iteration {ind}/{len(train_dl)}', end='')

    # jnp.array transfers to the data to the GPU automatically
    input, target = jnp.array(input), jnp.array(target)
    state, curr_loss = train_iter(state, input, target)

    n_samples += len(input)
    loss += len(input) * curr_loss
  
  return state, loss/n_samples


def validate(
    state: TrainState,
    valid_dl: DataLoader,
    ) -> T.Tuple[float, float]:
  """
  Validates the model.

  Args:
    state (TrainState): State.
    valid_dl (DataLoader): Validation data loader.
  
  Returns (T.Tuple[float, float]): Total loss and accuracy.
  """
  n_samples = 0
  loss = 0
  accuracy = 0

  for ind, (input, target) in enumerate(valid_dl):
    if ind%10 == 0:
      print('\r', end='')
      print(f'Validation iteration {ind}/{len(valid_dl)}', end='')

    # jnp.array transfers to the data to the GPU automatically
    input, target = jnp.array(input), jnp.array(target)
    curr_loss, curr_accuracy = valid_iter(state, input, target)

    n_samples += len(input)
    loss += len(input) * curr_loss
    accuracy += len(input) * curr_accuracy
  
  return loss/n_samples, accuracy/n_samples


def train(
    state: TrainState,
    train_dl: DataLoader,
    valid_dl: DataLoader,
    n_epochs: int = 5,
    ) -> TrainState:
  """
  Trains model.

  Args:
    state (TrainState): State.
    train_dl (DataLoader): Training data loader.
    valid_dl (DataLoader): Validation data loader.
    n_epochs (int): Number of epochs to train for.
    Default is 5.
  
  Returns (TrainState): Trained state.
  """
  for epoch in range(n_epochs):
    print(f'Epoch {epoch+1}')

    state, train_loss = train_epoch(state, train_dl)
    print('\r', end='')
    print(f'Training loss: {train_loss}')

    valid_loss, valid_accuracy = validate(state, valid_dl)
    print('\r', end='')
    print(f'Validation loss: {valid_loss}')
    print(f'Validation accuracy: {valid_accuracy}')
  
  return state

Let's train:

In [None]:
%%time
state = train(
    state=state,
    train_dl=train_dl,
    valid_dl=valid_dl,
    n_epochs=n_epochs,
    )

Epoch 1
Training loss: 1.2170429229736328
Validation loss: 0.29066792130470276
Validation accuracy: 0.9125548601150513
Epoch 2
Training loss: 0.38544732332229614
Validation loss: 0.2744150757789612
Validation accuracy: 0.9163925647735596
Epoch 3
Training loss: 0.2568410038948059
Validation loss: 0.21699966490268707
Validation accuracy: 0.9317434430122375
Epoch 4
Training loss: 0.17679345607757568
Validation loss: 0.18962512910366058
Validation accuracy: 0.9413377046585083
Epoch 5
Training loss: 0.15397047996520996
Validation loss: 0.18540027737617493
Validation accuracy: 0.9440789818763733
CPU times: user 8min 33s, sys: 3min 11s, total: 11min 45s
Wall time: 11min 48s


Training lasted roughly 12 minutes - as we will see, that is quite fast compared to PyTorch. It is important to bear in mind that the first epoch or so can be markedly slower because JAX is JITting ```train_iter``` and ```valid_iter``` for the first time. In the grand scheme of training, this extra overhead is typically negligible and is overshadowed by the actual training runtime, but one should be cognizant of it nonetheless.

## Batch normalization

Many networks contain layers such as batch normalization that exhibit different behaviour during training and inference. In PyTorch, modules have ```train``` and ```eval``` methods that signal to these submodules to enter training and inference mode respectively, but this strategy would not blend well with Flax's stateless nature. Consequently, Flax models - flaim included - oft-times accept a ```training``` argument in their ```apply``` method that places the network in training mode if ```True``` and in inference mode otherwise.

Additionally, batch normalization involves a set of running means and variances, stored in the dictionary of parameters with key ```batch_stats```, that are updated not through backpropagation but as part of the forward pass. That would not normally be permitted since Flax modules are immutable, so we must inform Flax that this collection is mutable by setting ```mutable = ['batch_stats']``` in ```state.apply_fn```.

In short, to adapt our code for models with BN, there are two adjustments we need to make. First, we must provide the appropriate values for ```training``` and ```mutable``` in each training or validation iteration to ensure the model is generating predictions correctly. Updating the model's parameters demands another modification: The training state, ```state```, cannot hold the running BN statistics in ```state.params``` because they are not updated through backpropagation; rather, the forward pass outputs the new means and variances, and one must manually replace the old ones with them. Hence, we will develop ```BNTrainState```, a child class of ```TrainState``` that has an additional ```batch_stats``` attribute for managing these means and variances. Note that ```train_epoch```, ```validate```, and ```train``` are not changed.



In [None]:
class BNTrainState(TrainState):
  """
  Training state with a batch_stats attribute for storing
  batch normalization statistics.
  """
  batch_stats: FrozenDict


@jax.jit
def train_iter(
    state: TrainState,
    input,
    target,
    ) -> T.Tuple[TrainState, float]:
  """
  Calculates the model's loss on the current batch
  and updates its parameters.

  Args:
    state (TrainState): State.
    input: Input.
    target: Target.
  
  Returns (T.Tuple[TrainState, float]): Updated state and loss.
  """
  def get_loss(params):
    vars = {'params': params, 'batch_stats': state.batch_stats}
    output, new_mutable_state = state.apply_fn(vars, input, training=True, mutable=['batch_stats'])
    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(output, target))
    return loss, new_mutable_state
  
  (loss, new_mutable_state), grads = jax.value_and_grad(get_loss, has_aux=True)(state.params)
  return state.apply_gradients(grads=grads, batch_stats=new_mutable_state['batch_stats']), loss


@jax.jit
def valid_iter(
    state: TrainState,
    input,
    target,
    ) -> T.Tuple[float, float]:
  """
  Calculates the model's loss and accuracy on the current batch.

  Args:
    state (TrainState): State.
    input: Input.
    target: Target.
  
  Returns (T.Tuple[float, float]): Loss and accuracy.
  """
  vars = {'params': state.params, 'batch_stats': state.batch_stats}
  output = state.apply_fn(vars, input, training=False, mutable=False)
  loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(output, target))
  accuracy = jnp.mean(jnp.argmax(output, axis=-1) == target)
  return loss, accuracy

This time, we will be training ECA-ResNet, a derivative of ResNet that benefits from channel adaptibility through [efficient channel attention (ECA)](https://arxiv.org/abs/1910.03151). ECA-ResNet is much faster than ConvNeXt, so we will double the number of epochs. 

In [None]:
%%capture
n_epochs = 10
lr = 4e-4
wd = 1e-2

model, vars, norm_stats = flaim.get_model(
    model_name='ecaresnet50t',
    pretrained='in1k_256',
    n_classes=37, # The number of breeds in the Pets dataset
    )

train_dl, valid_dl = get_pets_dls(
    norm_mean=norm_stats['mean'],
    norm_std=norm_stats['std'],
    )

lr_scheduler = optax.cosine_decay_schedule(
  init_value=lr,
  decay_steps=n_epochs*len(train_dl), # Number of training iterations
)
optim = optax.adamw(
    learning_rate=lr_scheduler,
    weight_decay=wd,
    )

state = BNTrainState.create(
    apply_fn=model.apply,
    params=vars['params'],
    batch_stats=vars['batch_stats'],
    tx=optim,
    )

In [None]:
%%time
state = train(
    state=state,
    train_dl=train_dl,
    valid_dl=valid_dl,
    n_epochs=n_epochs,
    )

Epoch 1
Training loss: 1.6559876203536987
Validation loss: 0.43099743127822876
Validation accuracy: 0.8637609481811523
Epoch 2
Training loss: 0.4938824772834778
Validation loss: 0.38565394282341003
Validation accuracy: 0.8780153393745422
Epoch 3
Training loss: 0.3837474584579468
Validation loss: 0.31008198857307434
Validation accuracy: 0.9018640518188477
Epoch 4
Training loss: 0.2917676270008087
Validation loss: 0.2872237265110016
Validation accuracy: 0.9084429740905762
Epoch 5
Training loss: 0.2586759328842163
Validation loss: 0.28441449999809265
Validation accuracy: 0.9125548601150513
Epoch 6
Training loss: 0.2273043394088745
Validation loss: 0.28030940890312195
Validation accuracy: 0.9152960777282715
Epoch 7
Training loss: 0.18541376292705536
Validation loss: 0.2546655833721161
Validation accuracy: 0.9207785129547119
Epoch 8
Training loss: 0.18228289484977722
Validation loss: 0.2613741457462311
Validation accuracy: 0.9191337823867798
Epoch 9
Training loss: 0.15694083273410797
Valida

92% accuracy in 12 minutes - not bad, but no match for a cutting-edge CNN like ConvNeXt. Note that signs of overfitting emerge in the latter half of learning, so you can experiment with increased regularization and augmentation for better generalization if you'd like. We shall train this network in PyTorch as well.

# Training (PyTorch + timm)

Below is a PyTorch implementation of our application. The overall layout is the same, and this tutorial is not concerned with PyTorch anyhow, so we will not dive any deeper into it.

In [None]:
def train_iter(
    model: nn.Module,
    optim: Optimizer,
    input: torch.Tensor,
    target: torch.Tensor,
    scheduler: T.Optional[_LRScheduler] = None,
    ) -> float:
  """
  Calculates the model's loss on the current batch
  and updates its parameters.

  Args:
    model (nn.Module): Model.
    optim (Optimizer): Optimizer.
    input (torch.Tensor): Input.
    target (torch.Tensor): Target.
    scheduler (T.Optional[_LRScheduler]): Optional learning
    rate scheduler. If None, a constant learning rate is used.
    Default is None.

  Returns (float): Loss.
  """
  output = model(input)
  loss = F.cross_entropy(output, target)
  loss.backward()

  optim.step()
  optim.zero_grad()

  if scheduler:
    scheduler.step()

  return loss.item()


def valid_iter(
    model: nn.Module,
    input: torch.Tensor,
    target: torch.Tensor,
    ) -> T.Tuple[float, float]:
  """
  Calculates the model's loss and accuracy on the current batch.

  Args:
    model (nn.Module): Model.
    input (torch.Tensor): Input.
    target (torch.Tensor): Target.
  
  Returns (T.Tuple[float, float]): Loss and accuracy.
  """
  output = model(input)
  loss = F.cross_entropy(output, target)
  accuracy = (torch.argmax(output, dim=-1) == target).float().mean()
  return loss.item(), accuracy.item()


def train_epoch(
    model: nn.Module,
    optim: Optimizer,
    train_dl: DataLoader,
    scheduler: T.Optional[_LRScheduler] = None,
    ) -> float:
  """
  Performs one training epoch.

  Args:
    model (nn.Module): Model.
    optim (Optimizer): Optimizer.
    train_dl (DataLoader): Training data loader.
    scheduler (T.Optional[_LRScheduler]): Optional learning
    rate scheduler. If None, a constant learning rate is used.
    Default is None.
  
  Returns (float): Total loss.
  """
  model.train()
  n_samples = 0
  loss = 0

  for ind, (input, target) in enumerate(train_dl):
    if ind%10 == 0:
      print('\r', end='')
      print(f'Training iteration {ind}/{len(train_dl)}', end='')

    input, target = input.cuda(), target.cuda()
    curr_loss = train_iter(
        model=model,
        optim=optim,
        input=input,
        target=target,
        scheduler=scheduler,
        )

    n_samples += len(input)
    loss += len(input) * curr_loss
  
  return loss/n_samples


def validate(
    model: nn.Module,
    valid_dl: DataLoader,
    ) -> T.Tuple[float, float]:
  """
  Validates the model.

  Args:
    model (nn.Module): Model.
    valid_dl (DataLoader): Validation data loader.
  
  Returns (T.Tuple[float, float]): Total loss and accuracy.
  """
  model.eval()
  n_samples = 0
  loss = 0
  accuracy = 0

  with torch.no_grad():
    for ind, (input, target) in enumerate(valid_dl):
      if ind%10 == 0:
        print('\r', end='')
        print(f'Validation iteration {ind}/{len(valid_dl)}', end='')

      input, target = input.cuda(), target.cuda()
      curr_loss, curr_accuracy = valid_iter(model, input, target)

      n_samples += len(input)
      loss += len(input) * curr_loss
      accuracy += len(input) * curr_accuracy
  
  return loss/n_samples, accuracy/n_samples


def train(
  model: nn.Module,
  optim: nn.Module,
  train_dl: DataLoader,
  valid_dl: DataLoader,
  scheduler: T.Optional[_LRScheduler] = None,
  n_epochs: int = 5,
  ) -> nn.Module:
  """
  Trains model.

  Args:
    model (nn.Module): Model.
    optim (Optimizer): Optimizer.
    train_dl (DataLoader): Training data loader.
    valid_dl (DataLoader): Validation data loader.
    scheduler (T.Optional[_LRScheduler]): Optional learning
    rate scheduler. If None, a constant learning rate is used.
    Default is None.
    n_epochs (int): Number of epochs to train for.
    Default is 5.
  
  Returns (nn.Module): Trained model.
  """
  for epoch in range(n_epochs):
    print(f'Epoch {epoch+1}')

    train_loss = train_epoch(
        model=model,
        optim=optim,
        train_dl=train_dl,
        scheduler=scheduler,
        )
    print('\r', end='')
    print(f'Training loss: {train_loss}')

    valid_loss, valid_accuracy = validate(model, valid_dl)
    print('\r', end='')
    print(f'Validation loss: {valid_loss}')
    print(f'Validation accuracy: {valid_accuracy}')
  
  return model

Let's create ConvNeXt-Small with timm and train. 

In [None]:
%%capture
n_epochs = 5
lr = 6e-4
wd = 1e-2

model = timm.create_model(
    model_name='convnext_small.fb_in1k',
    pretrained=True,
    num_classes=37, # The number of breeds in the Pets dataset
    ).cuda()
optim = AdamW(
    params=model.parameters(),
    lr=lr,
    weight_decay=wd,
    )

train_dl, valid_dl = get_pets_dls(
    numpy=False,
    norm_mean=model.default_cfg['mean'],
    norm_std=model.default_cfg['std'],
    )

scheduler = CosineAnnealingLR(
    optimizer=optim,
    T_max=n_epochs*len(train_dl), # Number of training iterations
    )

In [None]:
%%time
model = train(
    model=model,
    optim=optim,
    train_dl=train_dl,
    valid_dl=valid_dl,
    scheduler=scheduler,
    n_epochs=n_epochs,
    )

Epoch 1
Training loss: 1.354935941466114
Validation loss: 0.3003519731001896
Validation accuracy: 0.9139254385964912
Epoch 2
Training loss: 0.3941231843149453
Validation loss: 0.27561301937359467
Validation accuracy: 0.9114583333333334
Epoch 3
Training loss: 0.2761638104392771
Validation loss: 0.220189911280677
Validation accuracy: 0.9322916666666666
Epoch 4
Training loss: 0.20301196985600287
Validation loss: 0.1899036405232261
Validation accuracy: 0.9377741228070176
Epoch 5
Training loss: 0.1642190142812436
Validation loss: 0.18487075274287348
Validation accuracy: 0.9413377192982456
CPU times: user 14min 2s, sys: 9min 13s, total: 23min 15s
Wall time: 23min 16s


Final accuracy is similar to what we obtained with Flax + flaim, but the latter was 2x as fast, which attests to the potential of JAX and JIT. Dramatic speedups like this, however, are not always realistic; in the following section, we will examine a case where JAX's speed gains are far more modest.

## Batch normalization

Since PyTorch modules are stateful, no alterations are necessary to train models with batch normalization. We will reuse the code above to train ECA-ResNet.

In [None]:
%%capture
n_epochs = 10
lr = 4e-4
wd = 1e-2

model = timm.create_model(
    model_name='ecaresnet50t',
    pretrained=True,
    num_classes=37, # The number of breeds in the Pets dataset
    ).cuda()
optim = AdamW(
    params=model.parameters(),
    lr=lr,
    weight_decay=wd,
    )

train_dl, valid_dl = get_pets_dls(
    numpy=False,
    norm_mean=model.default_cfg['mean'],
    norm_std=model.default_cfg['std'],
    )

scheduler = CosineAnnealingLR(
    optimizer=optim,
    T_max=n_epochs*len(train_dl), # Number of training iterations
    )

In [None]:
%%time
model = train(
    model=model,
    optim=optim,
    train_dl=train_dl,
    valid_dl=valid_dl,
    scheduler=scheduler,
    n_epochs=n_epochs,
    )

Epoch 1
Training loss: 1.7298050559403604
Validation loss: 0.45278381732733625
Validation accuracy: 0.8552631578947368
Epoch 2
Training loss: 0.49622445435900436
Validation loss: 0.3360696582842553
Validation accuracy: 0.8928179824561403
Epoch 3
Training loss: 0.37976034469248954
Validation loss: 0.3354140745982397
Validation accuracy: 0.8939144736842105
Epoch 4
Training loss: 0.30646385644611557
Validation loss: 0.2665331764636855
Validation accuracy: 0.9141995614035088
Epoch 5
Training loss: 0.28369567124989997
Validation loss: 0.2571168984041402
Validation accuracy: 0.9177631578947368
Epoch 6
Training loss: 0.22420173778868557
Validation loss: 0.2767221699142012
Validation accuracy: 0.9183114035087719
Epoch 7
Training loss: 0.1982107046141959
Validation loss: 0.25447199858823105
Validation accuracy: 0.9202302631578947
Epoch 8
Training loss: 0.17676571939598051
Validation loss: 0.24388352461289942
Validation accuracy: 0.9229714912280702
Epoch 9
Training loss: 0.17798631664430886
Vali

Once again, the model's loss and accuracy are close to those of Flax + flaim, but PyTorch training was 30% slower - a noticeable gap, yet not on par with a 2x speedup. This should be a reminder that JAX performance boosts vary wildly depending on the network and can fall anywhere between being an order of magnitude faster or being virtually no better than PyTorch. Some other factors that influence the degree of speedup JAX offers include hardware, batch size, types of layers used, etc. Unfortunately, it is difficult to ascertain the extent of this speedup a priori, and the only way to accurately do so is by empirically benchmarking the code. 

# Conclusion

In this notebook, we trained a ConvNeXt pet breed classifier using Flax + flaim, doing so twice as fast as an equivalent PyTorch + timm script whilst achieving comparable scores. Moreover, we also examined the special case of models with batch normalization, which require extra attention in Flax due to Flax's stateless philospophy. We trained ECA-ResNet as an example of a model with BN and discovered PyTorch was only 30% slower than JAX. Accordingly, we must be aware that despite JAX being generally more efficient than PyTorch, how much faster it is exactly is conditioned by, e.g., the network architecture, hardware, and so forth. 