# Fine Tuning / Transfer learning of ViT model
Use case: classifier Lego/Duplo

In [1]:
from collections import defaultdict
import imageio.v2 as iio
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow.keras as ks
from transformers import ViTImageProcessor, TFViTForImageClassification, TFViTModel

2024-07-25 12:43:33.888779: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-25 12:43:33.932101: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-25 12:43:33.932135: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-25 12:43:33.933612: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-25 12:43:33.940467: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-25 12:43:33.941225: I tensorflow/core/platform/cpu_feature_guard.cc:1

## Load pretrained model and corresponding ImageProcessor

In [2]:
# Create an instance of ViTImageProcessor and use ViT base model as pretrained model
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# Create an instance of ViTForImageClassification and use ViT base model as pretrained model
ViT_base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

All PyTorch model weights were used when initializing TFViTModel.

All the weights of TFViTModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTModel for predictions without further training.


## Sublcass of keras.Model
Create a model class containing the  pretrained ViT and a dense classifier on top of pooler layer of ViT

In [3]:
class ViTLegoDuplo(tf.keras.Model):
    '''Class is intended to be constructed on a ViT model containing a pooler layer.
    A dense layer is added on top of the ViT pooler. In case of a use as a binary classifier, 
    the outputs will be logits from one neuron. In case of a one-hot-encoded classification
    task, the output will be probabilities from a Softmax-function. See __init__() for details.

    A method is provided to declare different layersr of the ViT model as trainable (see
    ".choose_trainables()"" for details).
    '''
    def __init__(self, vit_base_model, no_classes='binary'):
        # vit_base_model should be a ViT model with pretrained pooler layer
        # no_classes: Pass "binary" to obtain a single neuron output, providing logits
        #     Pass: integer number of classes for one-hot-encoded data, layer provides
        #     probabilties from Softmax-function in this case
        super().__init__()
        self.vit = vit_base_model
        if no_classes == 'binary':
            self.dense_layer = tf.keras.layers.Dense(1, activation='tanh')
        else:
            self.dense_layer = tf.keras.layers.Dense(no_classes, activation='tanh')
            self.dense_layer = tf.keras.layers.Softmax()

    def choose_trainables(self, embeddings=True, encoder=True, layernorm=True, pooler=True):
        # Method to set trainable argument of components of ViT model.
        
        # Access embedding layer of ViT model in its mainlayer and set trainable parameter
        self.vit.layers[0].embeddings.trainable = embeddings
        # Access encoder layers of ViT model in its mainlayer and set trainable parameter
        self.vit.layers[0].encoder.trainable = encoder
        # Access layer normalization of ViT model in its mainlayer and set trainable parameter
        self.vit.layers[0].layernorm.trainable = layernorm
        # Access pooler of ViT model in its mainlayer and set trainable parameter
        self.vit.layers[0].pooler.trainable = pooler

    def call(self, x):
        x = self.vit(x)[1] # take only output of pooler
        return self.dense_layer(x)    

## Instantiate model for Lego/Duplo problem

In [4]:
# Instantiate model from ViTLegoDuplo class and define trainable argument for components of ViT model
classifier = ViTLegoDuplo(ViT_base_model) # for binary classification
classifier.choose_trainables(embeddings=False, encoder=False, layernorm=False, pooler=False)    
# Define appropriate loss, compile and build model.
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
classifier.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
classifier.build(input_shape=(None, 3, 224, 224))
classifier.summary(expand_nested=True)

Model: "vi_t_lego_duplo"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               multiple                  769       
                                                                 
Total params: 86390017 (329.55 MB)
Trainable params: 769 (3.00 KB)
Non-trainable params: 86389248 (329.55 MB)
_________________________________________________________________


## Data preparation
From pictures in directories

In [25]:
# List of directories with data to use in training, validation and testing
# One directory per class
datadir_list = [
    '/home/felbus/transformers/data/Lego', 
    '/home/felbus/transformers/data/Duplo'
]

# Definition of useful classes and functions for later Papeline creation

