In [1]:
!pip install vit_keras



In [2]:
import tensorflow as tf
from tensorflow.keras import backend, optimizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GaussianNoise, Dense
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, LearningRateScheduler
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow import keras
import keras

from vit_keras import vit

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
train_directory='/content/drive/MyDrive/deep learning/Birddata/train'
val_directory='/content/drive/MyDrive/deep learning/Birddata/valid'
test_directory='/content/drive/MyDrive/deep learning/Birddata/test'

In [5]:
tf.random.set_seed(1234) #set seed to reproduce same results

In [6]:
# Data augmentation
train_datagen = ImageDataGenerator(
    rescale=1/255,
    horizontal_flip=True,
    rotation_range=15,
    zoom_range=0.1,
)
valid_datagen = ImageDataGenerator(rescale=1/255)
test_datagen = ImageDataGenerator(rescale=1/255)

train_generator = train_datagen.flow_from_directory(
    train_directory,
    target_size=(224, 224),
    batch_size=32,
    color_mode='rgb',
    class_mode='sparse',
    shuffle=True,
)
validation_generator = valid_datagen.flow_from_directory(
    val_directory,
    target_size=(224, 224),
    batch_size=32,
    color_mode='rgb',
    class_mode='sparse')

test_generator = test_datagen.flow_from_directory(
    test_directory,
    target_size=(224, 224),
    color_mode='rgb',
    class_mode='sparse')

Found 35216 images belonging to 250 classes.
Found 1250 images belonging to 250 classes.
Found 1250 images belonging to 250 classes.


In [7]:
backend.clear_session()

vit_model = vit.vit_l32(
    image_size=224,
    pretrained=True,
    include_top=False,
    pretrained_top=False
)

print(len(vit_model.layers))
print(vit_model.layers)



31
[<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fc88485a690>, <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fc8842051d0>, <tensorflow.python.keras.layers.core.Reshape object at 0x7fc886cdd790>, <vit_keras.layers.ClassToken object at 0x7fc8847a1cd0>, <vit_keras.layers.AddPositionEmbs object at 0x7fc884846dd0>, <vit_keras.layers.TransformerBlock object at 0x7fc8848462d0>, <vit_keras.layers.TransformerBlock object at 0x7fc8705e72d0>, <vit_keras.layers.TransformerBlock object at 0x7fc8705c3f10>, <vit_keras.layers.TransformerBlock object at 0x7fc8704f0d10>, <vit_keras.layers.TransformerBlock object at 0x7fc8704acd90>, <vit_keras.layers.TransformerBlock object at 0x7fc87048f150>, <vit_keras.layers.TransformerBlock object at 0x7fc8703e0850>, <vit_keras.layers.TransformerBlock object at 0x7fc8702f94d0>, <vit_keras.layers.TransformerBlock object at 0x7fc8702c8050>, <vit_keras.layers.TransformerBlock object at 0x7fc87023bf90>, <vit_keras.layers.Transfo

In [8]:
# Delay lr
def scheduler(epoch: int, lr: float) -> float:
    if epoch != 0 and epoch % 7 == 0:
        return lr * 0.1
    else:
        return lr
lr_scheduler_callback = LearningRateScheduler(scheduler)

In [9]:
finetune_at = 28

# fine-tuning
for layer in vit_model.layers[:finetune_at - 1]:
    layer.trainable = False

num_classes = len(validation_generator.class_indices)

# Add GaussianNoise layer for robustness
noise = GaussianNoise(0.01, input_shape=(224, 224, 3))

In [10]:
#add layers
# Classification head
head = Dense(num_classes, activation="softmax")
model = Sequential()
model.add(noise)
model.add(vit_model)
model.add(head)

In [11]:
#compile the model
model.compile(optimizer=optimizers.Adam(),
               loss="sparse_categorical_crossentropy",
               metrics=["accuracy"])

In [None]:
#train the model
tf.random.set_seed(100)
                      
history = model.fit(
          train_generator,
          epochs=50,
          validation_data=validation_generator,
          verbose=1, 
          shuffle=True,
          callbacks=[
              EarlyStopping(monitor="val_accuracy", patience=10, restore_best_weights=True),
              lr_scheduler_callback,
          ])

In [None]:
#save model
#from keras.models import load_model
#tf.keras.models.save_model(filepath='/content/drive/MyDrive/deep learning/bird models/vit_new',model=model)