In [1]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

import numpy as np

from typing import Any

2023-03-03 02:50:16.110978: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-03-03 02:50:16.111058: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64


In [2]:
rng = jax.random.PRNGKey(0)
batch = jax.random.normal(rng, (4, 32, 32, 3))

# Datasets

In [3]:
def get_datasets(batch_size=32):
    """
    Creates train, validation, and test datasets.
    Applies data normalization to all datasets and augmentation to training only.
    """
    train_ds, val_ds, test_ds = tfds.load(
        "cifar10", 
        split=["train[:90%]", "train[90%:]", "test"],
        as_supervised=True
    )

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    var = [x ** 2 for x in std]

    augment_pipeline = tf.keras.Sequential([
        tf.keras.layers.Rescaling(scale=1./255),
        tf.keras.layers.Normalization(mean=mean, variance=var),
        tf.keras.layers.ZeroPadding2D(padding=(4, 4)),
        tf.keras.layers.RandomFlip(mode="horizontal"),
        tf.keras.layers.RandomCrop(height=32, width=32)
    ])

    evaluate_pipeline = tf.keras.Sequential([
        tf.keras.layers.Rescaling(scale=1./255),
        tf.keras.layers.Normalization(mean=mean, variance=var),
    ])

    augment_pipeline.compile()
    evaluate_pipeline.compile()

    AUTOTUNE = tf.data.AUTOTUNE

    train_ds = train_ds.batch(batch_size, drop_remainder=True).map(lambda x, y: (augment_pipeline(x, training=True), y))
    train_ds = train_ds.cache().shuffle(1000).prefetch(AUTOTUNE)

    val_ds = val_ds.batch(batch_size, drop_remainder=True).map(lambda x, y: (evaluate_pipeline(x, training=False), y))
    val_ds = val_ds.cache().prefetch(AUTOTUNE)
    
    test_ds = test_ds.batch(batch_size, drop_remainder=True).map(lambda x, y: (evaluate_pipeline(x, training=False), y))
    test_ds = test_ds.cache().prefetch(AUTOTUNE)
    
    return train_ds, val_ds, test_ds

# ResNet Layer

In [4]:
from typing import Dict

class IdentityResidual(nn.Module):
    out_channels: int
    stride: int = 1

    def __call__(self, x):
        _, _, _, c = x.shape  # BHWC
        x = x[:, ::self.stride, ::self.stride, :]  # Downsample spatial dims
        if c != self.out_channels:  # Pad extra channels
            b, h, w, c = x.shape
            pad = jnp.zeros((b, h, w, self.out_channels - c))
            x = jnp.concatenate([x, pad], axis=-1)
        return x

class ResNetV2Layer(nn.Module):
    out_channels: int
    stride: int = 1

    def setup(self):
        conv_kwargs = {"padding": "SAME", "use_bias": False, "kernel_size": (3, 3)}
        self.conv1 = nn.Conv(self.out_channels, strides=self.stride, **conv_kwargs)
        self.conv2 = nn.Conv(self.out_channels, **conv_kwargs)
        self.bn1 = nn.BatchNorm(use_running_average=True, momentum=0.9)  # Momentum set to match PyTorch
        self.bn2 = nn.BatchNorm(use_running_average=True, momentum=0.9)
        self.residual = IdentityResidual(self.out_channels, self.stride)
        self.relu = nn.relu

    def __call__(self, x):
        residual = self.residual(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x + residual

In [5]:
layer = ResNetV2Layer(32, 2)

In [6]:
variables = layer.init(rng, batch)
output = layer.apply(variables, batch)
output.shape

(4, 16, 16, 32)

# ResNet Model

In [7]:
from functools import partial

class ResNetV2Model(nn.Module):
    output_classes: int = 10

    @nn.compact
    def __call__(self, x):
        return nn.Sequential([
            nn.Conv(16, kernel_size=(3, 3), padding="SAME", use_bias=False),
            ResNetV2Layer(16),
            ResNetV2Layer(16),
            ResNetV2Layer(16),
            ResNetV2Layer(32, stride=2),
            ResNetV2Layer(32),
            ResNetV2Layer(32),
            ResNetV2Layer(64, stride=2),
            ResNetV2Layer(64),
            ResNetV2Layer(64),
            partial(jnp.mean, axis=(1, 2)),  # Global average pooling over spatial dims
            nn.Dense(self.output_classes)
        ])(x)

In [8]:
model = ResNetV2Model()
variables = model.init(rng, batch)
output = model.apply(variables, batch)
output.shape

(4, 10)

# Training Functions

In [10]:
def loss_fn(logits, labels):
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

def compute_metrics(*, logits, labels):
    loss = loss_fn(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy,
    }
    return metrics


class TrainState(train_state.TrainState):
    """Custom train state for BatchNorm stats"""
    batch_stats: Any

def create_train_state(rng, learning_rate, batch_size, weight_decay=1e-4):
    """Creates initial `TrainState`."""
    model = ResNetV2Model()
    variables = model.init(rng, jnp.ones([batch_size, 32, 32, 3]))
    params = variables["params"]
    batch_stats = variables["batch_stats"]
    tx = optax.adamw(learning_rate, weight_decay=weight_decay)
    
    return TrainState.create(apply_fn=model.apply, params=params, batch_stats=batch_stats, tx=tx)

@jax.jit
def train_step(state, inputs, labels):
    """Train for a single step."""
    
    def objective(params):
        logits, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            inputs, mutable=['batch_stats']  # Mutate batch stats during train step
        )
        loss = loss_fn(logits=logits, labels=labels)
        return loss, (logits, updates)
    
    grad_fn = jax.value_and_grad(objective, has_aux=True)
    (loss, (logits, updates)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])  # Update with new batch stats
    metrics = compute_metrics(logits=logits, labels=labels)
    
    return state, metrics

