# Dinov2 Model Inference

## Introduction
The purpose of this notebook is to serve as a baseline entry for the competition, and as a guidance on how to infer the available trained models on quadrat test images. 
The DinoV2-based plant identification model provided in the competition is applied to the entire image of each quadrat in the test set, which has been resized to the original input dimensions of 518x518 used to train DinoV2.
This notebook can be used as a starting point for further development. Feel free to leave comments on errors or for any improvement.

### Import libraries

In [4]:
# Reading & writing CSV
import csv
# Numerical Computation
import numpy as np 
# Data Manipulation
import pandas as pd
# PyTorch Image Models: pre-trained computer vision models
import timm 
# PyTorch
import torch
# Manipulating Images
from PIL import Image
# Utilities for handling data loading
from torch.utils.data import DataLoader, Dataset
# A module for emitting log messages from the program (A professional way to track events, errors, and progress)
import logging
# Measure time
import time
# Interact with file system
import os

logging.basicConfig(
    level=logging.INFO,
    handlers=[logging.StreamHandler()])
_logger = logging.getLogger('inference')

In [5]:
# AverageMeter Class:
# This is a utility class used to keep track of a running average for a value, like loss or accuracy, during training or inference.
# It's a convenient way to monitor performance over time.

class AverageMeter:
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        

# TEstDataset Class
# This class is a custom PyTorch Dataset specifically for the test images.
# It tells PyTorch how to find, load, and preprocess each image in the test set.
class TestDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder)]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        if self.transform:
            image = self.transform(image)

        return image, image_path 

Load species competition metadata

In [6]:
df_species_ids = pd.read_csv('/kaggle/input/plantclef-2025/species_ids.csv')

df_metadata = pd.read_csv('/kaggle/input/plantclef-2025/PlantCLEF2024_single_plant_training_metadata.csv', sep=';', dtype={'partner': str})
id_to_species = df_metadata[['species_id', 'species']].drop_duplicates().set_index('species_id')

df_metadata.head()

Unnamed: 0,image_name,organ,species_id,obs_id,license,partner,author,altitude,latitude,longitude,gbif_species_id,species,genus,family,dataset,publisher,references,url,learn_tag,image_backup_url
0,59feabe1c98f06e7f819f73c8246bd8f1a89556b.jpg,leaf,1396710,1008726402,cc-by-sa,,Gulyás Bálint,205.9261,47.59216,19.362895,5284517.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/59feabe1c98f06...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...
1,dc273995a89827437d447f29a52ccac86f65476e.jpg,leaf,1396710,1008724195,cc-by-sa,,vadim sigaud,323.752,47.906703,7.201746,5284517.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/dc273995a89827...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...
2,416235e7023a4bd1513edf036b6097efc693a304.jpg,leaf,1396710,1008721908,cc-by-sa,,fil escande,101.316,48.826774,2.352774,5284517.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/416235e7023a4b...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...
3,cbd18fade82c46a5c725f1f3d982174895158afc.jpg,leaf,1396710,1008699177,cc-by-sa,,Desiree Verver,5.107,52.190427,6.009677,5284517.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/cbd18fade82c46...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...
4,f82c8c6d570287ebed8407cefcfcb2a51eaaf56e.jpg,leaf,1396710,1008683100,cc-by-sa,,branebrane,165.339,45.794739,15.965862,5284517.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/f82c8c6d570287...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...


Import provided model with timm library:

In [7]:
device = torch.device('cuda')
model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m',
                          pretrained=False,
                          num_classes=len(df_species_ids),
                          checkpoint_path='/kaggle/input/dinov2_patch14_reg4_onlyclassifier_then_all/pytorch/default/3/model_best.pth.tar')
model = model.to(device)
model = model.eval()

Load model configuration settings

In [8]:
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

Set hyperparameters:
* batch_size: size of batch of testing images
* top_k: keep best top_k results for each image
* min_score: keep only classes with a score higher than min_score

In [9]:
batch_size = 32
top_k = 15
min_score = 0.01 

Inference on Test Data

In [10]:

class_map = df_species_ids['species_id'].to_dict()
dataset = TestDataset(image_folder='/kaggle/input/plantclef-2025/PlantCLEF2025_test_images/PlantCLEF2025_test_images/',
                      transform=timm.data.create_transform(**data_config, is_training=False))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

image_predictions = {}

# Initialize batch time tracking
batch_time = AverageMeter()
end = time.time()

with torch.no_grad():
    for batch_idx, (images, image_paths) in enumerate(dataloader):
        images = images.to(device)
        outputs = model(images)  # Perform inference on the batch
        probabilities = torch.nn.functional.softmax(outputs, dim=1)

        # Get the top-k values and their indices
        values, indices = torch.topk(probabilities, top_k, dim=1)
        
        # Filter based on the probability threshold
        values_np = values.cpu().numpy()
        indices_np = indices.cpu().numpy()
        
        for i in range(values_np.shape[0]):
            # Filtered class indices above the threshold
            filtered_indices = indices_np[i][values_np[i] >= min_score]
            
            # Convert class indices to class labels
            filtered_labels = [class_map.get(idx, 'Unknown') for idx in filtered_indices]

            # Get the image name without the extension
            image_name = os.path.splitext(os.path.basename(image_paths[i]))[0]

            image_predictions[image_name] = filtered_labels
        
        batch_time.update(time.time() - end)
        end = time.time()

        # Log info at specified frequency
        if batch_idx % 10 == 0:  # You can set your log frequency here
            _logger.info(f'Predict: [{batch_idx}/{len(dataloader)}] '
                         f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})')  

Submit predictions

In [11]:
df_run = pd.DataFrame(list(image_predictions.items()), columns=['quadrat_id', 'species_ids'])
df_run['species_ids'] = df_run['species_ids'].apply(str)
df_run.to_csv("submission.csv", sep=',', index=False, quoting=csv.QUOTE_ALL)