Imports

In [None]:
import os
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import numpy as np
import cv2
from glob import glob
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from patchify import patchify
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras import layers

Load the data

In [3]:
flowers_root = keras.utils.get_file('flower_photos', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
[1m228813984/228813984[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


ViT

In [4]:
class ClassToken(Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
            trainable = True
        )

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]

        cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
        cls = tf.cast(cls, dtype=inputs.dtype)
        return cls

In [5]:
def mlp(x, cf):
    x = Dense(cf["mlp_dim"], activation="gelu")(x)
    x = Dropout(cf["dropout_rate"])(x)
    x = Dense(cf["hidden_dim"])(x)
    x = Dropout(cf["dropout_rate"])(x)
    return x

In [6]:
def transformer_encoder(x, cf):
    skip_1 = x
    x = LayerNormalization()(x)
    x = MultiHeadAttention(
        num_heads=cf["num_heads"], key_dim=cf["hidden_dim"]
    )(x, x)
    x = Add()([x, skip_1])

    skip_2 = x
    x = LayerNormalization()(x)
    x = mlp(x, cf)
    x = Add()([x, skip_2])

    return x

In [7]:
def ViT(cf):
    # Inputs
    input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"])
    inputs = Input(input_shape)

    # Patch + Position Embeddings
    patch_embed = Dense(cf["hidden_dim"])(inputs)

    positions = tf.range(start=0, limit=cf["num_patches"], delta=1)
    pos_embed = Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions)
    embed = patch_embed + pos_embed

    # Adding Class Token
    token = ClassToken()(embed)
    x = Concatenate(axis=1)([token, embed])

    for _ in range(cf["num_layers"]):
        x = transformer_encoder(x, cf)

    # Classification Head
    x = LayerNormalization()(x)
    x = x[:, 0, :]
    x = Dense(cf["num_classes"], activation="softmax")(x)

    model = Model(inputs, x)
    return model

Hyperparameter

In [8]:
hp = {}
hp["image_size"] = 200
hp["num_channels"] = 3
hp["patch_size"] = 25
hp["num_patches"] = (hp["image_size"]**2) // (hp["patch_size"]**2)
hp["flat_patches_shape"] = (hp["num_patches"], hp["patch_size"]*hp["patch_size"]*hp["num_channels"])

hp["batch_size"] = 16
hp["lr"] = 1e-4
hp["num_epochs"] = 500
hp["num_classes"] = 5
hp["class_names"] = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]

hp["num_layers"] = 6
hp["hidden_dim"] = 512
hp["mlp_dim"] = 1024
hp["num_heads"] = 8
hp["dropout_rate"] = 0.1

model = ViT(hp)
model.summary()

Preprocessing

In [9]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [10]:
def load_data(path, split=0.1):
    images = shuffle(glob(os.path.join(path, "*", "*.jpg")))

    split_size = int(len(images) * split)
    train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
    train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)

    return train_x, valid_x, test_x

In [15]:
def process_image_label(path):
    # Reading images
    path = path.decode()
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (hp["image_size"], hp["image_size"]))
    image = image/255.0

    # Preprocessing to patches
    patch_shape = (hp["patch_size"], hp["patch_size"], hp["num_channels"])
    patches = patchify(image, patch_shape, hp["patch_size"])

    patches = np.reshape(patches, hp["flat_patches_shape"])
    patches = patches.astype(np.float32)

    # Label
    class_name = path.split("/")[-2]
    class_idx = hp["class_names"].index(class_name)
    class_idx = np.array(class_idx, dtype=np.int32)

    return patches, class_idx

In [16]:
def parse(path):
    patches, labels = tf.numpy_function(process_image_label, [path], [tf.float32, tf.int32])
    labels = tf.one_hot(labels, hp["num_classes"])

    patches.set_shape(hp["flat_patches_shape"])
    labels.set_shape(hp["num_classes"])

    return patches, labels

In [17]:
def tf_dataset(images, batch=32):
    ds = tf.data.Dataset.from_tensor_slices((images))
    ds = ds.map(parse).batch(batch).prefetch(8)
    return ds

Training

In [14]:
np.random.seed(42)
tf.random.set_seed(42)

# Directory for storing files
create_dir("files")

# Paths
dataset_path = flowers_root + "/flower_photos"
model_path = os.path.join("files", "model.keras")
csv_path = os.path.join("files", "log.csv")

# Dataset
train_x, valid_x, test_x = load_data(dataset_path)
print(f"Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")

train_ds = tf_dataset(train_x, batch=hp["batch_size"])
valid_ds = tf_dataset(valid_x, batch=hp["batch_size"])

# Model
model = ViT(hp)
model.compile(
    loss="categorical_crossentropy",
    optimizer=tf.keras.optimizers.Adam(hp["lr"], clipvalue=1.0),
    metrics=["acc"]
)

callbacks = [
    ModelCheckpoint(model_path, monitor='val_loss', verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-10, verbose=1),
    CSVLogger(csv_path),
    EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=False),
]

model.fit(
    train_ds,
    epochs=hp["num_epochs"],
    validation_data=valid_ds,
    callbacks=callbacks
)

Train: 2936 - Valid: 367 - Test: 367
Epoch 1/500
[1m184/184[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 368ms/step - acc: 0.3116 - loss: 2.1429
Epoch 1: val_loss improved from inf to 1.15974, saving model to files/model.keras
[1m184/184[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 415ms/step - acc: 0.3119 - loss: 2.1400 - val_acc: 0.5204 - val_loss: 1.1597 - learning_rate: 1.0000e-04
Epoch 2/500
[1m184/184[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 283ms/step - acc: 0.4964 - loss: 1.2127
Epoch 2: val_loss improved from 1.15974 to 1.02589, saving model to files/model.keras
[1m184/184[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 313ms/step - acc: 0.4965 - loss: 1.2125 - val_acc: 0.5777 - val_loss: 1.0259 - learning_rate: 1.0000e-04
Epoch 3/500
[1m184/184[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 282ms/step - acc: 0.5709 - loss: 1.0918
Epoch 3: val_loss improved from 1.02589 to 0.96357, saving model to files/model.keras
[1m184/18

KeyboardInterrupt: 