In [None]:
from datetime import datetime
import os
import json
import yaml
from pathlib import Path
from types import SimpleNamespace
import argparse

import torch
from torchvision import transforms

import numpy as np
import pandas as pd
import torch

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torchvision import transforms as tfms
import torchvision.transforms as T

from typing import Sequence, Tuple, Any, Dict, List, Optional, Union
import importlib

import numpy as np
from sklearn.metrics import top_k_accuracy_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# path to fungitatsic dataset
data_path = Path('~/datasets/fungiclef2025/').expanduser().resolve()
# data_path = '/kaggle/input/fungi-clef-2025/'

In [None]:
class FungiTastic(torch.nn.Module):
    """
    Dataset class for the FewShot subset of the Danish Fungi dataset (size 300, closed-set).

    This dataset loader supports training, validation, and testing splits, and provides
    convenient access to images, class IDs, and file paths. It also supports optional
    image transformations.
    """

    SPLIT2STR = {'train': 'Train', 'val': 'Val', 'test': 'Test'}

    def __init__(self, root: str, split: str = 'val', transform=None):
        """
        Initializes the FungiTastic dataset.

        Args:
            root (str): The root directory of the dataset.
            split (str, optional): The dataset split to use. Must be one of {'train', 'val', 'test'}.
                Defaults to 'val'.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        super().__init__()
        self.split = split
        self.transform = transform
        self.df = self._get_df(root, split)

        assert "image_path" in self.df
        if self.split != 'test':
            assert "category_id" in self.df
            self.n_classes = len(self.df['category_id'].unique())
            self.category_id2label = {
                k: v[0] for k, v in self.df.groupby('category_id')['species'].unique().to_dict().items()
            }
            self.label2category_id = {
                v: k for k, v in self.category_id2label.items()
            }

    # def add_embeddings(self, embeddings: pd.DataFrame):
    #     """
    #     Updates the dataset instance with new embeddings.

    #     Args:
    #         embeddings (pd.DataFrame): A DataFrame containing an 'embedding' column.
    #                                    It must align with `self.df` in terms of indexing.
    #     """
    #     assert isinstance(embeddings, pd.DataFrame), "Embeddings must be a pandas DataFrame."
    #     assert "embedding" in embeddings.columns, "Embeddings DataFrame must have an 'embedding' column."
    #     assert len(embeddings) == len(self.df), "Embeddings must match dataset length."

    #     self.df = pd.merge(self.df, embeddings, on="filename", how="inner")

    def add_embeddings(self, embeddings: pd.DataFrame):
        """
        Updates the dataset instance with new embeddings.
    
        Args:
            embeddings (pd.DataFrame): A DataFrame containing 'filename', 'transformation', 
                                      and 'embedding' columns.
        """
        assert isinstance(embeddings, pd.DataFrame), "Embeddings must be a pandas DataFrame."
        assert "embedding" in embeddings.columns, "Embeddings DataFrame must have an 'embedding' column."
        assert "transformation" in embeddings.columns, "Embeddings DataFrame must have a 'transformation' column."
        
        # Merge on both filename and transformation
        self.df = pd.merge(self.df, embeddings, on=["filename"], how="left")
        
        # Make sure we have embeddings for at least the original images
        assert not self.df[self.df["transformation"] == "original"]["embedding"].isna().any(), \
            "Missing embeddings for some original images"

    def get_embeddings_for_class(self, id):
        # return the embeddings for class class_idx
        class_idxs = self.df[self.df['category_id'] == id].index
        return self.df.iloc[class_idxs]['embedding']
    
    @staticmethod
    def _get_df(data_path: str, split: str) -> pd.DataFrame:
        """
        Loads the dataset metadata as a pandas DataFrame.

        Args:
            data_path (str): The root directory where the dataset is stored.
            split (str): The dataset split to load. Must be one of {'train', 'val', 'test'}.

        Returns:
            pd.DataFrame: A DataFrame containing metadata and file paths for the split.
        """
        df_path = os.path.join(
            data_path,
            "metadata",
            "FungiTastic-FewShot",
            f"FungiTastic-FewShot-{FungiTastic.SPLIT2STR[split]}.csv"
        )
        df = pd.read_csv(df_path)
        df["image_path"] = df.filename.apply(
            lambda x: os.path.join(data_path, "FungiTastic-FewShot", split, '500p', x)  # TODO: 300p to fullsize if different embedder that can handle it
        )
        return df

    def __getitem__(self, idx: int):
        """
        Retrieves a single data sample by index.
    
        Args:
            idx (int): Index of the sample to retrieve.
            ret_image (bool, optional): Whether to explicitly return the image. Defaults to False.
    
        Returns:
            tuple:
                - If embeddings exist: (image?, embedding, category_id, file_path)
                - If no embeddings: (image, category_id, file_path) (original version)
        """
        file_path = self.df["image_path"].iloc[idx].replace('FungiTastic-FewShot', 'images/FungiTastic-FewShot')
    
        if self.split != 'test':
            category_id = self.df["category_id"].iloc[idx]
        else:
            category_id = None

        image = Image.open(file_path)
    
        if self.transform:
            image = self.transform(image)
    
        # Check if embeddings exist
        if "embedding" in self.df.columns:
            emb = torch.tensor(self.df.iloc[idx]['embedding'], dtype=torch.float32).squeeze()
        else:
            emb = None  # No embeddings available
    

        return image, category_id, file_path, emb


    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return len(self.df)

    def get_class_id(self, idx: int) -> int:
        """
        Returns the class ID of a specific sample.
        """
        return self.df["category_id"].iloc[idx]

    def show_sample(self, idx: int) -> None:
        """
        Displays a sample image along with its class name and index.
        """
        image, category_id, _, _ = self.__getitem__(idx)
        class_name = self.category_id2label[category_id]

        plt.imshow(image)
        plt.title(f"Class: {class_name}; id: {idx}")
        plt.axis('off')
        plt.show()

    def get_category_idxs(self, category_id: int) -> List[int]:
        """
        Retrieves all indexes for a given category ID.
        """
        return self.df[self.df.category_id == category_id].index.tolist()

