In [1]:
import flax
nn = flax.linen

import jax
import jax.numpy as jnp
from jax import random, value_and_grad


In [2]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x, **kwargs):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x


In [3]:
import torch
from torchvision.datasets import FashionMNIST, CIFAR10
from torchvision import transforms
from torch.utils import data
import torchvision
import numpy as np

from typing import Sequence, Union

In [4]:
train_dataset = CIFAR10(root='data', train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0,1,2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0,1,2))
print("Data mean", DATA_MEANS)
print("Data std", DATA_STD)


Files already downloaded and verified
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]


In [5]:
# Transformations applied on each image => bring them into a numpy array
def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - DATA_MEANS) / DATA_STD
    return img

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      image_to_numpy
                                     ])
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root='data' , train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root='data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))

# Loading the test set
test_set = CIFAR10(root='data', train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for training and validation
train_loader = data.DataLoader(
    train_set,
    batch_size=8,
    shuffle=True,
    drop_last=True,
    collate_fn=numpy_collate,
    num_workers=32,
    persistent_workers=True,
    generator=torch.Generator().manual_seed(1024))
val_loader   = data.DataLoader(val_set,
    batch_size=32,
    shuffle=False,
    drop_last=False,
    collate_fn=numpy_collate,
    num_workers=8,
    persistent_workers=True,
    generator=torch.Generator().manual_seed(1024))

test_loader  = data.DataLoader(test_set,
    batch_size=32,
    shuffle=False,
    drop_last=False,
    collate_fn=numpy_collate,
    num_workers=8,
    persistent_workers=True,
    generator=torch.Generator().manual_seed(1024))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified




In [6]:
imgs, _ = next(iter(train_loader))
print("Batch mean", imgs.mean(axis=(0,1,2)))
print("Batch std", imgs.std(axis=(0,1,2)))

Batch mean [-0.12891286 -0.22238259 -0.19784406]
Batch std [0.99327482 1.01667114 0.98516537]


In [7]:
from FlaxTrainer.trainer import TrainerModule
from FlaxTrainer.trainstates import TrainState
import optax

flax not found, run 'pip install --upgrade git+https://github.com/google/flax.git'


In [8]:
class cnnTrainer(TrainerModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def create_functions(self):
        def cross_entropy_loss(params, apply_fn, batch):
            x, y = batch
            y = jax.nn.one_hot(y, num_classes=10)
            logit = apply_fn({'params':params}, x)
            loss = optax.softmax_cross_entropy(logits=logit , labels=y).mean()
            return loss
    
        def train_step(state, batch):
            loss_fn = lambda params: cross_entropy_loss(params, state.apply_fn, batch)
            loss, grads = jax.value_and_grad(loss_fn)(state.params)
            state = state.apply_gradients(grads=grads)
            metrics = {'loss': loss}
            return state, metrics
        
        def eval_step(state, batch):
            loss = cross_entropy_loss(state.params, state.apply_fn, batch)
            return {'loss': loss}

        return train_step, eval_step




model = CNN()    

In [9]:
CHECKPOINT_PATH = "./saved_models/"
# TODO: Solve conflict of check_val_every_n_epoch and num_epochs
#mock = mockedcallback.MockedCallback(stop_train=False)
trainer = cnnTrainer(optimizer_hparams={'lr': 4e-3},
                            logger_params={'base_log_dir': CHECKPOINT_PATH},                           
                            check_val_every_n_epoch=5,
                            enable_progress_bar=True)
 #                           callbacks=[mock])

In [10]:
state = trainer.init_model(
    model=model,exmp_input=next(iter(train_loader))[0]
)







In [11]:
b = model.apply({'params': state.params}, next(iter(train_loader))[0])
b.shape

(8, 10)

In [12]:
trainer.train_model(model, state, train_loader=train_loader, val_loader=val_loader, num_epochs=100)

Epochs:   4%|▍         | 4/100 [00:53<21:35, 13.50s/it]


NameError: name 'checkpoints' is not defined

In [269]:
for batch in train_dataset:
    x, y = batch
    print( y)
    break

6


In [275]:
len(train_loader)

6250

In [13]:
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x


def cross_entropy_loss(*, logits, labels):

  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()


def compute_metrics(*, logits, labels):

  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics






In [14]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, next(iter(train_loader))[0])['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)




In [15]:
from time import time



@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch[0])
    loss = cross_entropy_loss(logits=logits, labels=batch[1])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch[1])
  return state, metrics


@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch[0])
  return compute_metrics(logits=logits, labels=batch[1])


def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds)
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for batch in train_ds:
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state


def eval_model(params, test_ds):
  
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']





rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

num_epochs = 10
batch_size = 32
for epoch in range(1, num_epochs + 1):
  t = time()
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_loader, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  for batch in test_loader:
    test_loss, test_accuracy = eval_model(state.params, batch)
  #print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
  #    epoch, test_loss, test_accuracy * 100))
  print("elpased_time: ", (time() - t))

train epoch: 1, loss: 2.3343, accuracy: 9.69
elpased_time:  11.437186002731323
train epoch: 2, loss: 2.3312, accuracy: 9.92
elpased_time:  8.585633754730225
train epoch: 3, loss: 2.3322, accuracy: 9.75
elpased_time:  8.396543979644775
train epoch: 4, loss: 2.3300, accuracy: 10.15
elpased_time:  8.821658611297607


KeyboardInterrupt: 