@jax.jit
def eval_step(state, inputs, labels):
    logits = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},  # Use current batch stats in state
        inputs  # Don't mutate batch stats in eval
    )
    metrics = compute_metrics(logits=logits, labels=labels)
    return state, metrics

In [11]:
def train_epoch(state, train_ds, epoch):
    """Train for a single epoch."""
    batch_metrics = []

    for inputs, labels in train_ds:
        inputs = jnp.float32(inputs)
        labels = jnp.int32(labels)
        
        state, metrics = train_step(state, inputs, labels)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

def eval_model(state, test_ds, epoch):
    batch_metrics = []

    for inputs, labels in test_ds:
        inputs = jnp.float32(inputs)
        labels = jnp.int32(labels)

        state, metrics = eval_step(state, inputs, labels)
        batch_metrics.append(metrics)
    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

# Do training

In [15]:
train_ds, eval_ds, test_ds = get_datasets(128)

































In [16]:
rng, init_rng = jax.random.split(rng)
LR = 1e-3
EPOCHS = 10
BATCH_SIZE = 128
state = create_train_state(init_rng, LR, BATCH_SIZE)

In [17]:
for epoch in range(1, EPOCHS + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    
    # Run an optimization step over a training batch
    state = train_epoch(state, train_ds, epoch)
    
    # Evaluate on the test set after each training epoch
    eval_model(state, test_ds, epoch)

train epoch: 1, loss: 1.7785, accuracy: 32.95
eval epoch: 1, loss: 1.4341, accuracy: 46.43
train epoch: 2, loss: 1.3733, accuracy: 49.56
eval epoch: 2, loss: 1.3047, accuracy: 54.34
train epoch: 3, loss: 1.1581, accuracy: 58.22
eval epoch: 3, loss: 1.0763, accuracy: 62.18
train epoch: 4, loss: 1.0057, accuracy: 63.97
eval epoch: 4, loss: 0.9932, accuracy: 64.03
train epoch: 5, loss: 0.8874, accuracy: 68.31
eval epoch: 5, loss: 0.9396, accuracy: 66.89
train epoch: 6, loss: 0.7865, accuracy: 71.81
eval epoch: 6, loss: 0.9030, accuracy: 68.25
train epoch: 7, loss: 0.6995, accuracy: 75.03
eval epoch: 7, loss: 0.8431, accuracy: 71.85
train epoch: 8, loss: 0.6120, accuracy: 78.25
eval epoch: 8, loss: 0.8117, accuracy: 72.45
train epoch: 9, loss: 0.5356, accuracy: 80.93
eval epoch: 9, loss: 0.8818, accuracy: 72.23
train epoch: 10, loss: 0.4742, accuracy: 83.04
eval epoch: 10, loss: 0.9214, accuracy: 72.46
