<center>
<table>
  <tr>
    <td><img src="https://portal.nccs.nasa.gov/datashare/astg/training/python/logos/nasa-logo.svg" width="100"/> </td>
     <td><img src="https://portal.nccs.nasa.gov/datashare/astg/training/python/logos/ASTG_logo.png?raw=true" width="80"/> </td>
     <td> <img src="https://www.nccs.nasa.gov/sites/default/files/NCCS_Logo_0.png" width="130"/> </td>
    </tr>
</table>
</center>

        
<center>
<h1><font color= "blue" size="+3">ASTG Python Course Series</font></h1>
</center>

---

<center>
    <h1><font color="red">Introduction to DINOv2 with PyTorch</font></h1>
</center>

# <font color="red">Objective</font>

- Provide an overview of DINOv2
   - Need for foundation models capable of generating features that work out of the box on any task.
   - How DINOv2, a self-supervized model, works to create vector embeddings for representing a large collection of images. DINOv2 can then be used to transfer the representations to AI models for better performance.
- Describe how DINOv2 can be combined with PyTorch to create model:
   - Show how to extract vector embeddings with DINOv2.
   - Create various models to solve the MNIST handwritten digit classification problem.
   - Show to use PyTorch Lightning.

### <font color="green">Concepts and key concepts</font>

- __Foundation models__: AI models trained on vast, immense datasets and can fulfill a broad range of general tasks. They serve as the base or building blocks for crafting more specialized applications.
- __Embeddings__: An embedding is a vector (list) of floating point numbers. The distance between two vectors measures their relatedness. Small distances suggest high relatedness and large distances suggest low relatedness.
- __PyTorch Lightning__: A PyTorch-based high-level Python framework that aims to simplify the training and deployment of models by providing a lightweight and standardized interface. 

# <font color="red">Background </font>

- Labeling images is one of the most time consuming parts of training a computer vision (CV) model:
   - Each object you want to identify needs to be labeled precisely.
   - It is computational intensive not only to gather and precisely label data when we deal with large datasets.
   - This is not only a limitation for subject matter-specific models, but also large general models that strive for high performance across a wider range of classes.
- The field of Natural Language Processing (NLP) has had rich featurization available in the form of vector embeddings.
   - Vector embeddings are numerical representations (keeping the meaning of the original data) of data points that express different types of data, including non-mathematical data such as words or images, as an array of numbers that ML models can process.
   - __They can be used as inputs to models that perform useful real-world tasks through mathematical operations that compare, transform, combine, sort or otherwise manipulate those numerical representations.__
   - Expressing data points as vectors enables the interoperability of different types of data, acting as a _lingua franca_ of sorts between different data formats by representing them in the same embedding space.
   - These foundational embeddings paved the way for a large range of applications.
      - Vector embeddings underpin nearly all modern machine learning, powering models used in the fields of NLP and CV, and serving as the fundamental building blocks of generative AI.
- We need CV foundation models capable of generating visual features that work out of the box on any task, both at the image level, e.g., image classification, and pixel level, e.g., segmentation.

# <font color="red">What is DINOv2? </font>

DINOv2 (self-__DIstillation of knowledge with NO labels v2__) is:

