https://keras.io/examples/vision/token_learner/

In [1]:
import keras 
from keras import layers
from keras import ops
from tensorflow import data as tf_data

from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import math

## Hyperparameters

In [2]:
# DATA
BATCH_SIZE = 256
AUTO = tf_data.AUTOTUNE
INPUT_SHAPE = (32,32,3)
NUM_CLASSES = 10

# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 1

# AUGMENTATION
IMAGE_SIZE = 48 # reize input images to this size
PATCH_SIZE = 6 # Size of the patches to be extrachted from the input images. 
NUM_PACHES = (IMAGE_SIZE // PATCH_SIZE ) ** 2

# VIT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4

MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM

]

# TOKENLEARNER
NUM_TOKENS = 4

Load and prepare the CIFAR-10 dataset

In [6]:
# Load the CIFAR-10 dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")

# Convert to tf.data.Dataset objects.
train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)

val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)

test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)

Training samples: 40000
Validation samples: 10000
Testing samples: 10000


Data augmentation
- Rescaling
- REsizing
- Random cropping (fixed-sized or random sized)
- Random horizontal flipping

In [11]:
data_augmentation = keras.Sequential(
    [
        layers.Rescaling(1/255.0),
        layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
        layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal")
    ],
    name="data_augmentation"
)

Positional embedding module
- multi-head self attention layers
- fully-connected feed forward networks (MLP) 