In [2]:
import warnings
warnings.filterwarnings("ignore")

In [14]:
import tensorflow as tf
import numpy as np
%pip install --upgrade tensorflow

Note: you may need to restart the kernel to use updated packages.


In [5]:
(x_train,y_train),(x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [6]:
x_train.shape

(50000, 32, 32, 3)

In [7]:
learning_rate = 0.001
weight_decay = 0.001
batch_size = 256
image_size = 72
patch_size = 6
num_epochs = 100
num_patch = (image_size // patch_size)**2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048,1024]

In [8]:
class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(images=images,
                                           size = [1, self.patch_size,patch_size,1],
                                           strides=[1, self.patch_size, self.patch_size, 1],
                                           rates=[1,1,1,1],
                                           padding="VALID")
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size,-1,patch_dims])
        return patches
        

In [9]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        seld.num_patches = num_patches
        self.projection = tf.keras.layers.Dense(inits = projection_dim)
        self.positional_embedding = tf.keras.layers.Embedding(input_dim = num_patches, output_dim=projection_dim)
    
    def call(self, patch):
        positions = tf.range(0, self.num_patches,1)
        encoded = self.projection(patch) + self.positional_embedding(positions)
        return encoded

In [None]:
data_augmentation = tf.keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor = 0.02),
        layers.RandomZoom(height_factor = 0.2, width_factor = 0.2)
    ], name = "data_augmentation"
)
data_augmentation.layers[0].adapt(x_train)

In [None]:
def mlp(x, hidden_units, droupout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation = tf.nn.gelu)(x)
        x = layers.Dropout(droupout_rate)(x)
        return x

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(4,4))
image = x_train[np.random.choice(x_train.shape[0])]
plt.imshow(image.astype('uint8'))
plt.axis('off')

resized_image = tf.image.resize(tf.convert_to_tensor([image]),
                                size = (image_size,image_size))
patches = Patches(patch_size)(resized_image)

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4,4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n,n,i+1)
    patch_img = tf.reshape(patch, (patch_size, patch_size,3))
    plt.imshow(patch_img.numpy().astype('uint8'))
    plt.axis('off')

In [None]:
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    augumented = data_augmentation(inputs)
    patches = Patches(patch_size)(augumented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    
    for _ in (transformer_layers):
        x1 = layers.LayerNormalization(epsilon = 1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(num_heads = num_heads, key_dim = projection_dim, dropout = 0.1)(x1,x)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon = 1e-6)(x2)
        encoded_patches = layers.Add()([x4,x2])
    
    representation = layers.LayerNormalization(epsilon = 1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units = mlp_heads_units, droupout_rate=0.5)
    logits = layers.Dense(10)(features)
    
    model = tf.keras.Model(inputs = inputs, ouputs = logits)
    return

In [None]:
import tensorflow_addons as tfa
def run_program(model):
    optimizer = tfa.optimizers.Adam(
        learning_rate = learning_rate, weight_decay = weight_decay
    )
    
    model.compile(
        optimizer = optimizer,
        loss = keras.losses.SparseCategoricalCrossentropy(from_logits = True),
        metrics = [
            keras.metrics.SparseCategoricalAccuracy(name = "accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name = "top-5-accuracy"),
        ],
    )
    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor = "val_accuracy",
        save_best_only = True,
        save_weights_only = True,
    )
    
    history = model.fit(x = x_train, y = y_train, batch_size= batch_size, epochs = num_epochs, validation_split = 0.1, callback = [checkpoint_callback])
    
    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test,y_test)
    print(accuracy, '\n')
    print(top_5_accuracy)

In [None]:
vit_classifier = create_vit_classifier()
history = run_program(vit_classifier)