# Embedding workflow using DINOv2

This notebook focuses on the **Feature Extraction** pipeline. 

We utilize the fine-tuned model **ViTD2PC24All** ([DINOv2](https://dinov2.metademolab.com/)) to extract high-dimensional embeddings from the single-label train images and multi-label test images.

We'll **visualize**, **tile**, and **process** these embeddings to support patch-wise multi-label inference using PyTorch and Faiss.

![diagram](../images/pytorch-webinar-diagram.png)

In [1]:
# !uv pip list | grep pydantic
# !uv pip install meerkat-ml
# !which pip

# !uv pip install pyspark -v

In [1]:
%load_ext autoreload
%autoreload 2

## Now to load the parquet file from disk and visualize the images

In [2]:
import pandas as pd

pd.options.display.precision = 2
pd.options.display.max_rows = 10

root_dir = "/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025"
dataset_dir = "/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/competition-metadata/PlantCLEF2025_test_images/PlantCLEF2025_test_images"
hf_dataset_dir = "/teamspace/studios/this_studio/plantclef-vision/data/parquet/plantclef2025/full_test/HF_dataset"

In [3]:
# from plantclef.pytorch.data import HFPlantDataset
from torchvision import transforms
from typing import Callable
import torch

In [5]:
# def transform_dict(transforms: Callable, key: str) -> Callable:
#     """Apply transformation to a specific key in the dataset."""

#     def transform_fn(row):
#         row[key] = [transforms(image) for image in row[key]]
#         return row

#     return transform_fn


# def create_transform(image_size: int, key: Optional[str] = None) -> Callable:
#     """Create image transformation pipeline that maintains aspect ratio."""
#     transform_list = [
#         # transforms.ToPILImage(),
#         transforms.Resize(
#             image_size, max_size=image_size + 2
#         ),  # Maintains aspect ratio
#         transforms.CenterCrop(image_size),
#         transforms.ToTensor(),
#     ]
#     transform_list = transforms.Compose(transform_list)
#     if key is not None:
#         return transform_dict(transform_list, key)
#     return transform_list

## Running torch_pipeline with HFPlantDataset

In [4]:
from plantclef.embed.workflow import torch_pipeline, Config
from plantclef.embed.utils import print_dir_size
import os
from rich import print as pprint

cfg = Config()
pprint(cfg)



In [5]:
print_dir_size(cfg.test_embeddings_path)

Analyzing disk usage of directory: /teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/embeddings/full_test/test_grid_3x3_embeddings
Directory Disk Usage: 543M	/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/embeddings/full_test/test_grid_3x3_embeddings
2025-05-08 08:42:53


In [6]:
from datasets import Dataset as HFDataset
import numpy as np


ds = HFDataset.load_from_disk(cfg.test_embeddings_path)


species_ids = [
    int(species_id) for species_id in sorted(list(set(ds[0]["logits"].keys())))
]
print(f"len(species_ids): {len(species_ids)}")

ds

len(species_ids): 2911


Dataset({
    features: ['image_name', 'embeddings', 'logits', 'tile'],
    num_rows: 18945
})

In [7]:
from typing import Tuple, List


def remove_NaN_values_from_dict(d: dict) -> dict:
    """Remove NaN values from a dictionary."""
    return {k: v for k, v in d.items() if v is not None}


def sort_and_filter_dict(d: dict, top_k: int = 0) -> List[Tuple[str, float]]:
    """
    Sort a dictionary by values and filter out NaN values.

    Takes in a dictionary mapping str keys to float values, then
        removes any keys with NaN values
        sorts the remaining key-value pairs in descending order by value and transforms into a sorted list of tuples.
    If top_k is specified, only the top_k items are returned.
    Args:
        d (dict): The dictionary to sort and filter.
        top_k (int): The number of top items to return. Default is 0, which returns all items.
    Returns:
        List[Tuple[str, float]]: A sorted list of tuples (key, value) from the dictionary.

    * [TODO] -- Consider adding a threshold for the values to filter out low-confidence predictions.

    """
    # Remove NaN values
    d = remove_NaN_values_from_dict(d)

    # Sort the dictionary by values in descending order
    sorted_list = sorted(d.items(), key=lambda item: item[1], reverse=True)

    # If top_k is specified, return only the top_k items
    if top_k > 0:
        sorted_list = sorted_list[:top_k]

    return sorted_list


# def format_logits(
#     row: dict, key: Optional[str] = None
# ) -> Dict[str, List[Tuple[str, float]]]:
#     """
#     Format the logits dictionary to remove NaN values and sort by confidence.

#     Args:
#         row (dict): The dictionary containing logits.

#     Returns:
#         dict: A formatted dictionary with sorted logits.
#     """

#     # Sort the dictionary by values in descending order
#     if isinstance(key, str):
#         row = row[key]
#     else:
#         key = ""
#     logits = sort_and_filter_dict(row, top_k=5)

#     return {key: logits}


# row = remove_NaN_values_from_dict(ds[0]["logits"])

# row = ds[2100] #["logits"]
# format_logits(row)
# sort_and_filter_dict(row, top_k=5)


# sorted(row.items(), key=lambda x: x[1], reverse=True)

In [8]:
from more_itertools import flatten
import more_itertools as mit


def select_top_k_unique_logits(df: pd.DataFrame, top_k: int = 5) -> list:
    """
    Select the top k unique logits from the DataFrame.
    """
    assert df.shape == (9, 2)

    # img_name, g = next(iter(dfg))
    logits = sorted(flatten(df["logits"].to_list()), key=lambda x: x[1], reverse=True)
    logits_unique = list(mit.unique_everseen(logits, key=lambda x: x[0]))

    top_k_logits_unique = logits_unique[:top_k]

    return top_k_logits_unique


def groupby_image_select_top_k_unique_logits(
    df: pd.DataFrame, top_k: int = 5
) -> pd.DataFrame:
    """
    Group by image across all image tiles and select the top k unique logits.

    Args:
        df (pd.DataFrame): The DataFrame to process.
        top_k (int): The number of top items to include. Default is 5.
    Returns:
        pd.DataFrame: A DataFrame containing the selected logits.
    """

    if isinstance(df, pd.DataFrame):
        df = df.groupby("image_name")

    return (
        df.apply(select_top_k_unique_logits, top_k=top_k).rename("logits").reset_index()
    )

    # logits = sorted(flatten(df["logits"].to_list()), key=lambda x: x[1], reverse=True)
    # logits = list(mit.unique_everseen(logits, key=lambda x: x[0]))
    # return logits

In [52]:
ddf = pd.Series([0, 1, 2])
ddf
ddf.rename("col_name").reset_index()

0    0
1    1
2    2
dtype: int64

In [9]:
import pandas as pd

top_k = 5


def prepare_submission_csv(ds: HFDataset, top_k: int = 5) -> pd.DataFrame:
    """
    Prepare a submission CSV file from logits saved to disk as a Hugging Face dataset.

    Args:
        ds (HFDataset): The dataset to process.
            Expected columns are  ['image_name', 'embeddings', 'logits', 'tile'].
        top_k (int): The number of top items to include in the submission. Default is 5.

    Returns:
        df (pd.DataFrame): A DataFrame containing the formatted top_k species_id predictions.
            Expected columns are ['quadrat_id', 'species_ids'].


    """
    # Convert the ["image_name", "logits"] columns from the hf dataset in to a pd.DataFrame
    df = ds.remove_columns(["embeddings", "tile"]).to_pandas()

    df = df.assign(
        logits=df.apply(
            lambda x: sort_and_filter_dict(  # Sort species IDs from high-to-low confidence scores remove all but the top_k
                x["logits"], top_k=top_k
            ),
            axis=1,
        )
    )

    df = groupby_image_select_top_k_unique_logits(df, top_k=5)
    df = df.rename(
        columns={
            "image_name": "quadrat_id",
            "logits": "species_ids",
        }
    )
    df = df.assign(
        species_ids=df.apply(
            lambda x: [  # Select only the species IDs
                species_id for species_id, _ in x["species_ids"]
            ],
            axis=1,
        )
    )

    return df

In [22]:
# df

  df.apply(select_top_k_unique_logits, top_k=top_k).rename("logits").reset_index()


Unnamed: 0,image_name,logits
0,2024-CEV3-20240602.jpg,"[(1654010, 0.44266772270202637), (1395063, 0.3..."
1,CBN-PdlC-A1-20130807.jpg,"[(1744569, 0.2301855832338333), (1361917, 0.22..."
2,CBN-PdlC-A1-20130903.jpg,"[(1744569, 0.16917195916175842), (1392608, 0.1..."
3,CBN-PdlC-A1-20140721.jpg,"[(1529289, 0.14910352230072021), (1374758, 0.1..."
4,CBN-PdlC-A1-20140811.jpg,"[(1361281, 0.12936192750930786), (1418612, 0.1..."
...,...,...
2100,RNNB-8-5-20240118.jpg,"[(1361437, 0.7179210782051086), (1655199, 0.52..."
2101,RNNB-8-6-20240118.jpg,"[(1655199, 0.37736761569976807), (1363434, 0.2..."
2102,RNNB-8-7-20240118.jpg,"[(1359297, 0.30361855030059814), (1356521, 0.2..."
2103,RNNB-8-8-20240118.jpg,"[(1359650, 0.3005388379096985), (1396330, 0.28..."


In [11]:
# top_1 = []
# top_2 = []
# top_3 = []
# top_4 = []
# top_5 = []

# for i, row in df.iterrows():
#     top_1.append(row["logits"][0])
#     top_2.append(row["logits"][1])
#     top_3.append(row["logits"][2])
#     top_4.append(row["logits"][3])
#     top_5.append(row["logits"][4])

#     print(i)
#     # pprint(row)

#     if i >= 5:
#         break

# print(f"top_1: {top_1}")
# print(f"top_2: {top_2}")
# print(f"top_3: {top_3}")
# print(f"top_4: {top_4}")
# print(f"top_5: {top_5}")

0
1
2
3
4
5
top_1: [('1654010', 0.44266772270202637), ('1744569', 0.2301855832338333), ('1744569', 0.16917195916175842), ('1529289', 0.14910352230072021), ('1361281', 0.12936192750930786), ('1744569', 0.1826046109199524)]
top_2: [('1395063', 0.3262457549571991), ('1361917', 0.22664865851402283), ('1392608', 0.15455718338489532), ('1374758', 0.12966282665729523), ('1418612', 0.11788784712553024), ('1358492', 0.14449474215507507)]
top_3: [('1392662', 0.3107529878616333), ('1356350', 0.22245322167873383), ('1361382', 0.11053255945444107), ('1402995', 0.08004958182573318), ('1356350', 0.08978088945150375), ('1361524', 0.12258566915988922)]
top_4: [('1414387', 0.2375417947769165), ('1418612', 0.17090369760990143), ('1361068', 0.07495987415313721), ('1741880', 0.07896480709314346), ('1392608', 0.08292622119188309), ('1397565', 0.10479555279016495)]
top_5: [('1743646', 0.17830075323581696), ('1361129', 0.1368836909532547), ('1361971', 0.06213787570595741), ('1362066', 0.072990283370018), ('17

In [12]:
# top_species_ids = [s_id for s_id, _ in [*top_1, *top_2, *top_3, *top_4, *top_5]]


# print(f"len(top_species_ids): {len(top_species_ids)}")
# print(f"len(set(top_species_ids)): {len(set(top_species_ids))}")

len(top_species_ids): 30
len(set(top_species_ids)): 25


In [64]:
# row["logits"][0]

('1654010', 0.44266772270202637)

In [None]:
# dfg = df.groupby("image_name")
# pred_df = prepare_submission_csv(ds, top_k=top_k)
# pred_df.shape

# pred_df.groupby("image_name").describe()
# for k, v in pred_df.groupby("image_name"):
#     print(k)
#     print(v)
#     break

  return df.apply(select_top_k_unique_logits, top_k=top_k)


In [102]:
# Compare manual nested list of lists unpacking vs. more_itertools.flatten

# logits = v["logits"].to_list()
# logits = [kth_logit for tile in logits for kth_logit in tile]
# iter_logits = list(flatten(v["logits"].to_list()))

# for i in range(len(logits)):
#     assert logits[i] == iter_logits[i], f"Mismatch at index {i}: {logits[i]} != {iter_logits[i]}"

34

In [107]:
from rich import print as pprint

# out = [("logits", "logits_unique")]
# for i in range(20):
#     out.append((logits[i], logits_unique[i]))

[1m[[0m
    [1m([0m[32m'logits'[0m, [32m'logits_unique'[0m[1m)[0m,
    [1m([0m[1m([0m[32m'1654010'[0m, [1;36m0.44266772270202637[0m[1m)[0m, [1m([0m[32m'1654010'[0m, [1;36m0.44266772270202637[0m[1m)[0m[1m)[0m,
    [1m([0m[1m([0m[32m'1395063'[0m, [1;36m0.3262457549571991[0m[1m)[0m, [1m([0m[32m'1395063'[0m, [1;36m0.3262457549571991[0m[1m)[0m[1m)[0m,
    [1m([0m[1m([0m[32m'1392662'[0m, [1;36m0.3107529878616333[0m[1m)[0m, [1m([0m[32m'1392662'[0m, [1;36m0.3107529878616333[0m[1m)[0m[1m)[0m,
    [1m([0m[1m([0m[32m'1392662'[0m, [1;36m0.293470174074173[0m[1m)[0m, [1m([0m[32m'1414387'[0m, [1;36m0.2375417947769165[0m[1m)[0m[1m)[0m,
    [1m([0m[1m([0m[32m'1414387'[0m, [1;36m0.2375417947769165[0m[1m)[0m, [1m([0m[32m'1743646'[0m, [1;36m0.17830075323581696[0m[1m)[0m[1m)[0m,
    [1m([0m[1m([0m[32m'1743646'[0m, [1;36m0.17830075323581696[0m[1m)[0m, [1m([0m[32m'1395117'[0m, [1;36m0.

In [11]:
def create_classification_dataframe(
    train_df: pd.DataFrame,
    test_df: pd.DataFrame,
    predictions: np.array,
    similarities: np.array,
):
    """
    Creates a classification DataFrame with Faiss predictions, similarities, and resolved species IDs.

    :param train_df: Train DataFrame with image_name to species_id mapping
    :param test_df: Test DataFrame (contains image_name, data, embeddings, etc.)
    :param predictions: np.array of shape (N, K) with predicted image names
    :param similarities: np.array of shape (N, K) with similarity scores
    :return: DataFrame with columns: predictions, similarities, species_ids
    """
    cls_test_df = test_df.copy()
    cls_test_df["predictions"] = predictions.tolist()
    cls_test_df["similarities"] = similarities.tolist()
    # create lookup dictionary
    image_to_species = dict(zip(train_df["image_name"], train_df["species_id"]))
    # map preds to species_id
    species_ids = []
    for row in cls_test_df["predictions"]:
        row_species = [image_to_species.get(img_name, None) for img_name in row]
        species_ids.append(row_species)
    # add to DataFrame
    cls_test_df["pred_species_ids"] = species_ids
    return cls_test_df

2911

In [17]:
# len({k: v for k, v in ds[1]["logits"].items() if v is not None})
# {k: v for k, v in ds[1]["logits"].items() if v is not None}

{'1359714': 0.015930160880088806,
 '1392540': 0.06060715764760971,
 '1392662': 0.293470174074173,
 '1394523': 0.05156639218330383,
 '1628936': 0.025065019726753235}

In [7]:
from plantclef.pytorch.model import DINOv2LightningModel

top_k = 5
model = DINOv2LightningModel(top_k=top_k)
model.transform

Compose(
    Resize(size=518, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(518, 518))
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

## Explore embeddings

### Get embeddings and logits from model.predict_step

In [39]:
# grid_size = cfg.grid_size

# embeddings, logits = model.predict_step(
#             batch, batch_idx=0
#         )

# embeddings, logits = model.predict_grid_step(
#             batch, batch_idx=0
#         )

# print(embeddings.shape)
# print(len(logits))
# print(embeddings.shape)
# embeddings = embeddings.view(-1, grid_size**2, 768)
# print(embeddings.shape)
# embeddings = embeddings.view(-1, grid_size**2, 768)

# logits = [
#             logits[i : i + grid_size**2] for i in range(0, len(logits), grid_size**2)
#         ]
# print(embeddings.shape)
# print(len(logits))
# print(embeddings.shape)
# print(len(logits))
# print([l.keys() for l in logits])
# logits[0]

torch.Size([4, 9, 768])
4


In [37]:
# print(f"batch_size -- len(logits): {len(logits)}")
# print(f"grid_size**2 -- len(logits[0]): {len(logits[0])}")
# # logits_img0_tile0 = logits[0][0]
# print(f"top_k -- k = len(list(logits[0][0].keys())): {len(list(logits[0][0].keys()))}")

batch_size -- len(logits): 36
grid_size**2 -- len(logits[0]): 5


KeyError: 0

### Get image names from HFDataset -> Create a pandas DataFrame to match image names to logits + embeddings

In [None]:
# embeddings, logits = model.predict_grid_step(
#             batch, batch_idx=0
#         )

In [12]:
# def create_predictions_df(
#     ds: HFPlantDataset, embeddings: torch.Tensor, logits: list
# ) -> pd.DataFrame:
#     """
#     Accepts an HFPlantDataset and a set of embeddings and logits.

#     To be called after the model has been run on the full dataset in ds.

#     Returns a DataFrame with the following columns:
#         - image_name
#         - tile
#         - embeddings
#         - logits
#     The DataFrame is exploded to have one row per tile.

#     """

#     pred_df = pd.DataFrame({"image_name": ds.dataset["file_path"]})
#     pred_df["image_name"] = pred_df["image_name"].str.rsplit("/", n=1, expand=True)[1]

#     pred_df = pred_df.convert_dtypes()

#     pred_df = pred_df.assign(embeddings=embeddings.cpu().tolist(), logits=logits)
#     explode_df = pred_df.explode(["embeddings", "logits"], ignore_index=True)
#     explode_df = explode_df.assign(tile=explode_df.groupby("image_name").cumcount())

#     return explode_df


# pred_ds = HFDataset.from_pandas(explode_df)
# pred_ds.save_to_disk(test_embeddings_path)

In [21]:
# loaded_ds = Dataset.load_from_disk(test_embeddings_path)
# loaded_ds.features["logits"]

In [43]:
import json
import shutil
import numpy as np


def write_embeddings_to_parquet(
    df: pd.DataFrame,
    folder_name: str,
    num_partitions: int = 20,
):
    # path to data
    root = Path().resolve().parents[0]
    data_path = f"{root}/data/embeddings"
    output_path = f"{data_path}/{folder_name}"

    # remove existing data if it exists to avoid duplication
    if Path(output_path).exists():
        shutil.rmtree(output_path, ignore_errors=True)

    # convert logits to json strings
    df["logits"] = df["logits"].apply(json.dumps)

    # assign partition numbers (0 to num_partitions-1)
    df_size = len(df)
    df["partition"] = np.repeat(
        np.arange(num_partitions), np.ceil(df_size / num_partitions)
    )[:df_size]

    # write to parquet using the new partition column
    df.to_parquet(output_path, partition_cols=["partition"], index=False)

    print(
        f"Embedding dataset written to: {output_path} with {num_partitions} partitions."
    )


# write data
# folder_name = f"test_grid_{GRID_SIZE}x{GRID_SIZE}_embeddings"
# write_embeddings_to_parquet(test_explode_df, folder_name, num_partitions=10)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4 entries, 0 to 3
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   file_path   4 non-null      string
 1   embeddings  4 non-null      object
 2   logits      4 non-null      object
dtypes: object(2), string(1)
memory usage: 224.0+ bytes


In [None]:
# ds.plot_image_tiles(idx=50)

# Misc below

In [52]:
# ds.dataset = ds.dataset.take(100)
# # extract embeddings
# embeddings, logits = torch_pipeline(
#     dataset=ds,  # .dataset.take(5),
#     batch_size=2,
#     use_grid=True,
#     cpu_count=1,
# )
# embeddings.shape
# grid_size = 3

# embeddings = embeddings.view(-1, grid_size**2, 768)
# embeddings.shape
# import matplotlib.pyplot as plt

# img = ds._get_image_tensor(0)

# plt.imshow(img.permute(1, 2, 0))

In [None]:
def center_crop(image: torch.Tensor) -> torch.Tensor:
    min_dim = min(image.shape[1:])
    return transforms.CenterCrop(min_dim)(image)

In [32]:
import torch

## Save huggingface test set to disk

In [20]:
# image_list = collect_image_filepaths(dataset_dir)

# ds = Dataset.from_dict({"image": image_list})
# ds = ds.cast_column("image", Image())

# ds.save_to_disk(hf_dataset_dir)

Collecting file paths in /teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/competition-metadata/PlantCLEF2025_test_images/PlantCLEF2025_test_images: 100%|██████████| 2105/2105 [00:00<00:00, 837905.47it/s]
Walking through dir /teamspace/studios/this_studio/plantclef-vision/data/plantclef2025/competition-metadata/PlantCLEF2025_test_images/PlantCLEF2025_test_images: 1it [00:00, 20.22it/s]


Saving the dataset (0/18 shards):   0%|          | 0/2105 [00:00<?, ? examples/s]

In [33]:
# ds_loaded = Dataset.load_from_disk(hf_dataset_dir)

Loading dataset from disk:   0%|          | 0/18 [00:00<?, ?it/s]

In [53]:
def create_transform(image_size: int) -> Callable:
    """Create image transformation pipeline that maintains aspect ratio."""
    transform_list = [
        # transforms.ToPILImage(),
        transforms.Resize(
            image_size, max_size=image_size + 2
        ),  # Maintains aspect ratio
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ]

    return transforms.Compose(transform_list)

In [44]:
# dataset = ds_loaded

# Misc below

In [None]:
import torch
from plantclef.config import get_device
import pandas as pd
from pathlib import Path

print(f"PyTorch Version: {torch.__version__}")
device = get_device()
print(f"Using device: {device}")


# Get list of stored filed in cloud bucket
root = Path().resolve().parents[0]
print(root)
! date

In [1]:
from pathlib import Path


test_parquet_output_dir = "/teamspace/studios/this_studio/plantclef-vision/data/parquet/plantclef2025/full_test"
os.makedirs(test_parquet_output_dir, exist_ok=True)

root = "/teamspace/studios/this_studio/plantclef-vision/data/plantclef2025"
test_image_dir = (
    root + "/competition-metadata/PlantCLEF2025_test_images/PlantCLEF2025_test_images"
)

### Extracting embeddings from single-label training images

We extract embeddings from a small subset of training images to validate our pipeline.  
We don't perform tiling on the train images (we use the full image) and extract 768-dimensional ViT embeddings.

In [None]:
limit_train_df = pd.DataFrame({})

# extract embeddings
embeddings, logits = torch_pipeline(
    limit_train_df,
    batch_size=2,
    use_grid=False,
    cpu_count=1,
)

In [None]:
# embeddings shape
embeddings.shape

In [None]:
# first embedding
embeddings[0][0][:100]  # showing first 100 values out of 768

In [None]:
# create embeddings dataframe
cols = ["image_name", "data", "species", "species_id"]
embeddings_df = limit_train_df[cols].copy()
embeddings_df["embeddings"] = embeddings.tolist()
embeddings_df.head(2)

In [None]:
from plantclef.plotting import plot_images_from_binary

embeddings_df = pd.DataFrame()
plot_images_from_binary(
    embeddings_df,
    data_col="data",
    label_col="species",
    grid_size=(1, 2),
    crop_square=True,
    figsize=(8, 4),
)

In [None]:
from plantclef.plotting import plot_embeddings

plot_embeddings(
    embeddings_df,
    data_col="embeddings",
    label_col="species",
    grid_size=(1, 2),
    figsize=(8, 4),
)

### Embedding test images with tiling (3x3)


Since the test images are high-resolution and contain multiple plant species, we split them into a 3x3 grid of tiles.
- We **extract embeddings** and **top-*K* logits** from each tile using the ViT model.  
- This **patch-wise representation** is critical for enabling multi-label classification.

In [None]:
# set params
USE_GRID = True
GRID_SIZE = 3  # 3x3 grid of tiles
CPU_COUNT = 1  # custom cpu_count
TOP_K = 5  # top-K logits for each tile


test_df = pd.DataFrame({})
test_image_df = pd.DataFrame({})

# select images from test set
image_names = ["CBN-Pyr-03-20230706.jpg", "CBN-can-E6-20230706.jpg"]
test_image_df = test_df[test_df["image_name"].isin(image_names)]

# get embeddings and logits
embeddings, logits = torch_pipeline(
    test_image_df,
    batch_size=2,
    use_grid=USE_GRID,
    grid_size=GRID_SIZE,
    cpu_count=CPU_COUNT,
    top_k=TOP_K,
)

In [None]:
# embeddings shape
embeddings.shape  # (2, 9, 768)

In [14]:
# create embeddings dataframe
def explode_embeddings_logits(
    df: pd.DataFrame,
    embeddings: torch.Tensor,
    logits: list,
    cols: list = ["image_name", "data"],
) -> pd.DataFrame:
    # create dataframe
    pred_df = df[cols].copy()
    pred_df["embeddings"] = embeddings.cpu().tolist()
    pred_df["logits"] = logits
    # explode embeddings
    explode_df = pred_df.explode(["embeddings", "logits"], ignore_index=True)
    # assign tile number for each image
    explode_df["tile"] = explode_df.groupby("image_name").cumcount()
    return explode_df

In [None]:
explode_df = explode_embeddings_logits(test_image_df, embeddings, logits)
explode_df.head(9)

In [None]:
from plantclef.plotting import plot_image_tiles

# show image tiles
plot_image_tiles(
    explode_df,
    data_col="data",
    grid_size=3,
)

In [None]:
from plantclef.plotting import plot_embed_tiles

plot_embed_tiles(
    explode_df,
    data_col="embeddings",
    grid_size=3,
    figsize=(15, 8),
)

In [None]:
# plot grid embeddings
plot_embeddings(
    explode_df,
    data_col="embeddings",
    label_col="tile",
    grid_size=(3, 3),
    figsize=(8, 8),
)

### Analyzing classifier logits per tile

For each tile, we look at the **top predicted species** and associated confidence scores (`logits`).  
This helps interpret how confident the model is in identifying species in each patch.

In [None]:
print(f"Length logits: {len(logits)}")

In [None]:
# display logits of first tile
explode_df["logits"].iloc[0]

In [None]:
# display logits for each tile
for i in range(9):
    logits = explode_df["logits"].iloc[i]
    logits_formatted = {k: round(v, 3) for k, v in logits.items()}
    print(f"Tile {i+1}: {logits_formatted}")

### Embedding the entire test set with tiling

We scale up our embedding pipeline to process the full test dataset using **3x3 tiling**.  
This prepares the data for the downstream tasks of efficient **nearest neighbor search** and **multi-label prediction** at the tile level.

In [None]:
import os

cpu_count = os.cpu_count()
print(f"CPU count: {cpu_count}")

In [None]:
# params
USE_GRID = True
GRID_SIZE = 3  # 3x3 grid of tiles
CPU_COUNT = 1  # custom cpu_count
TOP_K = 5  # top-K logits for each tile

# get embeddings and logits
test_embeddings, test_logits = torch_pipeline(
    test_df,
    batch_size=10,  # 10 imamges per batch
    use_grid=USE_GRID,
    grid_size=GRID_SIZE,
    cpu_count=CPU_COUNT,
    top_k=TOP_K,
)

In [None]:
print(test_embeddings.shape)
print(len(test_logits))

In [25]:
# explode full embeddings and logits
test_explode_df = explode_embeddings_logits(
    test_df,
    test_embeddings,
    test_logits,
)

In [None]:
print(test_explode_df.shape)
test_explode_df.head(9)

In [None]:
plot_embed_tiles(
    test_explode_df,
    data_col="embeddings",
    grid_size=3,
)

### Saving test embeddings and logits to Parquet

We serialize the full test embeddings into partitioned Parquet files for later use in inference pipelines.  
The logits are stored as JSON strings for flexibility.

In [None]:
# def write_embeddings_to_parquet(
#     df: pd.DataFrame,
#     folder_name: str,
#     num_partitions: int = 20,
# ):
#     # path to data
#     root = Path().resolve().parents[0]
#     data_path = f"{root}/data/embeddings"
#     output_path = f"{data_path}/{folder_name}"

#     # remove existing data if it exists to avoid duplication
#     if Path(output_path).exists():
#         shutil.rmtree(output_path, ignore_errors=True)

#     # convert logits to json strings
#     df["logits"] = df["logits"].apply(json.dumps)

#     # assign partition numbers (0 to num_partitions-1)
#     df_size = len(df)
#     df["partition"] = np.repeat(
#         np.arange(num_partitions), np.ceil(df_size / num_partitions)
#     )[:df_size]

#     # write to parquet using the new partition column
#     df.to_parquet(output_path, partition_cols=["partition"], index=False)

#     print(
#         f"Embedding dataset written to: {output_path} with {num_partitions} partitions."
#     )


# # write data
# folder_name = f"test_grid_{GRID_SIZE}x{GRID_SIZE}_embeddings"
# write_embeddings_to_parquet(test_explode_df, folder_name, num_partitions=10)

## Embedding the full training set (no tiling)

We repeat the embedding process on the **full training dataset**, this time *without tiling*.  
This enables us to use the embeddings directly or as a **transfer learning** approach in a Faiss-based nearest neighbor retrieval system.

In [None]:
# params
USE_GRID = False
CPU_COUNT = 1  # custom cpu_count
TOP_K = 5  # top-K logits for each tile

train_df = pd.DataFrame({})

# get embeddings and logits
train_embeddings, train_logits = torch_pipeline(
    train_df,
    batch_size=20,  # 20 imamges per batch
    use_grid=USE_GRID,
    cpu_count=CPU_COUNT,
    top_k=TOP_K,
)

In [None]:
print(train_embeddings.shape)
print(len(train_logits))

In [31]:
# explode full embeddings and logits
train_explode_df = explode_embeddings_logits(
    train_df,
    train_embeddings,
    train_logits,
    cols=["image_name", "data", "species", "species_id"],
)

In [None]:
train_explode_df.head(5)

In [None]:
from plantclef.plotting import plot_single_image_embeddings

plot_single_image_embeddings(
    train_explode_df,
    num_images=2,
    figsize=(8, 10),
)

### Saving the training embeddings to Parquet

Finally, we save the full training embeddings in partitioned Parquet format to support fast, distributed retrieval during inference.

In [None]:
# write data
folder_name = "train_embeddings"
write_embeddings_to_parquet(train_explode_df, folder_name, num_partitions=20)

### Embeddings Ready for Downstream Use

We now have rich ViT embeddings for both train and test datasets, ready for use in:
- Multi-label classification
- Retrieval-based inference
- Nearest Neighbor Search

In [None]:
# path to data
data_path = f"{root}/data/embeddings"
# output_path = f"{data_path}/test_grid_3x3_embeddings"
output_path = f"{data_path}/train_embeddings"

train_emb_df = pd.read_parquet(output_path)
print(train_emb_df.shape)
train_emb_df.head(5)

In [None]:
output_path = f"{data_path}/test_grid_3x3_embeddings"
test_grid_df = pd.read_parquet(output_path)
print(test_grid_df.shape)
test_grid_df.head(5)