# Embedding workflow using DINOv2

This notebook focuses on the **Feature Extraction** pipeline. 

We utilize the fine-tuned model **ViTD2PC24All** ([DINOv2](https://dinov2.metademolab.com/)) to extract high-dimensional embeddings from the single-label train images and multi-label test images.

We'll **visualize**, **tile**, and **process** these embeddings to support patch-wise multi-label inference using PyTorch and Faiss.

![diagram](../images/pytorch-webinar-diagram.png)

In [1]:
%load_ext autoreload
%autoreload 2

## Now to load the parquet file from disk and visualize the images

In [11]:
import pandas as pd
from rich import print as pprint

pd.options.display.precision = 2
pd.options.display.max_rows = 10
pd.options.display.max_columns = 25

root_dir = "/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025"
dataset_dir = "/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/competition-metadata/PlantCLEF2025_test_images/PlantCLEF2025_test_images"
hf_dataset_dir = "/teamspace/studios/this_studio/plantclef-vision/data/parquet/plantclef2025/full_test/HF_dataset"

In [None]:
dir(pd.options.display)

['chop_threshold',
 'colheader_justify',
 'date_dayfirst',
 'date_yearfirst',
 'encoding',
 'expand_frame_repr',
 'float_format',
 'html',
 'large_repr',
 'max_categories',
 'max_columns',
 'max_colwidth',
 'max_dir_items',
 'max_info_columns',
 'max_info_rows',
 'max_rows',
 'max_seq_items',
 'memory_usage',
 'min_rows',
 'multi_sparse',
 'notebook_repr_html',
 'pprint_nest_depth',
 'precision',
 'show_dimensions',
 'unicode',
 'width']

In [62]:
from plantclef.datasets.preprocessing.hf.train_val_test_subsets_to_hf import (
    Config,
    get_dict_transform,
)
# from plantclef.datasets import preprocessing

In [38]:
cfg = Config()

{'shortest_edge': 588}


In [39]:
cfg.metadata_path

'/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/competition-metadata/PlantCLEF2024_single_plant_training_metadata.csv'

In [6]:
cfg.show()

metadata = cfg.load_metadata()
class2idx = cfg.load_class_index(mode="class2idx")

metadata = cfg.encode_target_col(metadata, class2idx=class2idx)

[Cache Found] previously preprocessed metadata cache, loading from cache file and skipping preprocessing


In [27]:
keep_cols = [
    "image_path",
    "label_idx",
    "image_name",
    "organ",
    "species_id",
    "obs_id",
    "author",
    "altitude",
    "latitude",
    "longitude",
    "species",
    "genus",
    "family",
    "learn_tag",
]

metadata = metadata[keep_cols]

In [47]:
from datasets import Dataset as HFDataset, DatasetDict as HFDatasetDict
from datasets import Image

train_df = metadata[metadata["learn_tag"] == "train"]
val_df = metadata[metadata["learn_tag"] == "val"]
test_df = metadata[metadata["learn_tag"] == "test"]

train_ds = HFDataset.from_pandas(train_df)
val_ds = HFDataset.from_pandas(val_df)
test_ds = HFDataset.from_pandas(test_df)


dataset = HFDatasetDict({"train": train_ds, "val": val_ds, "test": test_ds})
dataset = dataset.cast_column(cfg.x_col, Image())
dataset

DatasetDict({
    train: Dataset({
        features: ['image_path', 'label_idx', 'image_name', 'organ', 'species_id', 'obs_id', 'author', 'altitude', 'latitude', 'longitude', 'species', 'genus', 'family', 'learn_tag', '__index_level_0__'],
        num_rows: 1308899
    })
    val: Dataset({
        features: ['image_path', 'label_idx', 'image_name', 'organ', 'species_id', 'obs_id', 'author', 'altitude', 'latitude', 'longitude', 'species', 'genus', 'family', 'learn_tag', '__index_level_0__'],
        num_rows: 51194
    })
    test: Dataset({
        features: ['image_path', 'label_idx', 'image_name', 'organ', 'species_id', 'obs_id', 'author', 'altitude', 'latitude', 'longitude', 'species', 'genus', 'family', 'learn_tag', '__index_level_0__'],
        num_rows: 47940
    })
})

True

In [64]:
# def get_dict_transform(transform_kwargs = {}, input_columns=None) -> Callable:

#     tx = get_transforms(**transform_kwargs)
#     def func(data, *args, **kwargs):
#         if (input_columns is not None) and isinstance(input_columns, str):
#             data = data[input_columns]
#             return {input_columns: tx(data)}
#         return tx(data)

#     return func

tx = get_dict_transform(
    transform_kwargs={"image_size": {"shortest_edge": 716}}, input_columns=cfg.x_col
)
tx

<function plantclef.datasets.preprocessing.hf.train_val_test_subsets_to_hf.get_dict_transform.<locals>.func(data, *args, **kwargs)>

In [None]:
dataset["train"][0][cfg.x_col]

In [65]:
dataset = dataset.map(tx, input_columns=cfg.x_col, num_proc=4)
dataset

Map:   0%|          | 0/1308899 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [46]:
for subset, ds in dataset.items():
    print(subset, ds.features)

train {'image_path': Value(dtype='string', id=None), 'label_idx': Value(dtype='int64', id=None), 'image_name': Value(dtype='string', id=None), 'organ': Value(dtype='string', id=None), 'species_id': Value(dtype='int64', id=None), 'obs_id': Value(dtype='int64', id=None), 'author': Value(dtype='string', id=None), 'altitude': Value(dtype='float64', id=None), 'latitude': Value(dtype='float64', id=None), 'longitude': Value(dtype='float64', id=None), 'species': Value(dtype='string', id=None), 'genus': Value(dtype='string', id=None), 'family': Value(dtype='string', id=None), 'learn_tag': Value(dtype='string', id=None), '__index_level_0__': Value(dtype='int64', id=None), 'image': Image(mode=None, decode=True, id=None)}
val {'image_path': Value(dtype='string', id=None), 'label_idx': Value(dtype='int64', id=None), 'image_name': Value(dtype='string', id=None), 'organ': Value(dtype='string', id=None), 'species_id': Value(dtype='int64', id=None), 'obs_id': Value(dtype='int64', id=None), 'author': Va

In [29]:
val_ds.features

{'image_path': Value(dtype='string', id=None),
 'label_idx': Value(dtype='int64', id=None),
 'image_name': Value(dtype='string', id=None),
 'organ': Value(dtype='string', id=None),
 'species_id': Value(dtype='int64', id=None),
 'obs_id': Value(dtype='int64', id=None),
 'author': Value(dtype='string', id=None),
 'altitude': Value(dtype='float64', id=None),
 'latitude': Value(dtype='float64', id=None),
 'longitude': Value(dtype='float64', id=None),
 'species': Value(dtype='string', id=None),
 'genus': Value(dtype='string', id=None),
 'family': Value(dtype='string', id=None),
 'learn_tag': Value(dtype='string', id=None),
 '__index_level_0__': Value(dtype='int64', id=None)}

In [31]:
# val_ds.take(10)["__index_level_0__"]

[479, 489, 490, 491, 492, 493, 495, 496, 497, 498]

In [None]:
# ds = HFDataset.from_dict({"image": image_paths, "file_path": image_paths})
# ds = ds.cast_column("image", Image())
# ds = ds.cast_column("file_path", Value("string"))

In [21]:
metadata.info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1408033 entries, 0 to 1408032
Data columns (total 22 columns):
 #   Column            Non-Null Count    Dtype   
---  ------            --------------    -----   
 0   image_name        1408033 non-null  string  
 1   organ             1408033 non-null  category
 2   species_id        1408033 non-null  int64   
 3   obs_id            1408033 non-null  Int64   
 4   license           1408033 non-null  category
 5   partner           115338 non-null   category
 6   author            1405895 non-null  string  
 7   altitude          705322 non-null   Float64 
 8   latitude          705425 non-null   Float64 
 9   longitude         705424 non-null   Float64 
 10  gbif_species_id   1406725 non-null  float64 
 11  species           1408033 non-null  category
 12  genus             1408033 non-null  category
 13  family            1408033 non-null  category
 14  dataset           1408033 non-null  category
 15  publisher         1357307 non-nu

In [23]:
metadata[keep_cols].info(verbose=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1408033 entries, 0 to 1408032
Data columns (total 14 columns):
 #   Column      Non-Null Count    Dtype   
---  ------      --------------    -----   
 0   image_path  1408033 non-null  string  
 1   label_idx   1408033 non-null  int64   
 2   image_name  1408033 non-null  string  
 3   organ       1408033 non-null  category
 4   species_id  1408033 non-null  int64   
 5   obs_id      1408033 non-null  Int64   
 6   author      1405895 non-null  string  
 7   altitude    705322 non-null   Float64 
 8   latitude    705425 non-null   Float64 
 9   longitude   705424 non-null   Float64 
 10  species     1408033 non-null  category
 11  genus       1408033 non-null  category
 12  family      1408033 non-null  category
 13  learn_tag   1408033 non-null  category
dtypes: Float64(3), Int64(1), category(5), int64(2), string(3)
memory usage: 113.2 MB


In [18]:
metadata["gbif_species_id"].nunique()
metadata["species_id"].nunique()

metadata["gbif_species_id"].isna().sum()
metadata["species_id"].isna().sum()

0

In [14]:
metadata.head(3)

Unnamed: 0,image_name,organ,species_id,obs_id,license,partner,author,altitude,latitude,longitude,gbif_species_id,species,genus,family,dataset,publisher,references,url,learn_tag,image_backup_url,image_path,label_idx
0,59feabe1c98f06e7f819f73c8246bd8f1a89556b.jpg,leaf,1396710,1008726402,cc-by-sa,,Gulyás Bálint,205.93,47.59,19.36,5280000.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/59feabe1c98f06...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...,/teamspace/studios/this_studio/plantclef-visio...,4826
1,dc273995a89827437d447f29a52ccac86f65476e.jpg,leaf,1396710,1008724195,cc-by-sa,,vadim sigaud,323.75,47.91,7.2,5280000.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/dc273995a89827...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...,/teamspace/studios/this_studio/plantclef-visio...,4826
2,416235e7023a4bd1513edf036b6097efc693a304.jpg,leaf,1396710,1008721908,cc-by-sa,,fil escande,101.32,48.83,2.35,5280000.0,Taxus baccata L.,Taxus,Taxaceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/416235e7023a4b...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...,/teamspace/studios/this_studio/plantclef-visio...,4826


In [12]:
metadata.describe(include="all")

Unnamed: 0,image_name,organ,species_id,obs_id,license,partner,author,altitude,latitude,longitude,gbif_species_id,species,genus,family,dataset,publisher,references,url,learn_tag,image_backup_url,image_path,label_idx
count,1408033,1408033,1.41e+06,1408033.0,1408033,115338,1405895,705322.0,705425.0,705424.0,1.41e+06,1408033,1408033,1408033,1408033,1357307,1357275,1408033,1408033,1408033,1408033,1.41e+06
unique,1408033,7,,,12,5,178333,,,,,7806,1446,181,2,5,1215736,1408033,3,1408033,1408033,
top,394af6a92ff308ae70cf4d62737d9f6fdb2cf96b.jpg,flower,,,cc-by-sa,tela,Tela Botanica − Liliane Roubaudi,,,,,Styphnolobium japonicum (L.) Schott,Carex,Asteraceae,plantnet,plantnet,https://identify.plantnet.org/fr/k-southwester...,https://bs.plantnet.org/image/o/394af6a92ff308...,train,https://lab.plantnet.org/LifeCLEF/PlantCLEF202...,/teamspace/studios/this_studio/plantclef-visio...,
freq,1,389251,,,1099727,111316,12294,,,,,823,21383,176707,1102483,1102483,54,1,1308899,1,1,
mean,,,1.41e+06,1619469443.27,,,,6597.68,43.29,1.11,4.63e+06,,,,,,,,,,,3.27e+03
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
min,,,1.36e+06,891083349.0,,,,-2274.0,-60.58,-178.17,2.65e+06,,,,,,,,,,,0.00e+00
25%,,,1.36e+06,1008019063.0,,,,73.0,42.35,-0.58,2.99e+06,,,,,,,,,,,1.56e+03
50%,,,1.39e+06,1014189664.0,,,,199.0,45.03,4.38,3.64e+06,,,,,,,,,,,3.08e+03
75%,,,1.40e+06,1019956407.0,,,,575.0,48.71,8.96,5.41e+06,,,,,,,,,,,4.87e+03


## Running torch_pipeline with HFPlantDataset

In [6]:
from plantclef.embed.workflow import Config
from plantclef.embed.utils import print_dir_size
import os

cfg = Config()
pprint(cfg)



In [14]:
import csv
import pandas as pd


df = pd.read_csv(cfg.test_submission_path)

df = df.assign(quadrat_id=df["quadrat_id"].apply(lambda x: os.path.splitext(x)[0]))

df.to_csv(cfg.test_submission_path, sep=",", index=False, quoting=csv.QUOTE_ALL)
df

Unnamed: 0,quadrat_id,species_ids
0,2024-CEV3-20240602,"[1654010, 1395063, 1392662, 1414387, 1743646]"
1,CBN-PdlC-A1-20130807,"[1744569, 1361917, 1356350, 1418612, 1361129]"
2,CBN-PdlC-A1-20130903,"[1744569, 1392608, 1361382, 1361068, 1361971]"
3,CBN-PdlC-A1-20140721,"[1529289, 1374758, 1402995, 1741880, 1362066]"
4,CBN-PdlC-A1-20140811,"[1361281, 1418612, 1356350, 1392608, 1722440]"
...,...,...
2100,RNNB-8-5-20240118,"[1361437, 1655199, 1357049, 1722441, 1414356]"
2101,RNNB-8-6-20240118,"[1655199, 1363434, 1359297, 1357962, 1361703]"
2102,RNNB-8-7-20240118,"[1359297, 1356521, 1363553, 1357358, 1362711]"
2103,RNNB-8-8-20240118,"[1359650, 1396330, 1743962, 1357962, 1388788]"


In [5]:
print_dir_size(cfg.test_embeddings_path)

Analyzing disk usage of directory: /teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/embeddings/full_test/test_grid_3x3_embeddings
Directory Disk Usage: 543M	/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/embeddings/full_test/test_grid_3x3_embeddings
2025-05-08 08:42:53


In [22]:
# top_1 = []
# top_2 = []
# top_3 = []
# top_4 = []
# top_5 = []

# for i, row in df.iterrows():
#     top_1.append(row["logits"][0])
#     top_2.append(row["logits"][1])
#     top_3.append(row["logits"][2])
#     top_4.append(row["logits"][3])
#     top_5.append(row["logits"][4])

#     print(i)
#     # pprint(row)

#     if i >= 5:
#         break

# print(f"top_1: {top_1}")
# print(f"top_2: {top_2}")
# print(f"top_3: {top_3}")
# print(f"top_4: {top_4}")
# print(f"top_5: {top_5}")
# top_species_ids = [s_id for s_id, _ in [*top_1, *top_2, *top_3, *top_4, *top_5]]

  df.apply(select_top_k_unique_logits, top_k=top_k).rename("logits").reset_index()


Unnamed: 0,image_name,logits
0,2024-CEV3-20240602.jpg,"[(1654010, 0.44266772270202637), (1395063, 0.3..."
1,CBN-PdlC-A1-20130807.jpg,"[(1744569, 0.2301855832338333), (1361917, 0.22..."
2,CBN-PdlC-A1-20130903.jpg,"[(1744569, 0.16917195916175842), (1392608, 0.1..."
3,CBN-PdlC-A1-20140721.jpg,"[(1529289, 0.14910352230072021), (1374758, 0.1..."
4,CBN-PdlC-A1-20140811.jpg,"[(1361281, 0.12936192750930786), (1418612, 0.1..."
...,...,...
2100,RNNB-8-5-20240118.jpg,"[(1361437, 0.7179210782051086), (1655199, 0.52..."
2101,RNNB-8-6-20240118.jpg,"[(1655199, 0.37736761569976807), (1363434, 0.2..."
2102,RNNB-8-7-20240118.jpg,"[(1359297, 0.30361855030059814), (1356521, 0.2..."
2103,RNNB-8-8-20240118.jpg,"[(1359650, 0.3005388379096985), (1396330, 0.28..."


## Explore embeddings

### Get embeddings and logits from model.predict_step

### Get image names from HFDataset -> Create a pandas DataFrame to match image names to logits + embeddings

# Misc below

### Extracting embeddings from single-label training images

We extract embeddings from a small subset of training images to validate our pipeline.  
We don't perform tiling on the train images (we use the full image) and extract 768-dimensional ViT embeddings.

### Embedding test images with tiling (3x3)


Since the test images are high-resolution and contain multiple plant species, we split them into a 3x3 grid of tiles.
- We **extract embeddings** and **top-*K* logits** from each tile using the ViT model.  
- This **patch-wise representation** is critical for enabling multi-label classification.

### Analyzing classifier logits per tile

For each tile, we look at the **top predicted species** and associated confidence scores (`logits`).  
This helps interpret how confident the model is in identifying species in each patch.

### Embedding the entire test set with tiling

We scale up our embedding pipeline to process the full test dataset using **3x3 tiling**.  
This prepares the data for the downstream tasks of efficient **nearest neighbor search** and **multi-label prediction** at the tile level.

### Saving test embeddings and logits to Parquet

We serialize the full test embeddings into partitioned Parquet files for later use in inference pipelines.  
The logits are stored as JSON strings for flexibility.

## Embedding the full training set (no tiling)

We repeat the embedding process on the **full training dataset**, this time *without tiling*.  
This enables us to use the embeddings directly or as a **transfer learning** approach in a Faiss-based nearest neighbor retrieval system.

### Saving the training embeddings to Parquet

Finally, we save the full training embeddings in partitioned Parquet format to support fast, distributed retrieval during inference.

### Embeddings Ready for Downstream Use

We now have rich ViT embeddings for both train and test datasets, ready for use in:
- Multi-label classification
- Retrieval-based inference
- Nearest Neighbor Search