- A cutting-edge self-supervised vision transformer developed by [Meta AI](https://arxiv.org/abs/2304.07193?ref=blog.roboflow.com). 
- A model that is trained to learn from the data itself, without the need for human-labeled annotations. It generates its own supervisory signals from the input data, making it a form of unsupervised learning.
   - The model learns to understand the underlying structure and relationships within the data, effectively learning useful representations.
   - Traditional deep learning models for CV often rely on massive amounts of labeled data, which can be expensive and time-consuming to acquire. 

### <font color="green">Self-supervised model</font>
- DINOv2 is a self-supervised vision transformer model that consists of a __family of foundation models__ producing universal features suitable for image-level visual tasks (image classification, instance retrieval, video understanding) as well as pixel-level visual tasks (depth estimation, semantic segmentation).
- It is an advanced self-supervised learning technique to train models, enhancing computer vision by accurately identifying individual objects within images and video frames.

### <font color="green">Self-distillation framework</font>

- DINOv2 uses self-supervized learning (SSL) and knowledge (or model) distillation methods.
   - SSL is a “self-supervision” technique that involves a two-step process of pretraining and fine-tuning, where models learn representations from unlabeled data through auxiliary tasks and adapt to specific tasks using smaller amounts of labeled data. 
   - Knowledge distillation is the process of training a smaller model to mimic the larger model. In this case, you transfer the knowledge from the larger model (often called the “teacher”) to the smaller model (often called the “student”).
      - __Step 1__: Train the teacher model with labeled data; it produces an output, so you map the input and output from the teacher model and use the smaller model to copy the output, while being more efficient in terms of model size and computational requirements.
      - __Step 2__: Use a large dataset of unlabeled data to train the student models to perform as well as or better than the teacher models. The idea here is to train the large models with your techniques and distill a set of smaller models. This technique is very good for saving computing costs, and DINOv2 is built with it. 

### <font color="green">Power of data</font>
- DINOv2 is trained on a colossal dataset comprising over 142 million images.
- The dataset encompasses a wide variety of scenes, objects, and viewpoints, crucial for learning representations applicable across different tasks.
- This massive scale training enables the model to learn richer, more generalizable visual representations that capture the intricate nuances of the visual world.
- Training with massive batches allows the model to learn from a more diverse set of examples simultaneously, leading to better generalization and faster convergence.

### <font color="green">Benefits</font>

- Being self-supervised, DINOv2 eliminates the need for labeled input data, allowing models built on this framework to acquire more comprehensive insights into image content.
- DINOv2 utilizes a self-supervised learning technique, enabling the model to be trained on unlabeled images, yielding two significant advantages:
   - The approach eliminates the need for substantial time and resource investment in labeling data.
   - The model gains more profound and meaningful representations of the image input since it is directly trained on the images themselves.
- Pre-training models with self-supervised learning and then fine-tuning them on specific downstream tasks has become a successful approach for transfer learning, enabling models to perform well even with limited labeled data for the target task.
   - DINOv2 can learn adaptable, high-quality, all-purpose visual features, enabling it to perform various computer vision tasks, such as classification, estimating depth, semantic segmentation, instance retrieval, and more, without fine-tuning specific tasks. 

# <font color="red"> Python packages used</font>

- __Matplotlib__: Create visualization.
- __Pandas__: Data (two-dimensional labelled array) manipulation and analysis.
- __PyTorch__: Used to to build, train, and evaluate a deep machine learning algorithm based on Neural Networks.
- __PyTorch Lightning__: A wrapper framework for PyTorch that makes it easy to develop and train deep learning models.

In [None]:
try:
    import google.colab
    print("Running in Google Colab")
except:
    print("Not running in Google Colab")
else:
    print("Installing modules in Google Colab")
    !pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
    !pip3 install torch torchaudio torchvision torchtext torchdata
    !pip install pytorch_lightning
    !pip install transformers

In [None]:
import warnings
warnings.filterwarnings("ignore") 

In [None]:
import requests

In [None]:
import matplotlib.pyplot as plt

In [None]:
import numpy as np
import pandas as pd

In [None]:
from PIL import Image

In [None]:
import torch
#import torchvision.transforms as T
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets
from torchvision import transforms 
#from transformers import AutoImageProcessor, AutoModel

In [None]:
print("Lightning version: ", pl.__version__)
print("Torch version:     ", torch.__version__)
print("CUDA is available: ", torch.cuda.is_available())

# <font color="red">Obtaining embeddings with DINOv2</font>

### <font color="blue">Choose your device</font>

Use CUDA if available, otherwise use CPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### <font color="blue">Load the DINOv2 model</font>
- DINOv2 is a family of self-supervised Vision Transformer (ViT) models.
- The models vary in size and complexity, and include `ViT-S/14` and `ViT-L/14`:
   - `ViT-S/14` (Small, `dinov2_vits14`): This is a relatively smaller model in the DINOv2 family. It offers a good balance between performance and computational efficiency, making it suitable for applications with limited resources. It's a good choice when you need solid performance without the extensive computational cost of larger models.
   - `ViT-L/14` (Large, `dinov2_vitl14`): This is a larger model with potentially more Transformer layers and attention heads, which can lead to improved performance, especially on complex tasks.
      - It's the second largest model in the DINOv2 series, with `ViT-G/14` being the largest.
      - `ViT-L/14` excels in applications where accuracy is paramount, even with increased computational cost. 

Here we use the small ViT: `dinov2_vits14`

In [None]:
vit_name_small = "dinov2_vits14"
vit_name_large = "dinov2_vitl14"
vit_name_grande = "dinov2_vitg14"

In [None]:
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", vit_name_small)

### <font color="blue">Bring up the model to the device</font>

In [None]:
dinov2_vits14.to(device)

### <font color="blue">Set the model to evaluation mode</font>

In [None]:
dinov2_vits14.eval()

### <font color="blue">Get the image of interest</font>

In [None]:
image_url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(image_url, stream=True).raw)

