In [5]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from resblock import ResBlock

# dataset loading
import tensorflow as tf
import tensorflow_datasets as tfds

### Dataset Exploration

In [4]:
def get_datasets(num_epochs, batch_size):
  """Load MNIST train and test datasets into memory."""
  train_ds = tfds.load('mnist', split='train')
  test_ds = tfds.load('mnist', split='test')

  train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                           tf.float32) / 255.,
                                          'label': sample['label']}) # normalize train set
  test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],
                                                         tf.float32) / 255.,
                                        'label': sample['label']}) # normalize test set

  train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
  test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency

  return train_ds, test_ds

train, test = get_datasets(1, 32)
print(train)

<_PrefetchDataset element_spec={'image': TensorSpec(shape=(32, 28, 28, 1), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(32,), dtype=tf.int64, name=None)}>


### Network

In [7]:
class ResNet(nn.Module):
    """
    Identical to architecture defined in original ResNet paper
    """
    stack_s_size: int = 3
    stack_m_size: int = 4
    stack_l_size: int = 6
    pool: nn.Module = nn.avg_pool
    linear: nn.Module = nn.Dense

    def setup(self):
        self.stack_s = nn.Sequential(
            [ResBlock(64) for _ in range(self.stack_s_size)]
        )
        self.stack_m = nn.Sequential(
            [ResBlock(128) for _ in range(self.stack_m_size)]
        )
        self.stack_l = nn.Sequential(
            [ResBlock(256) for _ in range(self.stack_l_size)]
        )

        # output logits
        self.fc_final = nn.Dense(10) 
    

    def __call__(self, x):
        x = self.stack_s(x)
        x = self.stack_m(x)
        x = self.stack_l(x)
        x = self.pool(x, (1, 2))
        x = self.fc(x)
        return x
