In [None]:
from datetime import datetime
import os
import json
import yaml
from pathlib import Path
from types import SimpleNamespace
import argparse

import torch
from torchvision import transforms

import numpy as np
import pandas as pd
import torch

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torchvision import transforms as tfms
import torchvision.transforms as T

from typing import Sequence, Tuple, Any, Dict, List, Optional, Union
import importlib

import numpy as np
from sklearn.metrics import top_k_accuracy_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# path to fungitatsic dataset
data_path = Path('~/datasets/fungiclef2025/').expanduser().resolve()
# data_path = '/kaggle/input/fungi-clef-2025/'

In [None]:
class FungiTastic(torch.nn.Module):
    """
    Dataset class for the FewShot subset of the Danish Fungi dataset (size 300, closed-set).

    This dataset loader supports training, validation, and testing splits, and provides
    convenient access to images, class IDs, and file paths. It also supports optional
    image transformations.
    """

    SPLIT2STR = {'train': 'Train', 'val': 'Val', 'test': 'Test'}

    def __init__(self, root: str, split: str = 'val', transform=None):
        """
        Initializes the FungiTastic dataset.

        Args:
            root (str): The root directory of the dataset.
            split (str, optional): The dataset split to use. Must be one of {'train', 'val', 'test'}.
                Defaults to 'val'.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        super().__init__()
        self.split = split
        self.transform = transform
        self.df = self._get_df(root, split)

        assert "image_path" in self.df
        if self.split != 'test':
            assert "category_id" in self.df
            self.n_classes = len(self.df['category_id'].unique())
            self.category_id2label = {
                k: v[0] for k, v in self.df.groupby('category_id')['species'].unique().to_dict().items()
            }
            self.label2category_id = {
                v: k for k, v in self.category_id2label.items()
            }

    def add_embeddings(self, embeddings: pd.DataFrame):
        """
        Updates the dataset instance with new embeddings.
    
        Args:
            embeddings (pd.DataFrame): A DataFrame containing 'filename', 'transformation', 
                                      and 'embedding' columns.
        """
        assert isinstance(embeddings, pd.DataFrame), "Embeddings must be a pandas DataFrame."
        assert "embedding" in embeddings.columns, "Embeddings DataFrame must have an 'embedding' column."
        assert "transformation" in embeddings.columns, "Embeddings DataFrame must have a 'transformation' column."
        
        # Merge on both filename and transformation
        self.df = pd.merge(self.df, embeddings, on=["filename"], how="left")
        
        # Make sure we have embeddings for at least the original images
        assert not self.df[self.df["transformation"] == "original"]["embedding"].isna().any(), \
            "Missing embeddings for some original images"

    def get_embeddings_for_class(self, id):
        # return the embeddings for class class_idx
        class_idxs = self.df[self.df['category_id'] == id].index
        return self.df.iloc[class_idxs]['embedding']
    
    @staticmethod
    def _get_df(data_path: str, split: str) -> pd.DataFrame:
        """
        Loads the dataset metadata as a pandas DataFrame.

        Args:
            data_path (str): The root directory where the dataset is stored.
            split (str): The dataset split to load. Must be one of {'train', 'val', 'test'}.

        Returns:
            pd.DataFrame: A DataFrame containing metadata and file paths for the split.
        """
        df_path = os.path.join(
            data_path,
            "metadata",
            "FungiTastic-FewShot",
            f"FungiTastic-FewShot-{FungiTastic.SPLIT2STR[split]}.csv"
        )
        df = pd.read_csv(df_path)
        df["image_path"] = df.filename.apply(
            lambda x: os.path.join(data_path, "FungiTastic-FewShot", split, '500p', x)  # TODO: 300p to fullsize if different embedder that can handle it
        )
        return df

    def __getitem__(self, idx: int):
        """
        Retrieves a single data sample by index.
    
        Args:
            idx (int): Index of the sample to retrieve.
            ret_image (bool, optional): Whether to explicitly return the image. Defaults to False.
    
        Returns:
            tuple:
                - If embeddings exist: (image?, embedding, category_id, file_path)
                - If no embeddings: (image, category_id, file_path) (original version)
        """
        file_path = self.df["image_path"].iloc[idx].replace('FungiTastic-FewShot', 'images/FungiTastic-FewShot')
    
        if self.split != 'test':
            category_id = self.df["category_id"].iloc[idx]
        else:
            category_id = None

        image = Image.open(file_path)
    
        if self.transform:
            image = self.transform(image)
    
        # Check if embeddings exist
        if "embedding" in self.df.columns:
            emb = torch.tensor(self.df.iloc[idx]['embedding'], dtype=torch.float32).squeeze()
        else:
            emb = None  # No embeddings available
    

        return image, category_id, file_path, emb


    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return len(self.df)

    def get_class_id(self, idx: int) -> int:
        """
        Returns the class ID of a specific sample.
        """
        return self.df["category_id"].iloc[idx]

    def show_sample(self, idx: int) -> None:
        """
        Displays a sample image along with its class name and index.
        """
        image, category_id, _, _ = self.__getitem__(idx)
        class_name = self.category_id2label[category_id]

        plt.imshow(image)
        plt.title(f"Class: {class_name}; id: {idx}")
        plt.axis('off')
        plt.show()

    def get_category_idxs(self, category_id: int) -> List[int]:
        """
        Retrieves all indexes for a given category ID.
        """
        return self.df[self.df.category_id == category_id].index.tolist()

In [None]:
### Load the datasets

train_dataset = FungiTastic(root=data_path, split='train', transform=None)
val_dataset = FungiTastic(root=data_path, split='val', transform=None)
test_dataset = FungiTastic(root=data_path, split='test', transform=None)

# train_dataset.df.head(5)

In [None]:
# test_dataset.df.image_path.to_numpy()[0]

In [None]:
# test_dataset.df.head(20)

## Loading, saving, computing embeddings

In [None]:
exp_name = "multimodel_cache_fungiclef25"

In [None]:
from pathlib import Path
import json
    
def save_artifacts(exp_name, train_dataset, val_dataset, test_dataset, config, overwrite=False):
    file = Path(f"numpy_embed_dims_{exp_name}.npy")
    if file.exists() and not overwrite:
        raise FileExistsError("overwrite is False and artifacts exist.")
    embed_dims = test_dataset.df.emb_dims.iloc[0]
    np.save(f"numpy_embed_dims_{exp_name}.npy", embed_dims)
    train_dataset.df.to_csv(f"train_df_{exp_name}.csv", index=None)
    val_dataset.df.to_csv(f"val_df_{exp_name}.csv", index=None)
    test_dataset.df.to_csv(f"test_df_{exp_name}.csv", index=None)
    np.save(f"train_numpy_embedding_{exp_name}.npy", train_dataset.df.embedding.to_numpy())
    np.save(f"val_numpy_embedding_{exp_name}.npy", val_dataset.df.embedding.to_numpy())
    np.save(f"test_numpy_embedding_{exp_name}.npy", test_dataset.df.embedding.to_numpy())
    with open(f"config_{exp_name}.json", "w") as f:
        json.dump(config, f, sort_keys=True, indent=4)

def load_artifacts(exp_name):
    train_df = pd.read_csv(f"train_df_{exp_name}.csv")
    val_df = pd.read_csv(f"val_df_{exp_name}.csv")
    test_df = pd.read_csv(f"test_df_{exp_name}.csv")
    embed_dims = np.load(f"numpy_embed_dims_{exp_name}.npy", allow_pickle=True)
    train_df['embed_dims'] = train_df.apply(lambda row: embed_dims, axis=1)
    val_df['embed_dims'] = val_df.apply(lambda row: embed_dims, axis=1)
    test_df['embed_dims'] = test_df.apply(lambda row: embed_dims, axis=1)
    train_embeddings = np.load(f"train_numpy_embedding_{exp_name}.npy", allow_pickle=True)
    val_embeddings = np.load(f"val_numpy_embedding_{exp_name}.npy", allow_pickle=True)
    test_embeddings = np.load(f"test_numpy_embedding_{exp_name}.npy", allow_pickle=True)
    train_df["embedding"] = train_embeddings
    val_df["embedding"] = val_embeddings
    test_df["embedding"] = test_embeddings
    return train_df, val_df, test_df

In [None]:
train_dataset = FungiTastic(root=data_path, split='train', transform=None)
val_dataset = FungiTastic(root=data_path, split='val', transform=None)
test_dataset = FungiTastic(root=data_path, split='test', transform=None)
train_dataset.df, val_dataset.df, test_dataset.df = load_artifacts(exp_name)
train_dataset.df_bak, val_dataset.df_bak, test_dataset.df_bak = train_dataset.df.copy(), val_dataset.df.copy(), test_dataset.df.copy()
embed_dims = np.load(f"numpy_embed_dims_{exp_name}.npy", allow_pickle=True)
with open(f"config_{exp_name}.json", 'r') as file:
    config = json.load(file)
config["emb_dims"] = embed_dims

In [None]:
dinov2L_exp_name = "dinov2L_cache"
dinov2L_train_dataset = FungiTastic(root=data_path, split='train', transform=None)
dinov2L_val_dataset = FungiTastic(root=data_path, split='val', transform=None)
dinov2L_test_dataset = FungiTastic(root=data_path, split='test', transform=None)
dinov2L_train_dataset.df, dinov2L_val_dataset.df, dinov2L_test_dataset.df = load_artifacts(dinov2L_exp_name)
dinov2L_embed_dims = np.load(f"numpy_embed_dims_{dinov2L_exp_name}.npy", allow_pickle=True)
with open(f"config_{dinov2L_exp_name}.json", 'r') as file:
    dinov2L_config = json.load(file)
dinov2L_config["emb_dims"] = dinov2L_embed_dims

## Combine the datasets

In [None]:
config["models"]

In [None]:
config

In [None]:
dinov2L_config

In [None]:
model_dims = {m:ed for m, ed in zip(config['models'], config['emb_dims'])}

start_indices = {}
cumulative_dim = 0
for model_name, dim in model_dims.items():
    start_indices[model_name] = cumulative_dim
    cumulative_dim += dim

def merge_embeddings(df, dinoL_df, config, keep_dinov2b=False):
    if not keep_dinov2b:
        models = [model for model in config["models"] if not model.startswith("DINO")]
        print(f"keeping {models}")
        keep_slices = []
        for model in models:
            start_idx = start_indices[model]
            end_idx = model_dims[model] + start_idx
            keep_slices.append([start_idx, end_idx])
        df["embedding"] = df["embedding"].apply(lambda emb: get_combined_embedding(emb, keep_slices))

    df_embedding = np.vstack(df["embedding"].to_numpy())
    dinov2L_df_embedding = np.vstack(dinoL_df["embedding"].to_numpy())

    print(df_embedding.shape)
    print(dinov2L_df_embedding.shape)
    
    combined = np.concatenate([df_embedding, dinov2L_df_embedding], axis=-1)

    print(combined.shape)

    df["embedding"] = [c for c in combined]
    
    return df

def get_combined_embedding(emb, keep_slices):
    model_embeddings = [emb[...,start:end] for start, end in keep_slices]
    return np.concatenate(model_embeddings, axis=-1)

def update_config(config, dinov2L_config, keep_dinov2b=False):
    updated_models = []
    updated_emb_dims = []
    for model, emb_dim in zip(config["models"], config["emb_dims"]):
        if model.startswith("DINO") and not keep_dinov2b:
            continue
        else:
            updated_models.append(model)
            updated_emb_dims.append(emb_dim)
    updated_models.append(dinov2L_config["models"][0])
    updated_emb_dims.append(dinov2L_config["emb_dims"][0])
    config["models"] = updated_models
    config["emb_dims"] = updated_emb_dims
    return config

In [None]:
model_dims

In [None]:
np.vstack(train_dataset.df.embedding.to_numpy()).shape

In [None]:
np.vstack(dinov2L_train_dataset.df.embedding.to_numpy()).shape

In [None]:
np.concatenate([np.vstack(train_dataset.df["embedding"].to_numpy()), np.vstack(dinov2L_train_dataset.df["embedding"].to_numpy())], axis=-1).shape

In [None]:
train_dataset.df = merge_embeddings(train_dataset.df, dinov2L_train_dataset.df, config, keep_dinov2b=True)
val_dataset.df = merge_embeddings(val_dataset.df, dinov2L_val_dataset.df, config, keep_dinov2b=True)
test_dataset.df = merge_embeddings(test_dataset.df, dinov2L_test_dataset.df, config, keep_dinov2b=True)

In [None]:
train_dataset.df.embedding.to_numpy()[0].shape

In [None]:
dinov2L_train_dataset.df.embedding.to_numpy()[0].shape

In [None]:
config = update_config(config, dinov2L_config, keep_dinov2b=True)

In [None]:
config

In [None]:
train_dataset.df["embed_dims"] = [config["emb_dims"] for i in range(len(train_dataset.df))]
train_dataset.df["emb_dims"] = [config["emb_dims"] for i in range(len(train_dataset.df))]
val_dataset.df["embed_dims"] = [config["emb_dims"] for i in range(len(val_dataset.df))]
val_dataset.df["emb_dims"] = [config["emb_dims"] for i in range(len(val_dataset.df))]
test_dataset.df["embed_dims"] = [config["emb_dims"] for i in range(len(test_dataset.df))]
test_dataset.df["emb_dims"] = [config["emb_dims"] for i in range(len(test_dataset.df))]

In [None]:
train_dataset.df_bak, val_dataset.df_bak, test_dataset.df_bak = train_dataset.df.copy(), val_dataset.df.copy(), test_dataset.df.copy()

# Delete dinov2L from memory

In [None]:
del dinov2L_train_dataset, dinov2L_val_dataset, dinov2L_test_dataset

## Merge SAM-H

In [None]:
samh_exp_name = "SAMH_cache"
samh_train_dataset = FungiTastic(root=data_path, split='train', transform=None)
samh_val_dataset = FungiTastic(root=data_path, split='val', transform=None)
samh_test_dataset = FungiTastic(root=data_path, split='test', transform=None)
samh_train_dataset.df, samh_val_dataset.df, samh_test_dataset.df = load_artifacts(samh_exp_name)
samh_embed_dims = np.load(f"numpy_embed_dims_{samh_exp_name}.npy", allow_pickle=True)
with open(f"config_{samh_exp_name}.json", 'r') as file:
    samh_config = json.load(file)
samh_config["emb_dims"] = samh_embed_dims

In [None]:
def merge_embeddings(df, dinoL_df, config, keep_dinov2b=False):

    print("all transforms", df.shape)
    extra_transforms_df = df[~(df["transformation"].isin(dinoL_df["transformation"].unique()))]
    df = df[df["transformation"].isin(dinoL_df["transformation"].unique())]
    print("transforms not in samh", extra_transforms_df.shape)
    print("transforms in samh", df.shape)
    
    if not keep_dinov2b:
        models = [model for model in config["models"] if not model.startswith("DINO")]
        print(f"keeping {models}")
        keep_slices = []
        for model in models:
            start_idx = start_indices[model]
            end_idx = model_dims[model] + start_idx
            keep_slices.append([start_idx, end_idx])
        df["embedding"] = df["embedding"].apply(lambda emb: get_combined_embedding(emb, keep_slices))

    df_embedding = np.vstack(df["embedding"].to_numpy())
    dinov2L_df_embedding = np.vstack(dinoL_df["embedding"].to_numpy())

    print(df_embedding.shape)
    print(dinov2L_df_embedding.shape)
    
    combined = np.concatenate([df_embedding, dinov2L_df_embedding], axis=-1)

    print(combined.shape)

    df["embedding"] = [c for c in combined]

    df = pd.concat([df, extra_transforms_df],ignore_index=True)
    print("merged", df.shape)
    
    return df

In [None]:
train_dataset.df = merge_embeddings(train_dataset.df, samh_train_dataset.df, config, keep_dinov2b=True)
val_dataset.df = merge_embeddings(val_dataset.df, samh_val_dataset.df, config, keep_dinov2b=True)
test_dataset.df = merge_embeddings(test_dataset.df, samh_test_dataset.df, config, keep_dinov2b=True)

In [None]:
config = update_config(config, samh_config, keep_dinov2b=True)

In [None]:
config

In [None]:
train_dataset.df["embed_dims"] = [config["emb_dims"] for i in range(len(train_dataset.df))]
train_dataset.df["emb_dims"] = [config["emb_dims"] for i in range(len(train_dataset.df))]
val_dataset.df["embed_dims"] = [config["emb_dims"] for i in range(len(val_dataset.df))]
val_dataset.df["emb_dims"] = [config["emb_dims"] for i in range(len(val_dataset.df))]
test_dataset.df["embed_dims"] = [config["emb_dims"] for i in range(len(test_dataset.df))]
test_dataset.df["emb_dims"] = [config["emb_dims"] for i in range(len(test_dataset.df))]

In [None]:
train_dataset.df_bak, val_dataset.df_bak, test_dataset.df_bak = train_dataset.df.copy(), val_dataset.df.copy(), test_dataset.df.copy()

In [None]:
del samh_train_dataset, samh_val_dataset, samh_test_dataset

In [None]:
save_artifacts("multimodel_cache_Dinov2L_SAMH", train_dataset, val_dataset, test_dataset, 
               {k:v for k, v in config.items() if k != "embed_dims"})