#### Access image properties

In [None]:
print(image.format)

In [None]:
print(image.size)

In [None]:
print(image.mode)

#### Display the image

In [None]:
plt.imshow(image);

### <font color="blue">Create a image preprocessor and apply over input image</font>

- Data transformation is an essential preprocessing step that prepares raw data for models.
- Transformations like resizing, converting images to tensors, or normalizing pixel values are common for image data. These transformations help the model to see data in a consistent, well-scaled format.
- We use the `torchvision.transforms` module to perform a series of manipulations on the image:
   - `Resize()`: Resize the input to the given size expected by DINOv2.
   - `ToTensor()`: Convert the image to a tensor. In PyTorch, models operate on tensors, so images (or any data) need to be converted into tensors before they can be fed into a model. 
      - A Tensor Image is  a tensor with (`C`, `H`, `W`) shape, where `C` is a number of channels, `H` and `W` are image height and width. 
   - `Normalize()`: Adjust the pixel values of an image so that they fall within a specific range. It standardizes the data, making it easier for the model to learn patterns. Here, we normalize the tensor image with mean (mean values for the three channels) and standard deviation (std values for the three channels).

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),       
    transforms.ToTensor(),              
    transforms.Normalize(                
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
input_image = transform(image).unsqueeze(0).to(device)

In [None]:
type(input_image)

In [None]:
input_image.shape

### <font color="blue"> Feed the image to the model to extract embeddings</font>

In [None]:
with torch.no_grad():
    embeddings = dinov2_vits14(input_image)

Shape of the `embeddings`:

In [None]:
print(embeddings.shape)

Print the first few values:

In [None]:
np_embeddings = embeddings[0].cpu().numpy()

In [None]:
type(np_embeddings)

In [None]:
np_embeddings.shape

In [None]:
np_embeddings[0:20]

In [None]:
print(f"")
print(f"Min value: {np_embeddings.min()}")
print(f"Max value: {np_embeddings.max()}")
print(f"STD value: {np_embeddings.std()}")
print(f"")

Plot the embeddings:

In [None]:
plt.imshow(np_embeddings.reshape(1, -1), aspect='auto')
plt.colorbar(label='Intensity')
plt.title('Embedded Image')
plt.xlabel('Index')

# <font color="red">Obtain and pre-process MNIST handwritten digit dataset</font>

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## <font color="blue">Load the data</font>

#### Create custom PyTorch dataset

- The `Dataset` class is the primary tool for handling data.
- It acts as an interface that allows users to define how their data is accessed from files, APIs, or even generated from scratch.
- It helps prepare data for training by abstracting the complexities of data loading.
- The three primary methods users need to implement when creating a custom dataset are:
   - `__init__()`: Load the data into memory.
   - `__len__ ()`: Define the total number of samples in our dataset.
   - `__getitem__()`: Retrieve a specific data sample by index.

In [None]:
class DigitDataset(Dataset):
    def __init__(self, sample_dataset, transform=None):
        self.data = list()
        self.labels = list()
        for idx in range(len(sample_dataset)):   
            self.data.append(np.array(sample_dataset[idx][0]))
            self.labels.append(sample_dataset[idx][1])
        self.transform = transform
        print(f"Dataset contains {len(sample_dataset)} samples.")

    def __len__(self):
        """
        Provide the total number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Generates one sample of data.
        """
        # Select sample associated with the provided index
        image = self.data[idx].reshape(28, 28).astype('uint8')
        
        # Convert to 3-channel RGB
        image = Image.fromarray(image).convert("RGB")  

        # If necessary perform the data transform
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

#### Define transformations

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224), antialias=True),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                         std=(0.229, 0.224, 0.225)),
])

