In [1]:
import tensorflow as tf
from tensorflow.keras import datasets

from model import ViT
from train_config import TrainerConfig, Trainer

In [2]:
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

In [3]:
train_images = tf.cast(x_train.reshape(-1, 32, 32, 3), tf.float32) / 255.0
test_images = tf.cast(x_test.reshape(-1, 32, 32, 3), tf.float32) / 255.0

In [4]:
train_x = tf.data.Dataset.from_tensor_slices(train_images)
train_y = tf.data.Dataset.from_tensor_slices(y_train)
train_dataset = tf.data.Dataset.zip((train_x, train_y)).shuffle(50000)

test_x = tf.data.Dataset.from_tensor_slices(test_images)
test_y = tf.data.Dataset.from_tensor_slices(y_test)
test_dataset = tf.data.Dataset.zip((test_x, test_y))

In [5]:
train_config = TrainerConfig(max_epochs=10, batch_size=128, learning_rate=1e-3)

model_config = {
    "image_size":32,
    "patch_size":4,
    "num_classes":10,
    "dim":64,
    "depth":3,
    "heads":4,
    "mlp_dim":128
}

In [6]:
trainer = Trainer(
    ViT,
    model_config,
    train_dataset,
    len(train_images),
    test_dataset,
    len(test_images),
    train_config
)

In [7]:
trainer.train()

epoch 1: train loss 2.19346. train accuracy 0.25130
epoch 1: test loss 2.10025. test accuracy 0.33000
epoch 2: train loss 2.12030. train accuracy 0.33092
epoch 2: test loss 2.06970. test accuracy 0.36050
epoch 3: train loss 2.06793. train accuracy 0.38444
epoch 3: test loss 2.01803. test accuracy 0.41440
epoch 4: train loss 2.03636. train accuracy 0.41598
epoch 4: test loss 1.98865. test accuracy 0.44430
epoch 5: train loss 2.00910. train accuracy 0.44458
epoch 5: test loss 1.99753. test accuracy 0.43560
epoch 6: train loss 1.99210. train accuracy 0.46246
epoch 6: test loss 1.96426. test accuracy 0.47010
epoch 7: train loss 1.98623. train accuracy 0.46832
epoch 7: test loss 1.96036. test accuracy 0.47430
epoch 8: train loss 1.97644. train accuracy 0.47846
epoch 8: test loss 1.96823. test accuracy 0.46700
epoch 9: train loss 1.95811. train accuracy 0.49702
epoch 9: test loss 1.93989. test accuracy 0.49540
epoch 10: train loss 1.95162. train accuracy 0.50340
epoch 10: test loss 1.93902. 