In [None]:
### Load the datasets

train_dataset = FungiTastic(root=data_path, split='train', transform=None)
val_dataset = FungiTastic(root=data_path, split='val', transform=None)
test_dataset = FungiTastic(root=data_path, split='test', transform=None)

# train_dataset.df.head(5)

In [None]:
# test_dataset.df.image_path.to_numpy()[0]

In [None]:
# test_dataset.df.head(20)

## Loading, saving, computing embeddings

In [None]:
exp_name = "multimodel_cache_Dinov2L_SAMH"

In [None]:
from pathlib import Path
import json
    
def save_artifacts(exp_name, train_dataset, val_dataset, test_dataset, config, overwrite=False):
    file = Path(f"numpy_embed_dims_{exp_name}.npy")
    if file.exists() and not overwrite:
        raise FileExistsError("overwrite is False and artifacts exist.")
    embed_dims = test_dataset.df.emb_dims.iloc[0]
    np.save(f"numpy_embed_dims_{exp_name}.npy", embed_dims)
    train_dataset.df.to_csv(f"train_df_{exp_name}.csv", index=None)
    val_dataset.df.to_csv(f"val_df_{exp_name}.csv", index=None)
    test_dataset.df.to_csv(f"test_df_{exp_name}.csv", index=None)
    np.save(f"train_numpy_embedding_{exp_name}.npy", train_dataset.df.embedding.to_numpy())
    np.save(f"val_numpy_embedding_{exp_name}.npy", val_dataset.df.embedding.to_numpy())
    np.save(f"test_numpy_embedding_{exp_name}.npy", test_dataset.df.embedding.to_numpy())
    with open(f"config_{exp_name}.json", "w") as f:
        json.dump(config, f, sort_keys=True, indent=4)

def load_artifacts(exp_name):
    train_df = pd.read_csv(f"train_df_{exp_name}.csv")
    val_df = pd.read_csv(f"val_df_{exp_name}.csv")
    test_df = pd.read_csv(f"test_df_{exp_name}.csv")
    embed_dims = np.load(f"numpy_embed_dims_{exp_name}.npy", allow_pickle=True)
    train_df['embed_dims'] = train_df.apply(lambda row: embed_dims, axis=1)
    val_df['embed_dims'] = val_df.apply(lambda row: embed_dims, axis=1)
    test_df['embed_dims'] = test_df.apply(lambda row: embed_dims, axis=1)
    train_embeddings = np.load(f"train_numpy_embedding_{exp_name}.npy", allow_pickle=True)
    val_embeddings = np.load(f"val_numpy_embedding_{exp_name}.npy", allow_pickle=True)
    test_embeddings = np.load(f"test_numpy_embedding_{exp_name}.npy", allow_pickle=True)
    train_df["embedding"] = train_embeddings
    val_df["embedding"] = val_embeddings
    test_df["embedding"] = test_embeddings
    return train_df, val_df, test_df

