# DINOv2 model fine-tuning

This notebook describe the process of fine-tuning DINOv2 model on a modified ImageNet-100 dataset (ImageNet-200)

## Import libraries

In [1]:
# Ignore warning while running the code
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [2]:
# Handling path
import os
from pathlib import Path

# PyTorch
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from copy import deepcopy

# DINOv2 ViT model
from dinov2.models.vision_transformer import vit_small

# Timing
import time

# Numpy 
import numpy as np

# Dataset path

Firstly, populate the train and validation dataset by creating two folders, one named train and one named val. In each of the folder, group images based on their classes and put them in the folder with name of the class.

In [3]:
# Define the local directory, as well as the path to training and validation set
local_directory = os.getcwd()
train_dataset_dir = Path("../../data/train")
valid_dataset_dir = Path("../../data/val")

## Image Resizing

Here we define resizing method to make sure that the size of the image fit with our model

In [4]:
class ResizeAndPad:
    def __init__(self, target_size, multiple):
        """
        Helper class to perform resize and padding on the image
        """
        self.target_size = target_size
        self.multiple = multiple

    def __call__(self, img):
        """
        Call transformation on the image
        """
        # Resize the image
        img = transforms.Resize(self.target_size)(img)

        # Calculate padding
        pad_width = (self.multiple - img.width % self.multiple) % self.multiple
        pad_height = (self.multiple - img.height % self.multiple) % self.multiple

        # Apply padding
        img = transforms.Pad(
            (pad_width // 2, 
             pad_height // 2, 
             pad_width - pad_width // 2, 
             pad_height - pad_height // 2)
        )(img)
        
        return img

In [5]:
# Define supported image size
IMAGE_SIZE = 256
TARGET_SIZE = (IMAGE_SIZE, IMAGE_SIZE)

In [6]:
# Define the DATA TRANSFORMATION process that images have to go through
DATA_TRANSFORM = {
    "train": transforms.Compose(
        [
            ResizeAndPad(TARGET_SIZE, 14),
            transforms.RandomRotation(360),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "valid": transforms.Compose(
        [
            ResizeAndPad(TARGET_SIZE, 14),
            transforms.RandomRotation(360),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

In [7]:
# Define the DATASETS, DATALOADERS and CLASSNAME
DATASETS = {
    "train": datasets.ImageFolder(train_dataset_dir, DATA_TRANSFORM["train"]),
    "valid": datasets.ImageFolder(valid_dataset_dir, DATA_TRANSFORM["valid"])
}

DATALOADERS = {
    "train": torch.utils.data.DataLoader(DATASETS["train"], batch_size=8, shuffle=True),
    "valid": torch.utils.data.DataLoader(DATASETS["valid"], batch_size=8, shuffle=True)
}

CLASSES = DATASETS["train"].classes

In [8]:
# Define the DEVICE for training the model
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## DINOv2 Classification Model

In [14]:
class DINOClassificationModel(nn.Module):
    def __init__(self, hidden_size, num_classes):
        """
        Load the pretrained DINOv2 Classification Model
        """
        # Initialize module
        super(DINOClassificationModel, self).__init__()
        
        # Load model with register
        model = vit_small(patch_size=14,
                          img_size=526,
                          init_values=1.0,
                          num_register_tokens=4,
                          block_chunks=0)
        self.embedding_size = 384
        self.number_of_heads = 6
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        
        # Load the pre-trained weights
        model.load_state_dict(
            torch.load(
                "../../pretrained/dinov2_vits14_reg4_pretrain.pth"
            )
        )
        
        # Copy the model
        self.transformers = deepcopy(model)
        
        # Add the classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.embedding_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.num_classes)
        )
        
    def forward(self, inputs):
        """
        Forward the inputs
        inputs: tensor of size (batch_size, image_height, image_width, channels)
        """
        # Pass through the transformers and normalization
        outputs = self.transformers(inputs)
        outputs = self.transformers.norm(outputs)
        outputs = self.classifier(outputs)
        return outputs

## Create trainer and train function for the model

In [15]:
class Trainer:
    def __init__(self, model, device, train_loader, val_loader, args):
        """
        Initialize the trainer for the DINOv2 ViT model
        """
        # Cache the parameters
        self.model = model
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.args = args
        
        # Model to device
        self.model.to(device)
        
        # Create optimizer and cross-entropy loss function
        self.optimizer = optim.Adam(self.model.parameters(), args["lr"])
        self.criterion = nn.CrossEntropyLoss()
        
    def train(self, epoch):
        """
        Train the Visual Transformer for one epoch
        :param epoch: the current epoch
        :return: epoch loss and accuracy
        """
        # Get the current time
        current_time = time.time()

        # Get the number of batches and the number of samples of the test loader
        n_batches, n_samples = len(self.train_loader), len(self.train_loader.dataset)

        # Initialize the loss and accuracy
        epoch_loss = 0.0
        epoch_accuracy = 0.0

        # Put the model into train mode
        self.model.train()

        # Calculate the loss and accuracy
        for image, label in self.train_loader:
            # print("PROCESSING")
            # Map image and label to device
            image = image.to(self.device)
            label = label.to(self.device)

            # Forward pass through visual transformer
            output = self.model(image)
            loss = self.criterion(output, label)

            # Backward pass through visual transformer
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Calculate the loss and accuracy
            acc = (output.argmax(dim=1) == label).float().sum()
            epoch_accuracy += acc.item()
            epoch_loss += loss.item()

        # Calculate the loss and accuracy
        epoch_loss = epoch_loss / n_batches
        epoch_accuracy = epoch_accuracy / n_samples * 100

        # Calculate the training time
        print(time.time() - current_time)

        # Display the current status
        print('Train Epoch: {}\t>\tLoss: {:.4f} / Acc: {:.1f}%'.format(epoch, epoch_loss, epoch_accuracy))

        return epoch_accuracy, epoch_accuracy
    
    def validate(self, epoch):
        """
        Perform the validation at epoch
        :param epoch: the current epoch
        :return: the epoch loss and accuracy
        """
        # Get the number of batches and the number of samples of the test loader
        n_batches, n_samples = len(self.val_loader), len(self.val_loader.dataset)

        # Put the model into eval mode
        self.model.eval()

        # Validate
        with torch.no_grad():
            epoch_val_accuracy = 0.0
            epoch_val_loss = 0.0

            for data, label in self.val_loader:
                # Map image and label to device
                data = data.to(self.device)
                label = label.to(self.device)

                # Forward pass through the Visual Transformer
                val_output = self.model(data)
                val_loss = self.criterion(val_output, label)

                # Get the loss and accuracy
                acc = (val_output.argmax(dim=1) == label).float().sum()
                epoch_val_accuracy += acc.item()
                epoch_val_loss += val_loss.item()

        # Calculate the validation accuracy and loss
        epoch_val_loss = epoch_val_loss / n_batches
        epoch_val_accuracy = epoch_val_accuracy / n_samples * 100

        # Display the current stats
        print('Validation Epoch: {}\t>\tLoss: {:.4f} / Acc: {:.1f}%'.format(epoch, epoch_val_loss,epoch_val_accuracy))
        
        return epoch_val_loss, epoch_val_accuracy
    
    def save(self, model_path, epoch):
        """
        Save the current model
        :param model_path: the saved model path
        :param epoch: the current epoch
        :return: None
        """
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, model_path)

In [16]:
# Create a train function
def train(model, datasets, dataloaders, args, device):
    """
    args: training arguments
    """
    # Display info
    print('Train dataset of size %d' % len(datasets["train"]))
    print('Validation dataset of size %d' % len(datasets["valid"]))
    
    # Create trainer
    trainer = Trainer(
        model,
        device,
        dataloaders["train"],
        dataloaders["valid"],
        args
    )
    
    # Start training
    val_loss_array = []
    train_loss_array = []
    val_accuracy_array = []
    train_accuracy_array = []
    
    # Model save directory
    stats_save_dir = args["save_dir"]
    stats_save_address = stats_save_dir + '/results.npy'
    
    # Train & Validate
    for epoch in range(1, args["epochs"] + 1):
        # Train the model for thi epoch
        epoch_loss, epoch_accuracy = trainer.train(epoch)

        # Validate the model
        epoch_val_loss, epoch_val_accuracy = trainer.validate(epoch)

        # Save the model
        trainer.save(args["output_model_prefix"], epoch)

        # Save the training and validation accuracy
        val_accuracy_array.append(epoch_val_accuracy)
        train_accuracy_array.append(epoch_accuracy)

        # Save the validation and training loss
        val_loss_array.append(epoch_val_loss)
        train_loss_array.append(epoch_loss)

        # Save the training and validation result
        losses = np.asarray([train_loss_array, val_loss_array, train_accuracy_array, val_accuracy_array])
        np.save(stats_save_address, losses)

# Model training

In [17]:
# Define arguments
ARGS = {
    "lr": 10e-6,
    "save_dir": "./models",
    "output_model_prefix": "./models/model.pth",
    "epochs": 50,
    "hidden_size": 256,
    "num_classes": 2
}

# Create model
model = DINOClassificationModel(hidden_size=ARGS["hidden_size"], num_classes=ARGS["num_classes"])

# Create training pipeline
train(
    model=model,
    datasets=DATASETS,
    dataloaders=DATALOADERS,
    args=ARGS,
    device=DEVICE
)