In [1]:
import torch
import jax_dataloader as jdl
import torchvision
import torchvision.transforms as transforms

  _torch_pytree._register_pytree_node(


In [2]:
class ToNumpy:
    def __call__(self, x: torch.Tensor):
        return x.numpy()

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToNumpy()])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


batch_size = 128

trainloader = jdl.DataLoader(trainset, backend="pytorch", batch_size=batch_size,
                                          shuffle=True)
testloader = jdl.DataLoader(testset, backend="pytorch", batch_size=batch_size,
                                         shuffle=False)

# classes in cifar10
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


In [4]:
import jax
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn


from einops import rearrange

class ConvNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        # convs
        out = nn.Conv(features=6, kernel_size=(5, 5))(x)
        out = nn.max_pool(out, window_shape=(2, 2))
        out = nn.Conv(features=16, kernel_size=(5, 5))(out)
        out = nn.max_pool(out, window_shape=(2, 2))


        # flatten into a vector 
        # skip the batch dim
        if len(x.shape) > 3:
            out = rearrange(x, "batch c h w -> batch (c h w)")
        else:
            out = out.flatten()

        # dense
        out = nn.Dense(features=120)(out)
        out = nn.Dense(features=84)(out)
        out = nn.Dense(features=10)(out)

        return out

In [5]:
model = ConvNet()
rng = jax.random.key(0)
params = model.init(rng, jnp.empty((3, 32, 32)))

# run a sample forward pass
logits = model.apply(params, jnp.empty((3, 32, 32)))
logits.shape

(10,)

In [6]:
import optax
from tqdm.auto import trange, tqdm
from flax.training import train_state
from functools import partial

@jax.jit
def calculate_loss(params, x, y):
    logits = model.apply(params, x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return loss

@jax.jit
def batched_loss(params, xs, ys):
    batch_loss = jax.vmap(calculate_loss, in_axes=(None, 0, 0))(params, xs, ys)
    return batch_loss.mean(axis=-1)



optimiser = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimiser
)
criterion = jax.value_and_grad(batched_loss)

@jax.jit
def train_step(state, batch):
    loss_value, grads = criterion(state.params, *batch)
    updated_state = state.apply_gradients(grads=grads)
    return loss_value, updated_state

Setting up checkpointing for the model

In [96]:
import orbax
from flax.training import orbax_utils

# since everything in jax is a pytree
# the checkpoints are basically the pytree versions of the params
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

# checkpoint manager for managing how many checkpoints to keep
# keep a max of 2 checkpoints
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=5, create=True)

# add the save path
save_path = "/tmp/cnn_ckpt"
checkpoint_manager = orbax.checkpoint.CheckpointManager(save_path, orbax_checkpointer, options)


In [97]:
# create save args for checkpoint_manager
ckpt = {
    "model": state,
    # "model_prngs": rng
}
save_args = orbax_utils.save_args_from_target(ckpt)

In [98]:
from tqdm.notebook import trange

def train(state, epochs, train_loader, test_loader):
    steps = 0
    losses = []
    # f1s = []

    # =============
    for e in trange(epochs):
        for batch in tqdm(train_loader):
            loss, state = train_step(state, batch)
            steps += 1

            # log every 200 steps
            if steps % 200 == 0:
                losses.append(loss)
                
                # run evaluation
                print("Evaluating ... ")
                # score = evaluate(state, test_loader)
                
                # f1s.append(score)
    
                print(f"Epoch : {e + 1} :: Step : {steps} :: Loss : {loss}")
        
        # save model ckpt
        checkpoint_manager.save(steps, ckpt, save_kwargs={"save_args": save_args})
        # orbax_checkpointer.save(f"{save_path}/{steps}", ckpt, save_args=save_args)
    # ============
    return state, losses

In [95]:
!rm -rf /tmp/cnn_ckpt

In [99]:
state, losses = train(state, 1, trainloader, testloader)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

Evaluating ... 
Epoch : 1 :: Step : 200 :: Loss : 1.2740516662597656


https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html

In [101]:
latest_step = checkpoint_manager.latest_step()
latest_step

391

In [104]:
raw_restored = checkpoint_manager.restore(latest_step)
restored_params = raw_restored["model"]["params"]

In [105]:
restored_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=restored_params,
        tx=optimiser
    )

In [106]:
# import os
# ckpt_path = os.path.join(save_path, str(latest_step))

# # but these are raw dicts and we need pytrees
# # the process outlined in the orbax docs is cumbersome
# # like why google? why?
# def restore_from_raw_checkpoint(latest_step):
#     empty_state = train_state.TrainState.create(
#         apply_fn=model.apply,
#         params=jax.tree_map(np.zeros_like, params),
#         tx=optimiser
#     )
    
#     # same as the save_args
#     target = {"model": empty_state}
    
#     latest_ckpt = checkpoint_manager.restore(latest_step, target)    
    
#     state_restored = orbax_checkpointer.restore(ckpt_path, item=target)
    
#     return state_restored

# restored_state = restore_from_raw_checkpoint(latest_step)

In [108]:
from sklearn.metrics import f1_score

@jax.jit
def test_step(state, xs):
    def infer(params, x):
        logits = model.apply(params, x)
        return jax.nn.softmax(logits, axis=-1) 

    preds = jax.vmap(jax.jit(infer), in_axes=(None, 0))(state.params, xs)
    return preds



def evaluate(state, test_loader):
    scores = list()
    for batch in tqdm(test_loader):
        xs, ys = batch
        preds = test_step(state, xs)
        preds = jnp.argmax(preds, axis=-1)
        f1 = f1_score(preds, ys, average="micro")
        scores.append(f1)

    return np.array(scores).mean(axis=-1)

In [110]:
# test the evaluation function
# should've named state but anyway
evaluate(restored_state, testloader)

  0%|          | 0/79 [00:00<?, ?it/s]

0.5327333860759493