In [None]:
train_dataset = FungiTastic(root=data_path, split='train', transform=None)
val_dataset = FungiTastic(root=data_path, split='val', transform=None)
test_dataset = FungiTastic(root=data_path, split='test', transform=None)
train_dataset.df, val_dataset.df, test_dataset.df = load_artifacts(exp_name)
train_dataset.df_bak, val_dataset.df_bak, test_dataset.df_bak = train_dataset.df.copy(), val_dataset.df.copy(), test_dataset.df.copy()
embed_dims = np.load(f"numpy_embed_dims_{exp_name}.npy", allow_pickle=True)
with open(f"config_{exp_name}.json", 'r') as file:
    config = json.load(file)
config["emb_dims"] = embed_dims

## Manipulating the datasets (reload to clear manipulations)

In [None]:
def reset_dfs(dataset_list):
    """reset dfs to their original state before manipulation"""
    for dataset in dataset_list:
        dataset.df = dataset.df_bak.copy()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy, normalize, softmax
import copy
import math

class EmbeddingDataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe
        self.embeddings = dataframe['embedding'].values
        self.labels = dataframe['category_id'].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        emb = torch.tensor(self.embeddings[idx], dtype=torch.float32).squeeze()
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return emb, label

In [None]:
class ProjectionModel(torch.nn.Module):
    def __init__(self, input_dim, embedder_dims, projection_dim=512, use_layernorm=False, use_dropout=False, dropout_rate=0.1, 
                 internal_dim=1024, use_attention=False, attention_dim=512, extra_layer=False):
        super().__init__()
        self.use_attention = use_attention
        self.num_embedders = len(embedder_dims) if use_attention else None
        self.embedder_dims = embedder_dims if use_attention else None
        self.fixed_dim = attention_dim

        if use_attention:
            self.attn_weights = torch.nn.Parameter(torch.ones(self.num_embedders))
            # ensures we can stack the embeddings
            self.attn_projections = nn.ModuleList([
                torch.nn.Sequential(torch.nn.Linear(emb_dim, self.fixed_dim), torch.nn.ReLU()) for emb_dim in embedder_dims
            ])
        
        layers = [
            torch.nn.Linear(self.fixed_dim if use_attention else input_dim, internal_dim * 2 if extra_layer else internal_dim),
            torch.nn.ReLU()
        ]

        if use_layernorm:
            layers.append(torch.nn.LayerNorm(internal_dim))
        if use_dropout:
            layers.append(torch.nn.Dropout(dropout_rate))

        if extra_layer:
            layers.extend([torch.nn.Linear(internal_dim * 2, internal_dim), torch.nn.ReLU()])

        layers.append(torch.nn.Linear(internal_dim, projection_dim))

        if use_layernorm:
            layers.append(torch.nn.LayerNorm(projection_dim))

        self.projection = torch.nn.Sequential(*layers)

    def forward(self, x):
        if self.use_attention:
            if x.dim() == 3 and x.shape[1] == 1:
                x = x.squeeze(1)  # Squeeze out the second dimension
            # split the embeddings out
            start = 0
            embeddings = []
            for emb_dim, proj_layer in zip(self.embedder_dims, self.attn_projections):
                embedding = x[:, start:start + emb_dim]
                projected_embedding = proj_layer(embedding)
                embeddings.append(projected_embedding)
                start += emb_dim
            weights = softmax(self.attn_weights, dim=0)
            # for embed in embeddings:
            #     print(embed.shape)
            x = torch.stack([w * e for w, e in zip(weights, embeddings)], dim=0).sum(dim=0)
            # print("after attention projections:", x.shape)
            x = self.projection(x)
            # print("after final projection:", x.shape)
        else:
            x = self.projection(normalize(x, p=2, dim=-1))
        # x = self.projection(x)
        return normalize(x, p=2, dim=-1)

class CosineClassifier(torch.nn.Module):
    """Classifier to train the ProjectionModel"""
    def __init__(self, embed_dim, num_classes, scale=10.0):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(num_classes, embed_dim))
        self.scale = scale  # Optional learnable scaling
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        # x: [B, D]
        x = normalize(x, p=2, dim=-1)
        w = normalize(self.weight, p=2, dim=-1)
        return self.scale * torch.matmul(x, w.T)

In [None]:
import torch
import copy
import torch.nn.functional as F
from pytorch_metric_learning.losses import NTXentLoss


def train(model, classifier, train_loader, val_loader, num_epochs=300, patience=5, lr=1e-5, device='cuda'):
    return _train_ce_infonce(model, classifier, train_loader, val_loader, num_epochs, patience, lr, device)

