In [1]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from resblock import ResBlock

# dataset loading
import tensorflow as tf
import tensorflow_datasets as tfds

# training
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers


[3m                                ResBlock Summary                                [0m
┏━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃[1m [0m[1mpath      [0m[1m [0m┃[1m [0m[1mmodule   [0m[1m [0m┃[1m [0m[1minputs     [0m[1m [0m┃[1m [0m[1moutputs   [0m[1m [0m┃[1m [0m[1mparams     [0m[1m [0m┃[1m [0m[1mbatch_sta…[0m[1m [0m┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│            │ ResBlock  │ [2mfloat32[0m[1,… │ [2mfloat32[0m[1… │             │            │
├────────────┼───────────┼─────────────┼────────────┼─────────────┼────────────┤
│ Conv_0     │ Conv      │ [2mfloat32[0m[1,… │ [2mfloat32[0m[1… │ bias:       │            │
│            │           │             │            │ [2mfloat32[0m[64] │            │
│            │           │             │            │ kernel:     │            │
│            │           │             │            │ [2mfloat32[0m[3,… │   

2024-03-21 14:45:46.936482: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-21 14:45:46.936544: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-21 14:45:46.963576: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


### Dataset Exploration

In [2]:
def get_datasets(num_epochs, batch_size):
  """Load MNIST train and test datasets into memory."""
  train_ds = tfds.load('mnist', split='train')
  test_ds = tfds.load('mnist', split='test')

  train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                           tf.float32) / 255.,
                                          'label': sample['label']}) # normalize train set
  test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                         tf.float32) / 255.,
                                        'label': sample['label']}) # normalize test set

  train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
  test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency

  return train_ds, test_ds


### Network

In [3]:
class ResNet(nn.Module):
    """
    Identical to architecture defined in original ResNet paper
    """
    stack_s_size: int = 3
    stack_m_size: int = 4
    stack_l_size: int = 6
    num_classes: int = 10
    pool: nn.Module = nn.avg_pool
    linear: nn.Module = nn.Dense

    def setup(self):
        self.stack_s = nn.Sequential(
            [ResBlock(64) for _ in range(self.stack_s_size)]
        )
        self.stack_m = nn.Sequential(
            [ResBlock(128) for _ in range(self.stack_m_size)]
        )
        self.stack_l = nn.Sequential(
            [ResBlock(256) for _ in range(self.stack_l_size)]
        )

        # output logits
        self.fc_final = nn.Dense(self.num_classes) 
    

    def __call__(self, x):

        B = x.shape[0]

        x = self.stack_s(x)
        x = self.stack_m(x)

        x = self.stack_l(x)
        x = self.pool(x, (2, 2), (2, 2))
        x = x.reshape((B, -1))
        x = self.fc_final(x)
        return x


Testing out on dummy input

Outputs:
- [0] Final FC layer activations
- [1] batch_stats

In [4]:
b = ResNet()
rng = jax.random.PRNGKey(0)

params = b.init(rng, jnp.ones((5, 28, 28, 3)))
# test the forward pass
logits = b.apply(params, jnp.ones((5, 28, 28, 3)), mutable=['batch_stats'])
# print(params)
print(logits[0].shape)

(5, 10)


### Evaluation Metrics

In [31]:
def get_accuracy(logits, labels):
  return jnp.mean(jnp.argmax(logits, -1) == labels)

### Training State
This bundles everything related to one training run:
- model parameters
- hyperparameters
- optimizers
- etc...

In [37]:
def create_train_state(model, rng, in_shape, hp):
    """
    create train state given
    model: nn.Module
    rng: PRNGKey
    in_shape: input shape
    hp: hyperparameters dict
    """

    params = model.init(rng, jnp.ones(in_shape))
    optim = optax.adam(learning_rate=hp['lr'])

    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optim
    )

# hyperparameters
hp = {
    'lr': 0.0001,
    'batch_size': 32,
    'num_epochs': 10
}


### Training step

In [38]:
# @jax.jit
def train_step(train_state, X, t):
    """
    train step
    """
    def loss_fn(params):
        logits = train_state.apply_fn(params, X, mutable=['batch_stats'])[0]
        # print(logits.shape, t.shape)
        t_onehot = jax.nn.one_hot(t, 10)
        loss = optax.softmax_cross_entropy(logits=logits, labels=t_onehot)
        loss = jnp.mean(loss)
        return loss, logits
    
    # function that computes function value and gradient
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grad = grad_fn(train_state.params)

    print("accuracy: ", get_accuracy(logits, t))
    # update parameters
    updates, new_opt_state = train_state.tx.update(grad, train_state.opt_state, train_state.params)
    new_params = optax.apply_updates(train_state.params, updates)
    new_state = train_state.replace(params=new_params, opt_state=new_opt_state)
    return new_state, loss

**Training loop**

In [39]:
def train_resnet(train_state):
    train, val = get_datasets(hp['num_epochs'], hp['batch_size'])
    for i, batch in enumerate(train):
        X = jnp.array(batch['image'])
        t = jnp.array(batch['label'])

        train_state, loss = train_step(train_state, X, t)
        print(f'Batch {i} loss: {loss}')

model = ResNet(10)
# print(model.tabulate(rng, jnp.ones((1, 28, 28, 1))))
ts = create_train_state(b, rng, (hp['batch_size'], 28, 28, 1), hp)
train_resnet(ts)

accuracy:  0.0625
Batch 0 loss: 2.3010637760162354
accuracy:  0.03125
Batch 1 loss: 2.3257012367248535
accuracy:  0.0
Batch 2 loss: 2.321312665939331
accuracy:  0.1875
Batch 3 loss: 2.2949609756469727
accuracy:  0.28125
Batch 4 loss: 2.2787766456604004
accuracy:  0.25
Batch 5 loss: 2.2432780265808105
accuracy:  0.125
Batch 6 loss: 2.2875819206237793
accuracy:  0.15625
Batch 7 loss: 2.300485610961914
accuracy:  0.375
Batch 8 loss: 2.128314971923828
accuracy:  0.375
Batch 9 loss: 2.0713255405426025
accuracy:  0.375
Batch 10 loss: 1.9635430574417114
accuracy:  0.46875
Batch 11 loss: 1.5138872861862183
accuracy:  0.40625
Batch 12 loss: 1.3324579000473022
accuracy:  0.5
Batch 13 loss: 1.4267290830612183
accuracy:  0.625
Batch 14 loss: 1.0227885246276855
accuracy:  0.65625
Batch 15 loss: 1.4378671646118164
accuracy:  0.59375
Batch 16 loss: 1.1241247653961182
accuracy:  0.625
Batch 17 loss: 1.2397819757461548
accuracy:  0.71875
Batch 18 loss: 0.8037732839584351
accuracy:  0.5625
Batch 19 loss

2024-03-21 15:14:52.922419: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


KeyboardInterrupt: 