In [1]:
from typing import Callable, Tuple, Any, Dict, List
from absl import logging
from functools import partial
from tqdm import tqdm
from dataclasses import dataclass

import os
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax.random import PRNGKey as jkey
from chex import Scalar, Array, PRNGKey, Shape
import flax
from flax import linen as nn
from flax.training.train_state import TrainState as RawTrainState
from flax.training.checkpoints import restore_checkpoint
import optax
import matplotlib.pyplot as plt
import tensorflow as tf

from training_cnn import get_CIFAR10, Metrices, checkpoint
from architectures import *

logging.set_verbosity(logging.WARN)


SEED = 42



In [2]:
# Loading CIFAR10 dataset
(x_train, y_train), (x_test, y_test) = get_CIFAR10(jkey(SEED), 1.0)

first_n = 8
dummy_batch = x_train[:first_n]
dummy_labels = y_train[:first_n]

In [3]:
class TrainState(RawTrainState):
    batch_stats: flax.core.FrozenDict


def conv_block_with_bn(x: Array, features: int, training: bool) -> Array:

    x = nn.Conv(features=features, kernel_size=(3, 3), padding='SAME')(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.Conv(features=features, kernel_size=(3, 3), padding='SAME')(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    return x


class BatchNormCNN(nn.Module):
    
    @nn.compact
    def __call__(self, batch: Array, training: bool):
        
        batch_size = batch.shape[0]
        x = batch / 255
        x = conv_block_with_bn(x, 20, training)
        x = conv_block_with_bn(x, 40, training)
        x = conv_block_with_bn(x, 80, training)
        x = conv_block_with_bn(x, 160, training)
        x = jnp.reshape(x, (batch_size, -1))
        x = nn.Dense(features=10)(x)
        
        return x


def create_BatchNormCNN(
    dummy_batch: Array,
    init_key: PRNGKey,
    lr: Scalar = 0.001,
    momentum: Scalar = 0.9
) -> TrainState:

    cnn = BatchNormCNN()
    variables = cnn.init(init_key, dummy_batch, training=False)

    return TrainState.create(
        apply_fn=cnn.apply,
        params=variables['params'],
        tx=optax.sgd(learning_rate=lr, momentum=momentum),
        batch_stats=variables['batch_stats'])


@jax.jit
def eval_BatchNormCNN(state: TrainState, batch: Array, labels: Array):
    
    logits = BatchNormCNN().apply(
        {'params': state.params, 'batch_stats': state.batch_stats},
        batch,
        training=False
    )
    one_hot = jax.nn.one_hot(labels, 10)

    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))

    return loss, accuracy

In [4]:
@jax.jit
def apply_model(state: TrainState, batch: Array, labels: Array):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params, batch_stats):

        logits, batch_stats = state.apply_fn(
            {'params': params, 'batch_stats': batch_stats},
            batch,
            training=True,
            mutable=['batch_stats']
        )
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))

        return loss, (logits, batch_stats)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits, batch_stats)), grads = grad_fn(state.params, state.batch_stats)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    
    return grads, loss, accuracy


@jax.jit
def update_model(state: TrainState, grads: nn.FrozenDict):

    return state.apply_gradients(grads=grads)

In [5]:
def train_epoch(
	state: TrainState,
	x_train: Array,
	y_train: Array,
	batch_size: int,
	perm_key: PRNGKey
) -> Tuple[TrainState, Scalar, Scalar]:

	n_samples = x_train.shape[0]
	steps_per_epoch = n_samples // batch_size

	perms = jax.random.permutation(perm_key, n_samples)[:steps_per_epoch * batch_size]
	perms = jnp.reshape(perms, (steps_per_epoch, batch_size))

	epoch_loss = []
	epoch_accuracy = []

	for perm in perms:

		x_batch = x_train[perm, ...]
		y_batch = y_train[perm, ...]
		grads, loss, accuracy = apply_model(state, x_batch, y_batch)
		state = update_model(state, grads)
		epoch_loss.append(loss)
		epoch_accuracy.append(accuracy)
	
	return state, np.mean(epoch_loss), np.mean(epoch_accuracy)

In [6]:
def train_and_eval(
    seed: int,
	epochs: int,
	batch_size: int,
	create_state_fun: Callable,
	lr: Scalar = 0.001,
	momentum: Scalar = 0.9,
	ds_chunk_size = 1.0,
	log_every: int = 0,
    checkpoint_dir: str = "",
) -> Tuple[TrainState, Metrices, float]:

    # Create PRNG keys
    key = jkey(seed)
    key, ds_key, init_key = jax.random.split(key, 3)

    # Load CIFAE10 dataset
    (x_train, y_train), (x_test, y_test) = get_CIFAR10(ds_key, chunk_size=ds_chunk_size)

    # Create structures to accumulate metrices
    metrices = Metrices(epochs)

    # Get initial MinCNN training state
    state = create_state_fun(x_train, init_key, lr=lr, momentum=momentum)

    # Iterate through the dataset for epochs number of times
    start = time.time()
    for epoch in range(1, epochs + 1):

        key, epoch_key = jax.random.split(key)
        state, train_loss, train_accuracy = train_epoch(state, x_train, y_train, batch_size, epoch_key)
        test_loss, test_accuracy = eval_BatchNormCNN(state, x_test, y_test)
        metrices.update(train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
        
        if log_every and (epoch % log_every == 0 or epoch in {1, epochs}):
            print(
                'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
                % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
            )
        
        checkpoint(checkpoint_dir, state, metrices, epoch, time.time() - start)

    return state, metrices, time.time() - start

In [8]:
final_state, metrices, elapsed_time = train_and_eval(
    seed=42,
    epochs=5,
    batch_size=32,
    create_state_fun=create_BatchNormCNN,
    lr=0.001,
    momentum=0.9,
    ds_chunk_size=1.0,
    log_every=1,
    checkpoint_dir=os.path.join("checkpoints/batch_norm_cnn"),
)
print(f"Total training time: {elapsed_time:.3f}")

epoch:  1, train_loss: 1.3804, train_accuracy: 50.33, test_loss: 2.2957, test_accuracy: 13.71
epoch:  2, train_loss: 0.9800, train_accuracy: 65.12, test_loss: 2.2902, test_accuracy: 12.83
epoch:  3, train_loss: 0.7898, train_accuracy: 72.38, test_loss: 2.2879, test_accuracy: 11.44
epoch:  4, train_loss: 0.6582, train_accuracy: 76.94, test_loss: 2.2892, test_accuracy: 11.34
epoch:  5, train_loss: 0.5481, train_accuracy: 80.96, test_loss: 2.2849, test_accuracy: 13.11
Total training time: 364.011


In [7]:
key, init_key, epoch_key = jax.random.split(jkey(42), 3)
state = create_BatchNormCNN(x_train, init_key)

state, train_loss, train_accuracy = train_epoch(state, x_train, y_train, 32, epoch_key)
loss, accuracy = eval_BatchNormCNN(state, x_test, y_test)