class LearnableLossWeighting(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize log variances as learnable parameters
        self.log_sigma_ce = torch.nn.Parameter(torch.tensor(0.0))
        self.log_sigma_triplet = torch.nn.Parameter(torch.tensor(0.0))

    def forward(self, ce_loss, triplet_loss):
        # From Kendall et al. CVPR 2018
        loss = (
            torch.exp(-self.log_sigma_ce) * ce_loss +
            torch.exp(-self.log_sigma_triplet) * triplet_loss +
            self.log_sigma_ce + self.log_sigma_triplet
        )
        return 0.5 * loss

def _train_ce_infonce(
    model,
    classifier,
    train_loader,
    val_loader,
    num_epochs,
    patience,
    lr,
    device,
    lambda_triplet=None
):
    model.to(device)
    classifier.to(device)
    if lambda_triplet == "learned" or lambda_triplet is None:
        loss_weighter = LearnableLossWeighting().to(device)
        optimizer = torch.optim.AdamW(
            list(model.parameters()) +
            list(classifier.parameters()) +
            list(loss_weighter.parameters()), 
            lr=lr, weight_decay=1e-4
        )
    elif isinstance(lambda_triplet, float) or isinstance(lambda_triplet, int):
        optimizer = torch.optim.AdamW(
            list(model.parameters()) + list(classifier.parameters()), 
            lr=lr, 
            weight_decay=1e-4
        )

    infonce_loss_func = NTXentLoss(temperature=0.07).to(device)

    best_val_loss = float('inf')
    best_model_state = None
    best_classifier_state = None
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        classifier.train()
        total_loss = 0.0
        ce_loss_total = 0.0
        triplet_loss_total = 0.0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            embeddings = model(x)
            logits = classifier(embeddings)

            # Cross-entropy
            ce_loss = F.cross_entropy(logits, y)

            infonce_loss = infonce_loss_func(embeddings, y)

            # Combined loss
            if lambda_triplet == "learned" or lambda_triplet is None:
                loss = loss_weighter(ce_loss, infonce_loss)
            elif isinstance(lambda_triplet, float) or isinstance(lambda_triplet, int):
                loss = ce_loss + lambda_triplet * infonce_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            ce_loss_total += ce_loss.item()
            triplet_loss_total += infonce_loss.item()

        # Validation
        model.eval()
        classifier.eval()
        val_loss = 0.0

        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                embeddings = model(x)
                logits = classifier(embeddings)
                ce_loss = F.cross_entropy(logits, y)

                infonce_loss = infonce_loss_func(embeddings, y)

                if lambda_triplet == "learned" or lambda_triplet is None:
                    val_loss += (loss_weighter(ce_loss, infonce_loss)).item()
                elif isinstance(lambda_triplet, float) or isinstance(lambda_triplet, int):
                    val_loss += (ce_loss + lambda_triplet * infonce_loss).item()

        avg_val_loss = val_loss / len(val_loader)
        print(
            f"Epoch {epoch+1:3d} | "
            f"Train CE: {ce_loss_total:.4f} | "
            f"InfoNCE: {triplet_loss_total:.4f} | "
            f"Total: {total_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}"
        )

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            best_classifier_state = copy.deepcopy(classifier.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break

    model.load_state_dict(best_model_state)
    classifier.load_state_dict(best_classifier_state)
    return model, classifier

## Fit projection network

## Top-5 Accuracy against validation set

In [None]:
class PrototypeClassifier(torch.nn.Module):
    def __init__(self, train_dataset, projection_model, device='cuda'):
        super().__init__()
        self.device = device
        self.train_dataset = train_dataset
        self.projection_model = projection_model.to(self.device)
        self.projection_model.eval()

        class_embeddings, _ = self._get_classifier_embeddings(train_dataset)
        print("class embeddings shape before projection:", class_embeddings[0].shape)
        self.class_embeddings = [self.projection_model(class_embedding.to(device)) for class_embedding in class_embeddings]
        print("class embeddings shape after projection:", self.class_embeddings[0].shape)
        
        self.class_prototypes = torch.nn.Parameter(self.get_mean_prototypes(self.class_embeddings), requires_grad=False)
        
        print("prototypes shape:", self.class_prototypes.shape)

    def _get_classifier_embeddings(self, dataset_train):
        class_embeddings = []
        empty_classes = []
        n_classes = min(torch.inf, dataset_train.n_classes)
        for cls in range(n_classes):
            cls_embs = dataset_train.get_embeddings_for_class(cls)
            if len(cls_embs) == 0:
                # if no embeddings for class, use zeros
                empty_classes.append(cls)
                class_embeddings.append(torch.zeros(1, dataset_train.emb_dim))
            else:
                class_embeddings.append(torch.tensor(np.vstack(cls_embs.values)))
        return class_embeddings, empty_classes

    def get_mean_prototypes(self, embeddings):
        # return normalize(torch.stack([class_embs.mean(dim=0) for class_embs in embeddings]), p=2, dim=-1)
        return torch.stack([class_embs.mean(dim=0) for class_embs in embeddings])

    @torch.no_grad()
    def make_prediction(self, embeddings):
        embeddings = embeddings.to(self.device)
        embeddings = self.projection_model(embeddings)
        # print(embeddings.shape)
        # print(self.class_prototypes.shape)
        if embeddings.dim() == 2 and embeddings.shape[1] != 1:
            embeddings = embeddings.unsqueeze(1)

        embeddings = normalize(embeddings, p=2, dim=-1)

        similarities = torch.nn.functional.cosine_similarity(embeddings, self.class_prototypes, dim=-1)
        # print("similarities", similarities.shape)
        # print("similarities[0]", similarities[0].shape)
        # top_5 = torch.topk(similarities, k=5, dim=1)
        # top_10 = torch.topk(similarities, k=10, dim=1)
        probas = torch.nn.functional.softmax(similarities, dim=1).detach().cpu()
        return probas

In [None]:
val_dataset.df.embedding[0].shape

In [None]:
val_dataset.df.shape

In [None]:
def get_val_probas(train_dataset, eval_dataset, format_for_submission=False, batch_size=10, reduction="mean", projection_model=None):

    proba_accumulator = {}
    
    query_embeddings = None

    if projection_model is None:
        projection_model = fitted_embedding_projection_model
        
    classifier = PrototypeClassifier(train_dataset, projection_model=projection_model, device='cpu')
    
    print(f"Class prototypes shape: {classifier.class_prototypes.shape}")
    
    # Initialize a dictionary to store predictions
    predictions = {}
    
    # Process in batches
    unique_observation_ids = eval_dataset.df["observationID"].unique()
    for i in range(0, len(unique_observation_ids), batch_size):
        batch_ids = unique_observation_ids[i:i+batch_size]
        batch_data = eval_dataset.df[eval_dataset.df["observationID"].isin(batch_ids)]
        
        # Process each filename group in this batch
        batch_embeddings = []
        batch_filenames = []
        
        for filename, group in batch_data.groupby("observationID", sort=False):
            embeddings_array = group["embedding"].to_numpy()
            embeddings_array = np.vstack(embeddings_array).squeeze()
            if reduction == "median":
                avg_embedding = np.median(embeddings_array, axis=0, keepdims=True).squeeze()
            elif reduction == "mean":
                avg_embedding = np.mean(embeddings_array, axis=0, keepdims=True).squeeze()
            
            batch_embeddings.append(avg_embedding)
            batch_filenames.append(filename)
        
        # Convert to tensor and make predictions
        avg_embeddings = torch.tensor(np.array(batch_embeddings), dtype=torch.float32)
        probas = classifier.make_prediction(avg_embeddings)
        
        # # Add to predictions dictionary
        # for fname, pred in zip(batch_filenames, top_5.indices.numpy()):
        #     predictions[fname] = pred
            
        for fname, proba in zip(batch_filenames, probas):
            proba_accumulator[fname] = proba.clone()
        
        # Free memory
        del avg_embeddings, batch_embeddings, batch_data
    
    return proba_accumulator

In [None]:
@torch.no_grad()
def get_submission_from_summed_probas(eval_dataset, proba_accumulator, format_for_submission=False, k=5):
    # Final prediction dictionary
    final_predictions = {}
    
    # For each observation, get top-5 from accumulated probabilities
    for fname, proba in proba_accumulator.items():
        top5 = torch.topk(proba, k=k)
        final_predictions[fname] = top5.indices.numpy()  # or top5.values if needed too
    
    # Map predictions back to eval dataset
    eval_dataset.df["preds"] = eval_dataset.df["observationID"].map(final_predictions)
    
    submission = eval_dataset.df.copy()
    submission = submission.drop_duplicates(subset="observationID")
    
    if format_for_submission:
        submission = submission[["observationID", "preds"]]
        submission['preds'] = submission['preds'].apply(lambda x: ' '.join(map(str, x)))

    return submission

In [None]:
def multi_hot_encode(data, num_classes):
    """
    Encodes a list of lists of categories into a multi-hot encoded numpy array.

    Args:
        data: A list of lists, where each inner list represents the categories present in an instance.
        num_classes: The total number of unique categories.

    Returns:
        A numpy array of shape (len(data), num_classes) representing the multi-hot encoded data.
    """
    encoded_data = np.zeros((len(data), num_classes), dtype=int)
    for i, instance_categories in enumerate(data):
        for category_index in instance_categories:
          if 0 <= category_index < num_classes:
            encoded_data[i, category_index] = 1
    return encoded_data

In [None]:
config

In [None]:
embedder_dims = config["emb_dims"]
embedder_dims

In [None]:
model_dims = {m:ed for m, ed in zip(config['models'], config['emb_dims'])}

start_indices = {}
cumulative_dim = 0
for model_name, dim in model_dims.items():
    start_indices[model_name] = cumulative_dim
    cumulative_dim += dim

In [None]:
start_indices

## Fit the projection model and evaluate the validation top5 accuracy

In [None]:
def slice_embedding(df, models):
    if not all(model in model_dims for model in models):
        missing = [m for m in models if m not in model_dims]
        raise ValueError(f"Models not found in configuration: {missing}")
    keep_slices = []
    for model in models:
        start_idx = start_indices[model]
        end_idx = model_dims[model] + start_idx
        keep_slices.append([start_idx, end_idx])
    df["embedding"] = df["embedding"].apply(lambda emb: get_combined_embedding(emb, keep_slices))
    return df


def get_combined_embedding(emb, keep_slices):
    model_embeddings = [emb[...,start:end] for start, end in keep_slices]
    return np.concatenate(model_embeddings, axis=-1)

In [None]:
from collections import defaultdict

In [None]:
import numpy as np
import pandas as pd
import math

def split_probability(n_obs, max_prob=0.9, steepness=1.0, midpoint=5):
    """Probabilistic curve for splitting based on number of observations."""
    return max_prob / (1 + math.exp(-steepness * (n_obs - midpoint)))

def probabilistic_train_val_split(
    df: pd.DataFrame,
    class_col: str = "category_id",
    obs_col: str = "observationID",
    val_frac: float = 0.2,
    random_state: int | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    rng = np.random.default_rng(random_state)
    train_idx, val_idx = [], []

    for cls, g in df.groupby(class_col):
        obs_ids = g[obs_col].unique()
        n_obs = len(obs_ids)

        if n_obs == 1:
            train_idx.extend(g.index)
            continue

        p_split = split_probability(n_obs)

        if rng.random() > p_split:
            train_idx.extend(g.index)
            continue

        n_val_obs = max(1, int(round(n_obs * val_frac)))
        val_obs = rng.choice(obs_ids, size=n_val_obs, replace=False)

        val_mask = g[obs_col].isin(val_obs)
        val_idx.extend(g[val_mask].index)
        train_idx.extend(g[~val_mask].index)

    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    return train_df, val_df


In [None]:
# illustrate the split probability increasing with n_obs, 0.45 with midpoint observations, capping at 0.9 

for i in range(45):
    print(i, split_probability(i))

In [None]:
best_individual_models = ['DINOv2Base@434', 'DINOv2Large@518', 'FungiTasticBEIT@384']
best_individual_augs = ['original', 'center_crop', 'top_left_crop', 'top_right_crop', 'bottom_left_crop', 'bottom_right_crop', 
                         'horizontal_flip', 'rot_90', 'rot_270', 'rot_15', 'rot_345']

# # best individual + SAMH x5 with random train-val split for projection model (0.2, probabilistic splits on image_path)
ensmbl_idx = 20
model_combos = [
    (['FungiTasticBEIT@384', 'DINOv2Base@434', 'DINOv2Large@518', 'SAMViTH@1024'], best_individual_augs, "ce_infonce"),  # best mean combo
    (['FungiTasticBEIT@384', 'DINOv2Base@434', 'DINOv2Large@518', 'SAMViTH@1024'], best_individual_augs, "ce_infonce"),  # best mean combo
    (['FungiTasticBEIT@384', 'DINOv2Base@434', 'DINOv2Large@518', 'SAMViTH@1024'], best_individual_augs, "ce_infonce"),  # best mean combo
    (['FungiTasticBEIT@384', 'DINOv2Base@434', 'DINOv2Large@518', 'SAMViTH@1024'], best_individual_augs, "ce_infonce"),  # best mean combo
    (['FungiTasticBEIT@384', 'DINOv2Base@434', 'DINOv2Large@518', 'SAMViTH@1024'], best_individual_augs, "ce_infonce"),  # best mean combo
]

list_of_projection_models = []
list_of_proba_dicts = []
for models, augs, loss_type in tqdm(model_combos):
    
    print(models, augs, loss_type, use_tim, layernorm)
    
    reset_dfs([train_dataset, val_dataset, test_dataset])
    
    # create augs subset and slice out the model embeddings subset
    train_dataset.df = train_dataset.df[train_dataset.df.transformation.isin(augs)].reset_index(drop=True)
    train_dataset.df = slice_embedding(train_dataset.df, models)
    val_dataset.df = val_dataset.df[val_dataset.df.transformation.isin(augs)].reset_index(drop=True)
    val_dataset.df = slice_embedding(val_dataset.df, models)
    test_dataset.df = test_dataset.df[test_dataset.df.transformation.isin(augs)].reset_index(drop=True)
    test_dataset.df = slice_embedding(test_dataset.df, models)
    
    input_dim = train_dataset.df.embedding[0].shape[-1]
    print("input dim before projection", input_dim)
    num_classes = train_dataset.df.category_id.nunique()
    print("num classes", num_classes)
    projection_embedder_dimension = 768
    print("projection embedder dim", projection_embedder_dimension)
    embedder_dims = train_dataset.df["emb_dims"].values[0]
    print("embedder dims", embedder_dims)

    train_val_df = pd.concat([train_dataset.df, val_dataset.df], ignore_index=True)
    train_df, val_df = probabilistic_train_val_split(
    train_val_df,
    class_col="category_id",
    obs_col="observationID",
    # obs_col="image_path",
    val_frac=0.1,
    )
    
    train_loader = DataLoader(EmbeddingDataset(train_df), batch_size=64, shuffle=True)
    val_loader = DataLoader(EmbeddingDataset(val_df), batch_size=64)
    
    model = ProjectionModel(input_dim=input_dim, embedder_dims=embedder_dims, projection_dim=projection_embedder_dimension, 
                            use_layernorm=False,
                            use_dropout=False, dropout_rate=0.25,
                            use_attention=False, attention_dim=512,
                            internal_dim=2048, extra_layer=False)
    classifier = CosineClassifier(embed_dim=projection_embedder_dimension, num_classes=num_classes)
    
    fitted_embedding_projection_model, fitted_classifier = train(model, classifier, train_loader, val_loader)
    # submission = get_val_predictions(train_dataset, val_dataset, query_aware_prototypes=False, use_tim=use_tim)

    probas_dict = get_val_probas(train_dataset, val_dataset, format_for_submission=False, batch_size=10,
                                 reduction="mean", projection_model=fitted_embedding_projection_model)
    
    list_of_proba_dicts.append(probas_dict)
    list_of_projection_models.append(fitted_embedding_projection_model)

# Create a defaultdict to accumulate summed probabilities
summed_probas = defaultdict(lambda: None)
# Loop over multiple models' probability outputs
for model_probas in list_of_proba_dicts:
    for fname, proba in model_probas.items():
        if summed_probas[fname] is None:
            summed_probas[fname] = proba.clone()
        else:
            summed_probas[fname] += proba
            
submission = get_submission_from_summed_probas(val_dataset, summed_probas)

labels = train_dataset.df.sort_values("category_id")["category_id"].unique()
y_true = submission["category_id"].to_numpy()
y_pred = np.array(submission["preds"].to_list())
y_pred = multi_hot_encode(y_pred, len(labels))
accuracy = top_k_accuracy_score(y_true, y_pred, k=5, labels=labels)
row = pd.DataFrame([{"models": models, "augs": augs, "accuracy": accuracy}])
display(row)

datetimestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
# PRED_FILENAME = f"{datetimestr}-{"-".join(models)}-{"-".join(augs)}-{loss_type}-{use_tim}-{accuracy:.3f}.csv"
PRED_FILENAME = f"ensmbl_{ensmbl_idx}-{datetimestr}-{accuracy:.3f}.csv"
print(PRED_FILENAME)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

## Average the embeddings over the observationID and use that average embedding to make each classification

In [None]:
list_of_proba_dicts_final_submission = []
for (models, augs, loss_type), projection_model in tqdm(zip(model_combos, list_of_projection_models), total=len(model_combos)):
    
    print(models, augs, loss_type, use_tim, layernorm)
    
    reset_dfs([train_dataset, val_dataset, test_dataset])

    # create augs subset and slice out the model embeddings subset
    train_dataset.df = train_dataset.df[train_dataset.df.transformation.isin(augs)].reset_index(drop=True)
    train_dataset.df = slice_embedding(train_dataset.df, models)
    val_dataset.df = val_dataset.df[val_dataset.df.transformation.isin(augs)].reset_index(drop=True)
    val_dataset.df = slice_embedding(val_dataset.df, models)
    test_dataset.df = test_dataset.df[test_dataset.df.transformation.isin(augs)].reset_index(drop=True)
    test_dataset.df = slice_embedding(test_dataset.df, models)

    # combine train and val for final prototypes
    train_val_dataset = FungiTastic(root=data_path, split='train', transform=None)
    train_val_dataset.df = pd.concat([train_dataset.df, val_dataset.df], ignore_index=True)
    
    # Create a classifier as before
    if use_tim:
        query_embeddings = []
        for group, groupdf in test_dataset.df.groupby("observationID"):
            query_embeddings.append(groupdf["embedding"].to_numpy().mean(axis=0))
        query_embeddings = np.concatenate(query_embeddings, axis=0)
        # query_embeddings = eval_dataset.df["embedding"].to_numpy()
        # query_embeddings = np.vstack(query_embeddings)
        query_embeddings = torch.tensor(query_embeddings, dtype=torch.float32)
    else:
        query_embeddings = None
    
    # Create a classifier as before
    classifier = PrototypeClassifier(train_val_dataset, projection_model=projection_model, device='cpu',
                                     use_tim=use_tim, query_embeddings=query_embeddings)
    
    # Debug print the prototype shape
    print(f"Class prototypes shape: {classifier.class_prototypes.shape}")
    
    # Initialize lists to store averaged embeddings and filenames
    avg_embeddings_list = []
    filenames_list = []
    
    # Process each filename group
    for filename, group in test_dataset.df.groupby("observationID", sort=False):
        # Convert list of embeddings to array, making sure to squeeze extra dimensions
        embeddings_array = group["embedding"].to_numpy()
    
        embeddings_array = np.vstack(embeddings_array).squeeze()
        avg_embedding = np.mean(embeddings_array, axis=0, keepdims=True).squeeze()
        
        avg_embeddings_list.append(avg_embedding)
        filenames_list.append(filename)
    
    # Convert to tensor
    avg_embeddings = torch.tensor(np.array(avg_embeddings_list), dtype=torch.float32)
    print(f"Averaged embeddings shape: {avg_embeddings.shape}")
    
    # Make predictions using the averaged embeddings
    probas = classifier.make_prediction(avg_embeddings)

    # Create a mapping from filenames to predictions
    probas_dict = {}
    for fname, proba in zip(filenames_list, probas):
        probas_dict[fname] = proba.clone()

    list_of_proba_dicts_final_submission.append(probas_dict)

# Create a defaultdict to accumulate summed probabilities
summed_probas = defaultdict(lambda: None)
# Loop over multiple models' probability outputs
for model_probas in list_of_proba_dicts_final_submission:
    for fname, proba in model_probas.items():
        if summed_probas[fname] is None:
            summed_probas[fname] = proba.clone()
        else:
            summed_probas[fname] += proba
            
submission = get_submission_from_summed_probas(test_dataset, summed_probas, format_for_submission=True, k=10)

submission = submission.rename(columns={"observationID": "observationId", "preds": "predictions"})
submission['observationId'] = submission['observationId'].astype('Int64')
submission.to_csv(f"ensmbl_{PRED_FILENAME}", index=None)
print(f"saved submission file as ensmbl_{PRED_FILENAME}")

display(submission.head()) 

In [None]:
# filename = "ensmbl_5_2025-05-19-13-16-53-0.659.csv"
# submission.to_csv(filename, index=None)
# display(submission.head()) 

In [None]:
# @torch.no_grad()
# def get_submission_from_summed_probas(eval_dataset, proba_accumulator, format_for_submission=False, k=5):
#     # Final prediction dictionary
#     final_predictions = {}
    
#     # For each observation, get top-5 from accumulated probabilities
#     for fname, proba in proba_accumulator.items():
#         top5 = torch.topk(proba, k=k)
#         final_predictions[fname] = top5.indices.numpy()  # or top5.values if needed too
    
#     # Map predictions back to eval dataset
#     eval_dataset.df["preds"] = eval_dataset.df["observationID"].map(final_predictions)
    
#     submission = eval_dataset.df.copy()
#     submission = submission.drop_duplicates(subset="observationID")
    
#     if format_for_submission:
#         submission = submission[["observationID", "preds"]]
#         submission['preds'] = submission['preds'].apply(lambda x: ' '.join(map(str, x)))

#     return submission

## Sanity checks

In [None]:
train_val_dataset.df.filename.nunique()

In [None]:
train_dataset.df.filename.nunique()

In [None]:
test_dataset.df.filename.nunique()

In [None]:
test_dataset.df.embedding[0].shape

In [None]:
train_dataset.df.transformation.unique()

In [None]:
train_val_dataset.df.transformation.unique()

In [None]:
test_dataset.df.transformation.unique()

In [None]:
train_dataset.df.columns

In [None]:
test_dataset.df.head()

In [None]:
len(avg_embeddings_list)

In [None]:
test_dataset.df["observationID"].nunique()

In [None]:
top_10.indices