In [2]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from resblock import ResBlock

# dataset loading
import tensorflow as tf
import tensorflow_datasets as tfds

# training
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers


[3m                                ResBlock Summary                                [0m
┏━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃[1m [0m[1mpath      [0m[1m [0m┃[1m [0m[1mmodule   [0m[1m [0m┃[1m [0m[1minputs     [0m[1m [0m┃[1m [0m[1moutputs   [0m[1m [0m┃[1m [0m[1mbatch_stats[0m[1m [0m┃[1m [0m[1mparams    [0m[1m [0m┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│            │ ResBlock  │ [2mfloat32[0m[1,… │ [2mfloat32[0m[1… │             │            │
├────────────┼───────────┼─────────────┼────────────┼─────────────┼────────────┤
│ Conv_0     │ Conv      │ [2mfloat32[0m[1,… │ [2mfloat32[0m[1… │             │ bias:      │
│            │           │             │            │             │ [2mfloat32[0m[6… │
│            │           │             │            │             │ kernel:    │
│            │           │             │            │             │ [2mfloat3

2024-03-14 14:29:47.205188: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-14 14:29:47.205214: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-14 14:29:47.205998: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


### Dataset Exploration

In [1]:
def get_datasets(num_epochs, batch_size):
  """Load MNIST train and test datasets into memory."""
  train_ds = tfds.load('mnist', split='train')
  test_ds = tfds.load('mnist', split='test')

  train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                           tf.float32) / 255.,
                                          'label': sample['label']}) # normalize train set
  test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                         tf.float32) / 255.,
                                        'label': sample['label']}) # normalize test set

  train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
  test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency

  return train_ds, test_ds


### Network

In [3]:
class ResNet(nn.Module):
    """
    Identical to architecture defined in original ResNet paper
    """
    stack_s_size: int = 3
    stack_m_size: int = 4
    stack_l_size: int = 6
    num_classes: int = 10
    pool: nn.Module = nn.avg_pool
    linear: nn.Module = nn.Dense

    def setup(self):
        self.stack_s = nn.Sequential(
            [ResBlock(64) for _ in range(self.stack_s_size)]
        )
        self.stack_m = nn.Sequential(
            [ResBlock(128) for _ in range(self.stack_m_size)]
        )
        self.stack_l = nn.Sequential(
            [ResBlock(256) for _ in range(self.stack_l_size)]
        )

        # output logits
        self.fc_final = nn.Dense(self.num_classes) 
    

    def __call__(self, x):

        B = x.shape[0]

        x = self.stack_s(x)
        x = self.stack_m(x)

        x = self.stack_l(x)
        x = self.pool(x, (2, 2), (2, 2))
        x = x.reshape((B, -1))
        x = self.fc_final(x)
        return x


Testing out on dummy input

Outputs:
- [0] Final FC layer activations
- [1] batch_stats

In [6]:
b = ResNet()
rng = jax.random.PRNGKey(0)

params = b.init(rng, jnp.ones((5, 28, 28, 3)))
# test the forward pass
logits = b.apply(params, jnp.ones((5, 28, 28, 3)), mutable=['batch_stats'])
# print(params)
print(logits[0].shape)

(5, 10)


### Training State

In [7]:
def create_train_state(model, rng, in_shape, hp):
    """
    create train state given
    model: nn.Module
    rng: PRNGKey
    in_shape: input shape
    hp: hyperparameters dict
    """

    params = model.init(rng, jnp.ones(in_shape))
    optim = optax.adam(learning_rate=hp['lr'])

    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optim
    )

# hyperparameters
hp = {
    'lr': 0.001,
    'batch_size': 32,
    'num_epochs': 10
}

create_train_state(b, rng, (hp['batch_size'], 28, 28, 3), hp)


TrainState(step=0, apply_fn=<bound method Module.apply of ResNet(
    # attributes
    stack_s_size = 3
    stack_m_size = 4
    stack_l_size = 6
    num_classes = 10
    pool = avg_pool
    linear = Dense
)>, params={'params': {'stack_s': {'layers_0': {'Conv_0': {'kernel': Array([[[[-9.80935842e-02,  2.91641261e-02, -2.78061092e-01, ...,
           1.83324218e-01,  1.17900312e-01, -3.21348876e-01],
         [-2.89608657e-01,  8.63510668e-02, -1.13491669e-01, ...,
           5.38874827e-02,  2.30794638e-01, -2.28826746e-01],
         [ 2.80009322e-02,  1.80257767e-01, -1.94366530e-01, ...,
          -9.49707478e-02,  2.93350995e-01,  1.78598110e-02]],

        [[-2.22302496e-01,  4.00537133e-01,  1.75341636e-01, ...,
          -1.11044846e-01, -2.64691323e-01, -1.22244181e-02],
         [-8.55550393e-02,  4.74745184e-02,  5.56934662e-02, ...,
          -2.62117147e-01, -1.76127404e-01, -1.01350985e-01],
         [-4.05315787e-01, -6.63268846e-03, -8.64338577e-02, ...,
           9.0806

### Training step

In [8]:
# @jax.jit
def train_step(train_state, X, t):
    """
    train step
    """
    def loss_fn(params):
        logits = train_state.apply_fn(params, X, mutable=['batch_stats'])[0]
        # print(logits.shape)
        loss = optax.softmax_cross_entropy(logits=logits, labels=t)
        return loss, logits
    
    # function that computes function value and gradient
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grad = grad_fn(train_state.params)

    # update parameters
    updates, new_opt_state = train_state.tx.update(grad, train_state.opt_state, train_state.params)
    new_params = optax.apply_updates(train_state.params, updates)
    new_state = train_state.replace(params=new_params, opt_state=new_opt_state)
    return new_state, loss

**Training loop**

In [10]:
def train_resnet(train_state):
    train, val = get_datasets(hp['num_epochs'], hp['batch_size'])
    for i, batch in enumerate(train):
        X = batch['image']
        t = batch['label']
        train_state, loss = train_step(train_state, X, t)
        print(f'Batch {i} Loss: {loss}')

model = ResNet(10)
# print(model.tabulate(rng, jnp.ones((1, 28, 28, 1))))
ts = create_train_state(b, rng, (hp['batch_size'], 28, 28, 1), hp) 
train_resnet(ts)

2024-03-14 14:31:35.746697: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 70.19MiB (rounded to 73597184)requested by op 
2024-03-14 14:31:35.747243: W external/tsl/tsl/framework/bfc_allocator.cc:494] ************************************************************************************x***********____
E0314 14:31:35.747293  488499 pjrt_stream_executor_client.cc:2804] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 73597056 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   12.81MiB
              constant allocation:         0B
        maybe_live_out allocation:   12.25MiB
     preallocated temp allocation:   70.19MiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:   95.25MiB
              total fragmentation:       112B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 57.38MiB
		Operator: op_name="jit(c

KeyboardInterrupt: 