#### Get the MNIST dataset

In [None]:
train_raw_dataset = datasets.MNIST(root='data', train=True, download=True)

In [None]:
test_raw_dataset = datasets.MNIST(root='data', train=False, download=True)

#### Make the dataset ready for PyTorch

In [None]:
train_dataset = DigitDataset(train_raw_dataset, transform=transform)
test_dataset = DigitDataset(test_raw_dataset, transform=transform)

#### Extract a subset of the dataset

We do it to reduce the computational requirement. You may choose to skip this step if you want to use the entire dataset.

Set seeds for reproducibility:

In [None]:
np.random.seed(42)
torch.manual_seed(42)

Create random indices for sampling:

In [None]:
ntrain_data = 12800
nval_data = 2000
ntest_data = 2000

In [None]:
train_indices = np.random.choice(len(train_dataset), size=ntrain_data, replace=False)
test_indices = np.random.choice(len(test_dataset), size=ntest_data, replace=False)
new_list = list(set(range(len(test_dataset))) - set(test_indices))
val_indices = np.random.choice(new_list, size=nval_data, replace=False)

Create subsets:

In [None]:
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(test_dataset, val_indices)
test_dataset = Subset(test_dataset, test_indices)

# <font color="red">Test Case 1: Use a linear model</font>

### Basic steps

- Load a pre-trained DINOv2 model using PyTorch Hub.
- Define a function to load an MNIST image and transform it into a format accepted by DINOv2.
- Compute DINOv2 embeddings for each MNIST image in your dataset.
- Train a classifier on the generated embeddings and their corresponding labels.
   - We can use a lightweight classifier, such as a Linear Support Vector Classification (SVC) model, on these embeddings.
   - The trained classifier will then be able to accurately classify the MNIST digits based on the visual features extracted by DINOv2.
- Use the trained classifier to predict the digit for new MNIST images. 

## <font color="blue"> Define the model</font>

#### Head model

In [None]:
class LinearClassifierHead(nn.Module):
    def __init__(self, embed_dim, num_hidden_nodes, num_classes):
        super().__init__()
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        return self.head(x)

#### Model that combines DINOv2 and the head model

In [None]:
class CustomModel(nn.Module):
    def __init__(self, dinov2_model, num_hidden_nodes, num_classes):
        super().__init__()
        self.dinov2_vits14 = dinov2_model
        try:
            self.embed_dim = dinov2_model.embed_dim
        except:
            self.embed_dim = dinov2_model.config.hidden_size
        # Replace the original head or add a new one
        # DINOv2 typically outputs a feature vector, so a linear layer is suitable
        self.image_classifier = LinearClassifierHead(self.embed_dim, num_hidden_nodes, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.dinov2_vits14(x)
        
        logits = self.image_classifier(features)
        return logits


## <font color="blue"> Create and train the model</font>

In [None]:
num_classes = 10

## <font color="green"> Set hyperparameters</font>

In [None]:
learning_rate = 0.001

In [None]:
batch_size = 64

In [None]:
max_epochs = 5

In [None]:
num_hidden_nodes=64

### <font color="green"> Load a pre-trained DINOv2 model</font>

In [None]:
dinov2_model = torch.hub.load('facebookresearch/dinov2', vit_name_small)

In [None]:
try:
    embed_dim = dinov2_model.embed_dim
except:
    embed_dim = dinov2_model.config.hidden_size

In [None]:
print(f"DINOv2 embed_dim: \n\t {embed_dim}")

In [None]:
dinov2_model.to(device)

In [None]:
dinov2_model.eval()

### <font color="green">Freeze DINOv2 backbone parameters</font>

In [None]:
for param in dinov2_model.parameters():
    param.requires_grad = False

### <font color="green">Create model</font>

In [None]:
torch.manual_seed(1)

model = CustomModel(
    dinov2_model=dinov2_model, 
    num_hidden_nodes = num_hidden_nodes, 
    num_classes=num_classes
)

In [None]:
model = model.to(device)

In [None]:
print('\t Model information: \n')
print(model)

In [None]:
def print_trainable_parameters_per_layer(model):
    n = 20
    m = 10
    p = n+m+2
    print(f"{'-'*p}")
    print(f"{'Modules':<{n}}  {'Parameters':{m}}")
    print(f"{'-'*p}")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name:<{n}}  {param.numel():{m}}")

