# Fine-tuning Vision Transformers (ViT) On FER-2013 and Optuna Hyperparameter Optimization


Download checkpoint-6000 at 70% validation accuracy (3.35 epoches, learning rate 1e-4, adam optimizer, batch-size 16) at https://drive.google.com/file/d/1nyPRecZq_-5rWXBLs_nAdpONAl1Xy2tJ/view?usp=sharing

### Data loading and initial data processing
We load the data from huggingface and convert to the format required for the ViT. We can optionally can increase the dataset via data augmentation or using the diffusionFER dataset (not used in this notebook)

In [None]:
!pip install datasets torchx
!pip install accelerate==0.2.1
!pip install transformers==4.18
from datasets import load_dataset

In [None]:
# images generated from stable diffusion https://huggingface.co/datasets/FER-Universe/DiffusionFER
# used to supplement training and increase robustness. 2581 training examples. no test images
# not used currently
diffusion_fer_dataset = load_dataset(
    'FER-Universe/DiffusionFER',
    split='train'
)
diffusion_fer_dataset

In [None]:
# just use the original fer-2013 dataset for now
# load directly from hf for the correct format
dataset_train = load_dataset(
    'Jeneral/fer-2013',
    split='train', 
)

dataset_train
dataset_test = load_dataset(
    'Jeneral/fer-2013',
    split='test', 
    ignore_verifications=False  
)

dataset_test

In [None]:
from PIL import Image
import io

# Convert img_bytes to img in PIL format for both training and testing datasets
def convert_to_pil(image_bytes):
    image = Image.open(io.BytesIO(image_bytes))
    # Convert back to RGB by duplicating the grayscale channel, required for vit
    image_RGB = Image.merge("RGB", (image, image, image))
    return image_RGB

dataset_train = dataset_train.map(lambda example: {'img': convert_to_pil(example['img_bytes'])}, remove_columns=['img_bytes'])
dataset_test = dataset_test.map(lambda example: {'img': convert_to_pil(example['img_bytes'])}, remove_columns=['img_bytes'])

# rename labels to label
dataset_train = dataset_train.rename_column("labels", "label")
dataset_test = dataset_test.rename_column("labels", "label")

In [None]:
# check how many labels/number of classes
num_classes = len(set(dataset_train['label']))
labels = dataset_train.features['label']
num_classes, labels


In [None]:
dataset_train[0]

In [None]:
dataset_train[0]['img']


In [None]:
dataset_train[0]['label'], labels.names[dataset_train[0]['label']]

### Loading ViT Feature Extractor
We use `google/vit-base-patch16-224-in21k` model from the Hugging Face Hub. The model is named as so as it refers to base-sized architecture with patch resolution of 16x16 and fine-tuning resolution of 224x224. We examine the pretrained feature extractor and transform dataset_train and dataset_test

In [None]:
from transformers import ViTFeatureExtractor

# import model
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(
    model_id
)

In [None]:
feature_extractor

In [None]:
example = feature_extractor(
    dataset_train[0]['img'],
    return_tensors='pt'
)
example

In [None]:
example['pixel_values'].shape

In [None]:
# load in relevant libraries, and alias where appropriate
import torch

# device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def preprocess(batch):
    # take a list of PIL images and turn them to pixel values
    inputs = feature_extractor(
        batch['img'],
        return_tensors='pt'
    )
    # include the labels
    inputs['label'] = batch['label']
    return inputs

In [None]:
# transform the training dataset
prepared_train = dataset_train.with_transform(preprocess)
# ... and the testing dataset
prepared_test = dataset_test.with_transform(preprocess)

### Model Fine-Tuning

We use the Trainer object from Huggingface optimized for Transformers. We define necessary arguments like collate_fn to create batch tensors and accuracy function as the compute_metric to pass into the trainer. We also set```remove_unused_columns=False``` because we need the img column unused features to create pixel_values. We have chosen a standard batch size equal to 16, 500 evaluation steps, and a learning rate of $2e^{-4}$, and the Adam optimizer which are reasonable parameters.


