# Finetuning
**Skintone prediction**
- Reference: HuggingFace's finetuning tutorial

## Install/Import libraries

In [None]:
%%capture

! pip install transformers pytorch-lightning --quiet

In [None]:
import math
import pandas as pd
from PIL import Image, UnidentifiedImageError
import os
from pathlib import Path
import ast
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder
from transformers import ViTFeatureExtractor, ViTForImageClassification

In [None]:
# Comment if not on google colab

# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Constants
label_dir = './labels_for_more_training_data.csv' #Path of the labels csv file
data_dir = '/content/drive/MyDrive/private_test/data' #Path of the dataset

In [None]:
df = pd.read_csv(label_dir)
df.head()

Unnamed: 0,file_name,height,width,image_id,bbox,skintone,age,race,emotion,gender,masked
0,10003832.jpg,2000,1459,1,"[584.1895944369563, 301.32785213219023, 265.74...",mid-light,20-30s,Mongoloid,Anger,Male,unmasked
1,10005259.jpg,1395,2000,2,"[1131.1132364709713, 312.5498771883628, 285.10...",light,20-30s,Mongoloid,Neutral,Male,unmasked
2,10005527.jpg,1507,2000,3,"[548.0171526364226, 265.9999999999995, 246.980...",mid-light,20-30s,Mongoloid,Sadness,Female,unmasked
3,100086002.jpg,1334,2000,4,"[900.5677208085174, 57.13482704531668, 163.848...",light,20-30s,Mongoloid,Neutral,Female,unmasked
4,100148503.jpg,1561,2000,5,"[862.5207825161339, 478.9999999999999, 210.264...",light,20-30s,Caucasian,Happiness,Female,unmasked


## Init Dataset and Split into Training and Validation Sets
- We create custom dataset to load our images, crop by bbox and resize them. Also includes the image's label upon output.
- Then, we'll split dataset into train set and validation set by the ratio (85/15)

In [None]:
  # Creating a custom dataset class
class ImageDataset(Dataset):
    def __init__(self, dir, labels_dir, target_attr, transform=None):
      self.data_dir = dir
      self.target_attr = target_attr
      self.labels = pd.read_csv(labels_dir)
      self.labels = self.labels[self.labels['race']!='Mongolid']
      self.images = os.listdir(dir)
      self.transform = transform
      self.new_size = (128,128)

  # Defining the length of the dataset
    def __len__(self):
      return len(self.labels['file_name'])

  # Defining the method to get an item from the dataset
    def __getitem__(self, index):
      image_path = os.path.join(self.data_dir, self.labels.iloc[index]['file_name'])
      image = Image.open(image_path).convert('RGB')

      label = self.labels.iloc[index][self.target_attr]
      bbox = self.labels.iloc[index]['bbox']
      bbox = ast.literal_eval(bbox)
      image = image.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))
      image = image.resize(self.new_size)
      #image = pil_to_tensor(image)

      # Applying the transform
      if self.transform:
        image = self.transform(image)

      return (image, label)


In [None]:
ds = ImageDataset(data_dir, label_dir, 'skintone')
#Random sampling
indices = torch.randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .15)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])

## Preparing Labels for Our Model's Config

By adding `label2id` + `id2label` to our model's config, we'll get friendlier labels in the inference API.

In [None]:
skintone_label2id = {'dark': '0', 'light': '1', 'mid-dark': '2', 'mid-light': '3'}
skintone_id2label = {'0': 'dark', '1': 'light', '2': 'mid-dark', '3': 'mid-light'}

## Image Classification Collator

To apply our transforms to images, we'll use a custom collator class. We'll initialize it using an instance of `ViTFeatureExtractor` and pass the collator instance to `torch.utils.data.DataLoader`'s `collate_fn` kwarg.

In [None]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor

    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([int(skintone_label2id[x[1]]) for x in batch], dtype=torch.long)
        return encodings

## Init Feature Extractor, Model, Data Loaders
- We'll init a pretrained model `google/vit-base-patch16-224-in21k` to finetune.

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(skintone_label2id),
    label2id=skintone_label2id,
    id2label=skintone_id2label
)

collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=8, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8, collate_fn=collator, num_workers=2)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


If we're resuming training from a previous model, load it here. Otherwise skip this step.

In [None]:
# model.load_state_dict(torch.load('./drive/MyDrive/skintone.pth'))
# model.eval();

# Training

We'll use [PyTorch Lightning](https://pytorchlightning.ai/) to fine-tune our model.


In [None]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy(
            task='multiclass' if model.config.num_labels > 2 else 'binary',
            num_classes=model.config.num_labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=3)
trainer.fit(classifier, train_loader, val_loader)

INFO:lightning_fabric.utilities.seed:Seed set to 42
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type                      | Params
------------------------------------------------------
0 | model   | ViTForImageClassification | 85.8 M
1 | val_acc | MulticlassAccuracy        | 0     
------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.207   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


## Save model

In [None]:
import datetime

str(datetime.datetime.today())

In [None]:
torch.save(model.state_dict(), 'model_weights_skintone ' + str(datetime.datetime.today()) + '.pth')