In [None]:
print_trainable_parameters_per_layer(model)

In [None]:
for param in model.image_classifier.parameters():
    print(param)

### <font color="green">Define the loss function and optimzer</font>

In [None]:
loss_function = nn.CrossEntropyLoss()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### <font color="green"> Define DataLoader</font>

- We pass the dataset to our dataloader, and our `batch_size` hyperparameter as initialization arguments.
- This creates an iterable data loader, so we can easily iterate over each batch using a loop.

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

__Let us check some examples (by using `test_loader`):__

In [None]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [None]:
example_data.shape

In [None]:
example_targets.shape

In [None]:
def plot_sample_images(feature_data, label_data, ref_title="Ground Truth"):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.imshow(feature_data[i][0], cmap='gray', interpolation='none')
        plt.xticks([])
        plt.yticks([])
        plt.title(f"{ref_title}: {label_data[i]}")
        plt.tight_layout()

In [None]:
plot_sample_images(example_data, example_targets)

### <font color="green">Define functions to train per batch and per epoch </font>

In [None]:
def train_model_per_batch(data, target, mymodel, 
                          myloss_function, myoptimizer) -> float:
    #data, target = data.to(device), target.to(device)

    # Zero the gradients
    myoptimizer.zero_grad()

    # Perform forward pass
    feature = mymodel(data)

    # Compute loss
    loss = myloss_function(feature, target)

    # Perform backward pass
    loss.backward()

    # Perform optimization
    myoptimizer.step()
    
    return loss.item()

In [None]:
def train_model_per_epoch(epoch_idx, mymodel, myloss_function, myoptimizer, 
                          dataloader_train, train_losses, train_counter):
    # Put model in training model
    mymodel.train()
    n_dataloader = len(dataloader_train.dataset)
    n_data_per_batch = len(dataloader_train)
    print(f"Inside train_model_per_epoch - Number of item: {n_dataloader}")
    for batch_idx, (data, target) in enumerate(dataloader_train):
        data, target = data.to(device), target.to(device)
        n_data = len(data)
        loss_val = train_model_per_batch(data, target, mymodel, loss_function, optimizer)
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch_idx} [{batch_idx*n_data}/{n_dataloader}' 
                  f'({100.*batch_idx/n_data_per_batch:.0f}%)]\tLoss: {loss_val:.6f}')
            train_losses.append(loss_val)
            train_counter.append((batch_idx*64) + ((epoch_idx-1)*n_dataloader))

### <font color="green">Define function to evaluate the model accuracy</font>

In [None]:
def compute_accuracy(model, dataloader, test_loss):
    """
    Compute the percentage of correct classification.
    """

    model = model.eval()

    n_items = len(dataloader.dataset)
    print(f"Inside compute_accuracy - Number of item: {n_items}")
    correct = 0.0
    test_loss = 0

    with torch.no_grad(): 
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            logits = model(data)
            # sum up batch loss
            test_loss += loss_function(logits, target).item()
            # get the index of the max log-probability
            _, pred = torch.max(logits.data, 1)
            correct += (pred == target).sum().item()

    test_loss /= len(dataloader)
    test_losses.append(test_loss)    
    perc = 100. * correct / n_items
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{n_items} ({perc:.0f}%)\n')

### <font color="green">Train and evaluate the model</font>

In [None]:
train_losses = list()
train_counter = list()
test_losses = list()
test_counter = [i*len(train_loader.dataset) for i in range(max_epochs+1)]

In [None]:
%%time

compute_accuracy(model, test_loader, test_losses)
for epoch_idx in range(1, max_epochs+1):
    train_model_per_epoch(epoch_idx, model, loss_function, optimizer,
                         train_loader, train_losses, train_counter)
    compute_accuracy(model, test_loader, test_losses)  

__Plot the losses__

In [None]:
def plot_losses(train_counter, train_losses, test_counter, test_losses):
    fig = plt.figure()
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('Number of training examples seen')
    plt.ylabel('Loss')

In [None]:
plot_losses(train_counter, train_losses, test_counter, test_losses)

# <font color="red">Test Case 2: PyTorch Lightning and a multi-layer NN model</font>