In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [None]:
import numpy as np
from datasets import load_metric

# accuracy metric
metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(
        predictions=np.argmax(p.predictions, axis=1),
        references=p.label_ids
    )

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./fer_2013",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=5,
  save_steps=500,
  eval_steps=500,
  learning_rate=1e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
  optim="adamw_hf",
  logging_dir='./logs',
)

We can now load the pre-trained model. We'll add ```num_labels``` on init so the model creates a classification head with the right number of units. We pass the model to GPU if available and make sure to resume from previous checkpoint if available. 

In [None]:
from transformers import ViTForImageClassification

labels = dataset_train.features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_id,  # classification head
    num_labels=len(labels)
)

In [None]:
model.to(device)

In [None]:
from transformers import Trainer
import os

# Check if there are any checkpoints in the output directory
# convenience function for loading checkpoints if needed
def get_latest_checkpoint():
  latest_checkpoint = None
  if os.path.exists(training_args.output_dir):
      # List all files in the output directory
      checkpoint_files = os.listdir(training_args.output_dir)
      # Filter for checkpoint directories
      checkpoints = [os.path.join(training_args.output_dir, f) for f in checkpoint_files if "checkpoint" in f]
      # Find the latest checkpoint if any
      if checkpoints:
          latest_checkpoint = max(checkpoints, key=os.path.getmtime)
          print(f"Will resume training from checkpoint: {latest_checkpoint}")
      else:
          print("No checkpoint found, will start training from scratch")
  return latest_checkpoint

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train,
    eval_dataset=prepared_test,
    tokenizer=feature_extractor,
)

In [None]:
train_results = trainer.train(resume_from_checkpoint=True)
# save tokenizer with the model
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
# save the trainer state
trainer.save_state()

### Hyperparameter Optimization
We install Optuna to optimize hyperparameters. We define an objective function and a Optuna study object with 50 trials to maximize the validation accuracy. All models are trained from scratch starting by loading the pretrained model. We vary learning rate from 1e-6 to 1e-3 and use either adam_hf or adafactor as optimizers.

In [None]:
%pip install optuna

In [None]:
import optuna
from transformers import ViTForImageClassification, TrainingArguments, Trainer

def objective(trial: optuna.Trial):
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224-in21k",
        num_labels=7
    )
    model.to(device)

    training_args = TrainingArguments(
        output_dir = "./optuna/",
        learning_rate=trial.suggest_float("learning_rate", low=1e-6, high=1e-3),
        optim=trial.suggest_categorical("optimizer", ["adamw_hf", "adafactor"]),
        num_train_epochs=0.1,
        remove_unused_columns=False,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_train,
        eval_dataset=prepared_test,
        tokenizer=feature_extractor,
    )

    train_results = trainer.train()

    outputs = trainer.predict(prepared_test)
    test_accuracy = outputs.metrics['test_accuracy']

    return test_accuracy

In [None]:
# Maximize test_accuracy
study = optuna.create_study(study_name='hyperparameter-search', direction='maximize')
study.optimize(func=objective, n_trials=50, n_jobs=-1)
print(study.best_value)
print(study.best_params)
print(study.best_trial)

#### Model Evaluation

We can now evaluate our model using the accuracy metric defined above. We print the metrics and the confusion matrix. We can pick the first image in our testing dataset and see if the predicted label is correct by loading the fine-tuned model and running inference.

