# Training Vision Transformer for Flower Classification
In this notebook I have used keras based implimentation of pretrained Vision-Transformers for Flower Classification. Part of code is taken from <a href='https://www.kaggle.com/dimitreoliveira/flower-classification-with-tpus-eda-and-baseline#Model'>here</a>. This is a very good notebook on the same problem.

In [1]:
!nvidia-smi

Mon Jul  5 14:17:46 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Importing libraries

In [2]:
!pip install vit-keras --quiet
!pip install tensorflow_addons --quiet

[?25l[K     |▌                               | 10kB 24.7MB/s eta 0:00:01[K     |█                               | 20kB 20.7MB/s eta 0:00:01[K     |█▌                              | 30kB 16.4MB/s eta 0:00:01[K     |██                              | 40kB 14.7MB/s eta 0:00:01[K     |██▍                             | 51kB 7.7MB/s eta 0:00:01[K     |███                             | 61kB 9.0MB/s eta 0:00:01[K     |███▍                            | 71kB 8.7MB/s eta 0:00:01[K     |███▉                            | 81kB 9.5MB/s eta 0:00:01[K     |████▍                           | 92kB 9.3MB/s eta 0:00:01[K     |████▉                           | 102kB 7.4MB/s eta 0:00:01[K     |█████▎                          | 112kB 7.4MB/s eta 0:00:01[K     |█████▉                          | 122kB 7.4MB/s eta 0:00:01[K     |██████▎                         | 133kB 7.4MB/s eta 0:00:01[K     |██████▊                         | 143kB 7.4MB/s eta 0:00:01[K     |███████▎               

In [37]:
import os
import re
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf
from functools import partial
from vit_keras import vit

from sklearn.metrics import roc_auc_score

## Configuration

In [4]:
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed_everything(1234)

In [70]:
class Config:
  
  INPUT_DIR = '/content/drive/MyDrive/Projects/Flower_Classification/input'
  MODEL_DIR = '/content/drive/MyDrive/Projects/Flower_Classification/models'
  LOG_DIR = '/content/drive/MyDrive/Projects/Flower_Classification/logs'

  AUTOTUNE = tf.data.experimental.AUTOTUNE
  
  IMAGE_SIZE = 224
  IMAGE_HEIGHT = 224
  IMAGE_WIDTH = 224
  TRAIN_BATCH_SIZE = 8
  VALID_BATCH_SIZE = 8
  SHUFFLE = 1234

  LR = 0.001
  NUM_CLASSES = 104
  EPOCHS = 50

  FINETUNE_LR = 0.0001
  FINETUNE_EPOCHS = 30

In [71]:
config = Config()

## Creating the Dataset

In [7]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [config.IMAGE_HEIGHT, config.IMAGE_WIDTH, 3])
    return image

In [8]:
def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

In [9]:
def load_dataset(filenames, labeled=True, ordered=False):
  ignore_order = tf.data.Options()
  if not ordered:
    ignore_order.experimental_deterministic = False
  dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=config.AUTOTUNE)
  dataset = dataset.with_options(ignore_order)
  dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=config.AUTOTUNE)
  return dataset

In [10]:
# Train Dataset
train_filenames = glob.glob('/content/drive/MyDrive/Projects/Flower_Classification/input/tfrecords-jpeg-224x224/train/*.tfr*')

In [11]:
# Validation Dataset
valid_filenames = glob.glob('/content/drive/MyDrive/Projects/Flower_Classification/input/tfrecords-jpeg-224x224/val/*.tfr*')

In [12]:
def data_augment(image, label):
    crop_size = tf.random.uniform([], int(config.IMAGE_HEIGHT*.7), config.IMAGE_HEIGHT, dtype=tf.int32)
        
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_saturation(image, lower=0, upper=2)
    image = tf.image.random_crop(image, size=[crop_size, crop_size, 3])
    image = tf.image.resize(image, size=[config.IMAGE_HEIGHT, config.IMAGE_WIDTH])

    return image, label

In [13]:
def get_training_dataset():

  dataset = load_dataset(train_filenames)
  dataset = dataset.map(data_augment, num_parallel_calls=config.AUTOTUNE)
  dataset = dataset.repeat()
  dataset = dataset.shuffle(config.SHUFFLE)
  dataset = dataset.batch(config.TRAIN_BATCH_SIZE)
  dataset = dataset.prefetch(config.AUTOTUNE)

  return dataset

In [14]:
def get_valid_dataset():

  dataset = load_dataset(valid_filenames)
  dataset = dataset.batch(config.TRAIN_BATCH_SIZE)
  dataset = dataset.cache()
  dataset = dataset.prefetch(config.AUTOTUNE)

  return dataset

## Creating the Model
We will use Vision Transformer Base 16 as our base model.

In [21]:
vit_model = vit.vit_b16(
    image_size = config.IMAGE_SIZE,
    activation = 'relu',
    pretrained = True,
    include_top = False,
    pretrained_top = False,
    classes = config.NUM_CLASSES
)

vit_model.trainable = False



In [22]:
model = tf.keras.Sequential([
                            vit_model,
                            tf.keras.layers.Flatten(),
                            tf.keras.layers.BatchNormalization(),
                            tf.keras.layers.Dense(1024, activation=tf.keras.activations.gelu),
                            tf.keras.layers.BatchNormalization(),
                            tf.keras.layers.Dense(config.NUM_CLASSES, activation='softmax')
])

In [24]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
vit-b16 (Functional)         (None, 768)               85798656  
_________________________________________________________________
flatten_1 (Flatten)          (None, 768)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 768)               3072      
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              787456    
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_3 (Dense)              (None, 104)               106600    
Total params: 86,699,880
Trainable params: 897,640
Non-trainable params: 85,802,240
____________________________________

## Training the model

### Training the last layers

In [27]:
optimizer = tf.keras.optimizers.Adam(learning_rate=config.LR)
loss = tf.keras.losses.SparseCategoricalCrossentropy()
metric = tf.keras.metrics.SparseCategoricalAccuracy()

In [28]:
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

In [29]:
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(config.MODEL_DIR + '/ViT_training.h5', save_best_only=True, mode='min', save_weights_only=False)
early_stopping = tf.keras.callbacks.EarlyStopping(patience=3)
logger = tf.keras.callbacks.TensorBoard(log_dir=config.LOG_DIR)
lr_schedular = tf.keras.callbacks.ReduceLROnPlateau(patience=1)

In [30]:
model.fit(x = get_training_dataset(),
          validation_data = get_valid_dataset(), 
          epochs = config.EPOCHS, 
          steps_per_epoch = 21000//config.TRAIN_BATCH_SIZE,
          callbacks = [model_checkpoint, early_stopping, logger, lr_schedular])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50


<tensorflow.python.keras.callbacks.History at 0x7fa8ba972150>

In [31]:
model.save(config.MODEL_DIR + '/ViT_trained.h5')

### Finetuning whole model

In [63]:
model = tf.keras.models.load_model(config.MODEL_DIR + '/ViT_trained.h5')

In [67]:
for layer in model.layers:
  layer.trainable = True

model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
vit-b16 (Functional)         (None, 768)               85798656  
_________________________________________________________________
flatten_1 (Flatten)          (None, 768)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 768)               3072      
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              787456    
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_3 (Dense)              (None, 104)               106600    
Total params: 86,699,880
Trainable params: 86,696,296
Non-trainable params: 3,584
______________________________________

In [72]:
optimizer = tf.keras.optimizers.Adam(learning_rate=config.FINETUNE_LR)
loss = tf.keras.losses.SparseCategoricalCrossentropy()

In [75]:
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(config.MODEL_DIR + '/ViTfinetuning.h5', save_best_only=True, mode='min', save_weights_only=False)
early_stopping = tf.keras.callbacks.EarlyStopping(patience=2)
logger = tf.keras.callbacks.TensorBoard(log_dir=config.LOG_DIR)
lr_schedular = tf.keras.callbacks.ReduceLROnPlateau(patience=1, min_delta=0.003)

In [76]:
model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])

model.fit(x = get_training_dataset(),
          validation_data = get_valid_dataset(),
          epochs = config.FINETUNE_EPOCHS,
          steps_per_epoch=2700,
          callbacks=[early_stopping, logger, lr_schedular])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30


<tensorflow.python.keras.callbacks.History at 0x7fa6849af650>

In [77]:
model.save(config.MODEL_DIR + '/ViT_finetuned.h5')

## Conclusion:
Our Model is performing well and we have the validation loss of 0.2822 and the accuracy of 0.9529. As we can see our model is overfitting which we can further improve but for now I will stop at this point.