- In the previous example, we combined DINOv2 with a simple linear model.
- Here we use a multi-layer sequatial model.
- We also use PyTorch Lightning that is designed to automate and simplify the training and deployment of deep learning models.
   - It eliminates boilerplate code for training loops and complex setups, which is cumbersome for many developers, and allows users to focus on the core model and experiment logic.
   - It automate the training loop: abstracts the codes related to to epoch and batch iteration, `optimizer.step()`, `loss.backward()`, `optimizer.zero_grad()`, and setting the model to `eval()` or `train()` mode.
   - It simplifies complex setups like multi-GPU and distributed training (e.g., DDP) with minimal code changes, making it easy to scale training from a single device to multiple GPUs or TPUs.
     

__Custom classifier head__

- This classifier is created for illustration only.
- We create a sequential network consisting of:
   - A fully-connected (Linear) layer with `num_hidden_nodes` nodes, followed by the `Tanh` activation function.
   - A Dropout layer with a `20%` dropout rate to prevent overfitting.
   - A second Linear layer, with `num_hidden_nodes` nodes, followed by the `Sigmoid` activation function.
   - Another Dropout layer, that removes `20%` of the nodes.
   - A final Linear layer, with `num_classes` nodes (matching the number of classes in the dataset), followed by a Softmax activation function that outputs class probabilities.