class ParseCat:
    '''Class helps to transform string labels to integer numbers. Therefore an initially
    empty dictionary is filled with an pair (categry_str: #), when class is called. 
    # is the integer representation of a category string value and is generated automatically,
    when class is called with an unknown category string value. Class returns # uppon call.
    '''
    def __init__(self):
        self.cat_dict = {} # emtpy dictionary to hold ('category': #) pairs

    def __call__(self, category):
        # Dynamically define #, if category is unkwown, return #
        if (cat_int := self.cat_dict.get(category)) is None: 
            self.cat_dict[category] = cat_int = len(self.cat_dict)
        return cat_int # return #

    def __str__(self):
        # Generate human readable output for print()
        content_str = f'#\t\tCategory\n-------------------------------------\n'
        for i, (key, value) in enumerate(self.cat_dict.items()):
            content_str += f'{value}\t\t{key}\n'
        return content_str
        

@tf.py_function(Tout=tf.float32)
def apply_feature_extractor(x):
    '''Function will apply "feature_extractor" to image data. Decorator allows usage of this function inside tf.data.Dataset.map().
    Parameters:
    x <tf.Tensor>: image data

    Return:
    <tf.Tensor>: features of image data with shape (channels, heigth, width)
    '''
    return feature_extractor.preprocess(images=x, return_tensors='tf')['pixel_values']

def list_file_generator(dir_list):
    '''Generator function will take a list with directory paths and yield tuples
    of two strings (filepath, label). Label will be extracted from lowest part of
    directory path and is pared to an integer category number.
    Parameters:
    dir_list <list <str>>: list of strings containing paths to directories

    Return:
    <tuple (<str>, <int>): Tuple containing absolute file path and label category
    '''
    # Loop over all directories passed in dir_list
    for directory in dir_list:
        # Decode bytestring to string
        directory = directory.decode('utf-8')
        # Take lowest directory name as label and parse it to integer
        _, label_str = os.path.split(directory)
        label = lego_parse(label_str)
        # Create list of all files in directory and loop over these files
        file_list = os.listdir(directory)
        for file in file_list:
            # join path and filename to absolute file path of this file
            file_path = os.path.join(directory, file)
            # yield tupel (absolute file path to image file, label)
            yield file_path, label
            
def load_image_data(path):
    '''Function will load image data and process it with tensorflow decode_image function.
    Parameters:
    path <str>: file path string to load the image from

    Return:
    <tf.Tensor>: Decoded image with shape (height, width, channels)
    '''
    image = tf.io.read_file(path)
    return tf.io.decode_image(image, channels=3)

# Creation of Pipeline object tf.data.Dataset

# Instantiate a perser object for label management
lego_parse = ParseCat()

# Initiate an tf.data.Dataset instance from generator function. 
# Elements will be of type tuple with two Tensors (file path and integer label)
image_dataset = tf.data.Dataset.from_generator(
    list_file_generator, args=[datadir_list], output_signature=(
        tf.TensorSpec(shape=(), dtype=tf.string), # file path
        tf.TensorSpec(shape=(), dtype=tf.int16) # label category
    )
)

# Load image data by mapping file path to load_image_data function
image_dataset = image_dataset.map(lambda x, y: (load_image_data(x), y))

# Wrap call of ViTImageProcessor in batch() / unbatch(), because otherwise it will add a dimension, that is not
# recognized as a batching dimension by the tf.data.Dataset instance.
image_dataset = image_dataset.batch(2)
# Apply feature_extractor to all image data, data in pipeline will be "channel dimension first"
# Images will be rescaled to [0, 1] and then normalized to means [0.5, 0.5, 0.5], resized to 3x224x224
image_dataset = image_dataset.map(lambda x, y: (apply_feature_extractor(x), y))
image_dataset = image_dataset.unbatch()

# Create new dataset from iterator over image_dataset. This is a work-around, as the
# return of ViTImageProcessor somehow returns a Dataset, that is unsuitable to be 
# passed to Model. Most probably because of missing Tensor Shape information
image_dataset = tf.data.Dataset.from_generator(
    image_dataset.__iter__, output_signature=(
        tf.TensorSpec(shape=(3, 224, 224), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int16)
    )
)

# Batching 
image_dataset = image_dataset.batch(3)