In [12]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from vit_model import VisionTransformer

In [2]:
devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(devices[0], True)

In [5]:
# loading mnist with reshaping and normalization
def load_cifar():
    
    (X_train, y_train), (X_test, y_test) = keras.datasets.cifar100.load_data()

    X_train = X_train.astype("float32")/255
    X_test = X_test.astype("float32")/255

    return X_train, y_train, X_test, y_test

In [7]:
X_train, y_train, X_test, y_test = load_cifar()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz


In [10]:
input_shape=X_train.shape[1:]
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72
patch_size = 6
projection_dim = 64
num_attention_heads = 4
num_transformer_layers = 8
num_mlp_heads = [2048, 1024]
output_classes = len(np.unique(y_train))

(32, 32, 3)

In [None]:
transformer = VisionTransformer(
    input_shape=input_shape, 
    image_size=image_size, 
    patch_size=patch_size, 
    projection_dim=projection_dim, 
    num_transformer_layers=num_transformer_layers,
    num_attention_heads=num_attention_heads,
    num_mlp_heads=num_mlp_heads,
    output_classes=output_classes
    )

In [None]:
def train(model, X_train, y_train, batch_size, epochs, learning_rate, weight_decay):
    model.summary()

    model.compile(learning_rate=learning_rate, weight_decay=weight_decay)

    model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs)

    return model

In [None]:
trained_transformer = train(transformer, X_train, y_train, batch_size, num_epochs, learning_rate, weight_decay)

In [None]:
trained_transformer.save("model")