In [None]:
class ImageClassifierNetwork(nn.Module):
    def __init__(self, input_size, num_hidden_nodes, num_classes):
        super(ImageClassifierNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.net = nn.Sequential(
            nn.Linear(input_size, num_hidden_nodes),
            #nn.ReLU(),
            nn.Tanh(),
            nn.Dropout(.2),
            
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            #nn.ReLU(),
            nn.Sigmoid(),
            nn.Dropout(.2),
            
            nn.Linear(num_hidden_nodes, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.flatten(x)
        output = self.net(x)
        return output

__Lightning model with DINOv2 and custom head__

When using `LightningModule`, the PyTorch code isn't abstracted; it’s organized into six sections:

- Initialization (`__init__` and `setup()` methods).
- Train loop (`training_step()` method).
- Validation loop (`validation_step()` method).
- Test loop (`test_step()` method).
- Prediction loop (`prediction_step()` method).
- Optimizers and LR schedulers (`configure_optimizers()`).

Each the above methods needs to be included inside the class.

In [None]:
class MyLightningModule(pl.LightningModule):
    def __init__(self, dinov2_model, num_hidden_nodes, num_classes, learning_rate):
        super().__init__()
        self.dinov2_model = dinov2_model
        try:
            self.embed_dim = dinov2_model.embed_dim
        except:
            self.embed_dim = dinov2_model.config.hidden_size
        print(f"embed_dim = {self.embed_dim}")
        self.image_classifier = ImageClassifierNetwork(self.embed_dim, num_hidden_nodes, 
                                                       num_classes)
        
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        with torch.no_grad():
            features = self.dinov2_model(x)
        
        logits = self.image_classifier(features)
        return logits

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)     
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(1) == labels).float().mean()
        self.log("train_acc", acc)
        # Log the loss at each training step and epoch, create a progress bar
        self.log("train_loss", loss, 
                 on_step=True, on_epoch=True, prog_bar=True, logger=True)
        #self.log('train_loss', loss)
        
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        #preds = torch.argmax(outputs, dim=1)
        acc = (outputs.argmax(1) == labels).float().mean()
        #acc = (preds == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(1) == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.image_classifier.parameters(), lr=self.learning_rate)

__Create the model__

In [None]:
new_model = MyLightningModule(
    dinov2_model=dinov2_model, 
    num_hidden_nodes=num_hidden_nodes,
    num_classes=num_classes,
    learning_rate=learning_rate
)

__Set up the `Trainer`__

In [None]:
which_device = "gpu" if torch.cuda.is_available() else "cpu"

In [None]:
trainer = pl.Trainer(
    max_epochs=max_epochs, 
    accelerator=which_device, 
    devices="auto"
)

In [None]:
new_model = new_model.to(device)

__Train and validate the model__

In [None]:
%%time

trainer.fit(new_model, train_loader, val_loader)

__Test the model__

In [None]:
%%time

trainer.test(new_model, test_loader)

- We write here the `predict` function to determine the prediction.
- This is redundant and is meant for verification only.

In [None]:
def predict(mymodel, dataloader, device):
    mymodel = mymodel.to(device)
    mymodel.eval()
    all_preds = list()
    all_actuals = list()
    with torch.no_grad():
        for idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            outputs = mymodel(data)
            preds = torch.argmax(outputs, dim=1)
            # Move predictions back to CPU and convert to numpy
            all_preds.extend(preds.cpu().numpy())
            all_actuals.extend(target.cpu().numpy())  
    return all_actuals, all_preds

In [None]:
%%time

# Generate predictions
actuals, predictions = predict(new_model, test_loader, device)

# Create submission file
df = pd.DataFrame(
    {'ImageId': range(1, len(predictions) + 1), 
     'Actuals': actuals, 
     'Predictions': predictions}
)
df.to_csv('test_results.csv', index=False)

In [None]:
test_accuracy = len(df[df.Actuals == df.Predictions])/len(df)
print(test_accuracy)

# <font color="red">Test Case 3: Use DINOv2 embeddings as inputs of a NN</font>

- In the previous examples, the first `layer` of the created models was the DINOv2 model.
- The DINOv2 model is used only in evaluation mode, which primary role is to compute the embeddings.
- At each epoch, the embeddings of each image are calculated. This makes the training, validation and testing time consuming.
- We want to compute the DINOv2 embeddings first and use them as inputs to any neural network of choice.


__Step 1__: Steps to Generate Embeddings

- Load DINOv2: Load a pre-trained DINOv2 model.
- Obtain the MNIST Data
- Apply necessary transformations to the MNIST images, such as resizing and normalization, to match the input requirements of the DINOv2 model.
- Extract Embeddings: Feed the preprocessed MNIST images into the DINOv2 model. The output of the model, specifically the class token embedding or a global average pooled representation of the patch embeddings, will be the feature vector (embedding) for each MNIST image.

__Step 2__: Training the PyTorch Model

- Use the DINOv2 embeddings for the MNIST images as input to train a PyTorch classification model.
- The model can be: 
   - A Linear Classifier: A simple linear layer can be sufficient to classify the embeddings.
   - A Neural Network: Build a  feedforward neural network on top of the embeddings for potentially better classification performance. 

## <font color="blue">Determine the embeddings</font>

In [None]:
def create_embeddings_labels(dataset, dino_model, device):
    embeddings_obj = list()
    labels_obj = list()
    with torch.no_grad():
        for i in range(len(dataset)):
            data = dataset[i][0].unsqueeze(0).to(device)
            
            output = dino_model(data)
            embeddings_obj.append(output.cpu())
            labels_obj.append(dataset[i][1])
    embeddings_obj = torch.cat(embeddings_obj, dim=0)
    labels_obj = torch.tensor(labels_obj)

    return embeddings_obj, labels_obj

In [None]:
try:
    embed_dim = dinov2_model.embed_dim
except:
    embed_dim = dinov2_model.config.hidden_size

In [None]:
%%time

embeddings_train, labels_train = create_embeddings_labels(train_dataset, dinov2_model, device)

In [None]:
%%time

embeddings_test, labels_test = create_embeddings_labels(test_dataset, dinov2_model, device)

In [None]:
%%time

embeddings_val, labels_val = create_embeddings_labels(val_dataset, dinov2_model, device)

## <font color="blue">Create the dataloaders</font>

In [None]:
class MyDataset():
    '''
    Custom 'Dataset' object for our regression data.
    Must implement these functions: __init__, __len__, and __getitem__.
    '''
    def __init__(self, X, y):
        self.features = X
        self.labels = y

    def __getitem__(self, index):
        x = self.features[index]
        y = self.labels[index]
        return x, y

    def __len__(self):
        return len(self.features)


In [None]:
def instantiate_data(Xdata, ydata, batch_size=64, shuffle=False):
    dataset = MyDataset(Xdata, ydata)
    dataloader = DataLoader(dataset=dataset, 
                            batch_size=batch_size, 
                            shuffle=shuffle)
    return dataloader

In [None]:
train_dataloader = instantiate_data(embeddings_train, labels_train, 
                                    batch_size=batch_size, shuffle=True)

test_dataloader = instantiate_data(embeddings_test, labels_test, 
                                   batch_size=batch_size, shuffle=False)

val_dataloader = instantiate_data(embeddings_val, labels_val, 
                                  batch_size=batch_size, shuffle=False)

## <font color="blue">Model creation</font>

In [None]:
class GenericLightningModule(pl.LightningModule):
    def __init__(self, nn_model, embed_dim, num_hidden_nodes, num_classes, learning_rate):
        super().__init__()
        self.image_classifier = nn_model(embed_dim, num_hidden_nodes, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        logits = self.image_classifier(x)
        return logits

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)     
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(1) == labels).float().mean()
        self.log("train_acc", acc)
        self.log("train_loss", loss, 
                 on_step=True, on_epoch=True, prog_bar=True, logger=True) 
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(1) == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = (outputs.argmax(1) == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.image_classifier.parameters(), lr=self.learning_rate)