In [None]:
metrics = trainer.evaluate(prepared_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
outputs = trainer.predict(prepared_test)

print(outputs.metrics)


from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(xticks_rotation=45)

In [None]:
# show the first image of the testing dataset
image = dataset_test["img"][0].resize((200,200))
image

In [None]:
# extract the actual label of the first image of the testing dataset
actual_label = dataset_test["label"][0]

labels = dataset_test.features['label']
actual_label, labels.names[actual_label]


In [None]:
from transformers import ViTForImageClassification, ViTFeatureExtractor

# import our fine-tuned model from a local checkpoint
latest_checkpoint = get_latest_checkpoint()
model_finetuned = ViTForImageClassification.from_pretrained(latest_checkpoint)
# import features from the same checkpoint
feature_extractor_finetuned = ViTFeatureExtractor.from_pretrained(latest_checkpoint)

In [None]:
inputs = feature_extractor_finetuned(image, return_tensors="pt")

with torch.no_grad():
    logits = model_finetuned(**inputs).logits

In [None]:
predicted_label = logits.argmax(-1).item()
labels = dataset_test.features['label']
labels.names[predicted_label]

In [None]:
# Assuming 'latest_checkpoint' contains the path to the latest checkpoint directory
# Code to save checkpoint
latest_checkpoint = 'fer_2013/checkpoint-6000'
if latest_checkpoint:
    checkpoint_dir_name = os.path.basename(latest_checkpoint)
    zip_file_name = f"{checkpoint_dir_name}.zip"
    # Create a zip file of the latest checkpoint directory
    !zip -r {zip_file_name} {latest_checkpoint}
    # Download manually by right-clicking on created zip file
else:
    print("No checkpoint available to download.")

### Real-time inference

In [None]:
from IPython.display import display, Javascript
from base64 import b64decode
from PIL import Image
import numpy as np
import cv2
from io import BytesIO

def real_time_inference():
    js = Javascript('''
    async function captureFrame() {
        const div = document.createElement('div');
        const video = document.createElement('video');
        video.style.display = 'block';
        const stream = await navigator.mediaDevices.getUserMedia({video: true});
        document.body.appendChild(div);
        div.appendChild(video);
        video.srcObject = stream;
        await video.play();

        // Wait for the video to start playing
        await new Promise(resolve => setTimeout(resolve, 1000));

        const canvas = document.createElement('canvas');
        canvas.width = video.videoWidth;
        canvas.height = video.videoHeight;
        canvas.getContext('2d').drawImage(video, 0, 0);
        const dataUrl = canvas.toDataURL('image/jpeg');
        stream.getTracks().forEach(track => track.stop()); // Stop the video stream
        div.remove();
        return dataUrl;
    }
    ''')
    display(js)
    data_url = eval_js('captureFrame()')
    image_data = b64decode(data_url.split(',')[1])
    image = Image.open(BytesIO(image_data))
    frame = np.array(image)  # Convert PIL Image to numpy array for OpenCV processing

    # Load pre-trained Haar Cascade for face detection
    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

    # Load the finetuned model and feature extractor
    model = ViTForImageClassification.from_pretrained('./fer_2013/checkpoint-1000')
    feature_extractor = ViTFeatureExtractor.from_pretrained('./fer_2013/checkpoint-1000')

    # Convert the image to grayscale
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Detect faces in the image
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    # Process each face found
    for (x, y, w, h) in faces:
        # Extract face ROI (Region of Interest)
        face_roi = frame[y:y+h, x:x+w]

        # Preprocess the face ROI as per your model requirements
        inputs = feature_extractor(images=face_roi, return_tensors="pt")

        # Make prediction
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()

        # Get the class label
        class_label = model.config.id2label[predicted_class_idx]

        # Display the prediction
        display_str = f'Predicted: {class_label}'
        print(display_str)

        # Draw a rectangle around the face
        cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2)

    # Convert numpy array back to PIL Image for displaying in Colab
    frame = Image.fromarray(frame)
    img_byte_arr = BytesIO()
    frame.save(img_byte_arr, format='JPEG')
    encoded_img = b64encode(img_byte_arr.getvalue())
    img_str = encoded_img.decode('utf-8')
    img_html = f'<img src="data:image/jpeg;base64,{img_str}" />'
    display(HTML(img_html))

# Call the real-time inference function
while True:
  real_time_inference()