### <font color="green">Option 1: Linear model</font>

In [None]:
linear_model = GenericLightningModule(
    nn_model=LinearClassifierHead, 
    embed_dim=embed_dim,
    num_hidden_nodes=num_hidden_nodes,
    num_classes=num_classes,
    learning_rate=learning_rate
)

In [None]:
which_device = "gpu" if torch.cuda.is_available() else "cpu"

In [None]:
linear_trainer = pl.Trainer(
    max_epochs=max_epochs, 
    accelerator=which_device, 
    devices="auto"
)

In [None]:
linear_model = linear_model.to(device)

In [None]:
%%time

linear_trainer.fit(linear_model, train_dataloader, val_dataloader)

In [None]:
%%time

linear_trainer.test(linear_model, test_dataloader)

### <font color="green">Option 2: Multi-layer model</font>

In [None]:
nn_model = GenericLightningModule(
    nn_model=ImageClassifierNetwork, 
    embed_dim=embed_dim,
    num_hidden_nodes=num_hidden_nodes,
    num_classes=num_classes,
    learning_rate=learning_rate
)

In [None]:
nn_trainer = pl.Trainer(
    max_epochs=max_epochs, 
    accelerator=which_device, 
    devices="auto"
)

In [None]:
nn_model = nn_model.to(device)

In [None]:
%%time

nn_trainer.fit(nn_model, train_dataloader, val_dataloader)

In [None]:
%%time

nn_trainer.test(nn_model, test_dataloader)

# <font color="red"> References</font>

- [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/pdf/2304.07193) by Maxime Oquab et al.
- [DINOv2 by Meta: A Self-Supervised foundational vision model](https://learnopencv.com/dinov2-self-supervised-vision-transformer/) by Bhomik Sharma, April 2025.
- [01.Meta-DinoV2-Getting Started](https://www.kaggle.com/code/shravankumar147/01-meta-dinov2-getting-started)
- [DINOv2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) from hugginface.co
- [Building the DINO model from Scratch with PyTorch: Self-Supervised Vision Transformer](https://medium.com/thedeephub/self-supervised-vision-transformer-implementing-the-dino-model-from-scratch-with-pytorch-62203911bcc9) by Shubh Mishra
- [How to Classify Images with DINOv2](https://blog.roboflow.com/how-to-classify-images-with-dinov2/) by James Gallagher (May 30, 2023
- [Deploying DINOv2 to A Rest API Endpoint for Image Classification | Modelbit](https://colab.research.google.com/github/write-with-neurl/modelbit-09/blob/main/notebook/Deploying_DINOv2_for_Image_Classification_with_Modelbit.ipynb#scrollTo=q06RxQlCzQnG)
- [DINOv2: Self-supervised Learning Model Explained](https://encord.com/blog/dinov2-self-supervised-learning-explained/) eNCORD Blog, November 2024.
- [How to Classify Images with DINOv2](https://blog.roboflow.com/how-to-classify-images-with-dinov2/) by James Gallagher, May 2023.
- [DinoV2 Fine-Tuning Tutorial: How to Maximize Accuracy for Computer Vision Tasks](https://kili-technology.com/data-labeling/computer-vision/dinov2-fine-tuning-tutorial-maximizing-accuracy-for-computer-vision-tasks) by Asmaa Mirkhan
- [PyTorch Lightning: A Comprehensive Hands-On Tutorial](https://www.datacamp.com/tutorial/pytorch-lightning-tutorial) by Bex Tuychiev