In [1]:
import sys
sys.path.append("../")

In [2]:
import os.path as osp
import numpy as np
import os
import pandas as pd
import copy
from collections import Counter, defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy, normalize, softmax

import torch.nn.functional as F
from pytorch_metric_learning.losses import NTXentLoss


import faiss

from typing import Dict, List, Tuple
from tqdm import tqdm

from src.data.catalog import FungiTasticCatalog, CUBCatalog, Catalog
from src.data.datasets.image_dataset import ImageDataset

from src.config import EMBEDDING_KEY_NAME, FILENAME_KEY_NAME, TRANSFORMATION_KEY_NAME, TARGET_KEY_NAME

torch.backends.cudnn.benchmark = True

In [3]:
%load_ext autoreload
%autoreload 2

# Define parameters

In [4]:
DATASET_ROOT = "/media/marek/disk/datasets/MPV"
OUTPUT_DIR = "../logs/embeddings"
os.makedirs(OUTPUT_DIR, exist_ok=True)


### Load model

In [5]:
def load_df_with_embeddings( embedding_path: str):
    """Load embeddings and metadata to recreate a FungiTastic dataset with precomputed features."""
    print(f"Loading embeddings from {embedding_path}")
    saved_data = torch.load(embedding_path, map_location="cpu",  weights_only=False)
    emb_data = saved_data["embeddings_data"]

    emb_df = pd.DataFrame({
        EMBEDDING_KEY_NAME: list(emb_data[EMBEDDING_KEY_NAME]),
        TRANSFORMATION_KEY_NAME: emb_data[TRANSFORMATION_KEY_NAME],
        TARGET_KEY_NAME: emb_data[TARGET_KEY_NAME],
        FILENAME_KEY_NAME: emb_data[FILENAME_KEY_NAME],
    })

    return emb_df


In [8]:
output_sub_dir = osp.join(OUTPUT_DIR, "MPV-FewShot")
load_file_names = [
    "facebook-dinov3-vit7b16-pretrain-lvd1689m_512",
    # "BAAI-EVA-CLIP-18B_224",
    "hf-hub_BVRA-swin_base_patch4_window12_384.in1k_ft_fungitastic_384_224",
    "hf-hub_BVRA-vit_base_patch16_224.in1k_ft_fungitastic_224_224",
    # "vit_pe_core_gigantic_patch14_448_448"
]

shortcuts = [
    "dinov3",
    # "evaclip",
    "bvra-swin",
    "bvra-vit",
    # "pe_core"
]

embedding_dfs = {}
for shortcut, load_file_name in zip(shortcuts, load_file_names):
    try:
        train_emb = load_df_with_embeddings(embedding_path=osp.join(output_sub_dir, f"train_{load_file_name}.pth"))
        val_emb = load_df_with_embeddings(embedding_path=osp.join(output_sub_dir, f"val_{load_file_name}.pth"))
        test_emb = load_df_with_embeddings(embedding_path=osp.join(output_sub_dir, f"test_{load_file_name}.pth"))

        embedding_df = pd.concat([train_emb, val_emb, test_emb])
        # embedding_df = pd.concat([val_emb])
        embedding_dfs[shortcut] = embedding_df

    except Exception as e:
        print(e)

embedding_dimensions = {key: df["embedding"].iloc[0].shape[0] for key, df in embedding_dfs.items()}


Loading embeddings from ../logs/embeddings/MPV-FewShot/train_facebook-dinov3-vit7b16-pretrain-lvd1689m_512.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/val_facebook-dinov3-vit7b16-pretrain-lvd1689m_512.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/test_facebook-dinov3-vit7b16-pretrain-lvd1689m_512.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/train_hf-hub_BVRA-swin_base_patch4_window12_384.in1k_ft_fungitastic_384_224.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/val_hf-hub_BVRA-swin_base_patch4_window12_384.in1k_ft_fungitastic_384_224.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/test_hf-hub_BVRA-swin_base_patch4_window12_384.in1k_ft_fungitastic_384_224.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/train_hf-hub_BVRA-vit_base_patch16_224.in1k_ft_fungitastic_224_224.pth
Loading embeddings from ../logs/embeddings/MPV-FewShot/val_hf-hub_BVRA-vit_base_patch16_224.in1k_ft_fungitastic_224_224.pth
Loading embeddings

In [9]:
import gc

for sc in shortcuts:
    embedding_dfs[sc] = embedding_dfs[sc].rename(columns={"embedding": f"embedding_{sc}"})

# 2. Start with the first DF
merged = embedding_dfs[shortcuts[0]][["filename", "transform", f"embedding_{shortcuts[0]}"]]

# 3. Merge the rest iteratively using only necessary columns
for sc in shortcuts[1:]:
    df = embedding_dfs[sc][["filename", "transform", f"embedding_{sc}"]]
    merged = pd.merge(
        merged,
        df,
        on=["filename", "transform"],
        how="inner"  # keep only matching filename + transform
    )
    # Optional: free memory of df after merge
    del df
    gc.collect()

# 4. Concatenate embeddings row-wise without apply
embedding_cols = [f"embedding_{sc}" for sc in shortcuts]

# Preallocate an array for all concatenated embeddings
concat_embs = np.empty((len(merged), sum(embedding_dfs[sc][f"embedding_{sc}"].iloc[0].shape[0] for sc in shortcuts)), dtype=np.float32)

start = 0
for sc in shortcuts:
    col_embs = np.stack(merged[f"embedding_{sc}"].values).astype(np.float32)
    end = start + col_embs.shape[1]
    concat_embs[:, start:end] = col_embs
    start = end
    # free memory
    merged.drop(columns=[f"embedding_{sc}"], inplace=True)
    gc.collect()

merged["embedding"] = list(concat_embs)  # final concatenated embedding


In [10]:
catalog = FungiTasticCatalog(
    dataset_root=DATASET_ROOT,
    dataset_variant="fewshot",
    dataset_size="720p",
    download=False,
    keep_zip=False,
)

catalog.add_embeddings(merged, validate="one_to_many")
df = catalog.get_metadata()
train_df = df[df["split"] == "train"]
val_df = df[df["split"] == "val"]
test_df = df[df["split"] == "test"]


In [11]:
del merged
del embedding_dfs
gc.collect()

0

# FungiCLEF 2025
https://ceur-ws.org/Vol-4038/paper_239.pdf

In [12]:
class EmbeddingDataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe
        self.embeddings = dataframe[EMBEDDING_KEY_NAME].values
        self.labels = dataframe[TARGET_KEY_NAME].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        emb = torch.tensor(self.embeddings[idx], dtype=torch.float32).squeeze()
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return emb, label

In [13]:
class ProjectionModel(torch.nn.Module):
    def __init__(self, input_dim, embedder_dims, projection_dim=512, use_layernorm=False, use_dropout=False, dropout_rate=0.1,
                 internal_dim=1024, use_attention=False, attention_dim=512, extra_layer=False):
        super().__init__()
        self.use_attention = use_attention
        self.num_embedders = len(embedder_dims) if use_attention else None
        self.embedder_dims = embedder_dims if use_attention else None
        self.fixed_dim = attention_dim

        if use_attention:
            self.attn_weights = torch.nn.Parameter(torch.ones(self.num_embedders))
            # ensures we can stack the embeddings
            self.attn_projections = torch.nn.ModuleList([
                torch.nn.Sequential(torch.nn.Linear(emb_dim, self.fixed_dim), torch.nn.ReLU()) for emb_dim in embedder_dims
            ])

        layers = [
            torch.nn.Linear(self.fixed_dim if use_attention else input_dim, internal_dim * 2 if extra_layer else internal_dim),
            torch.nn.ReLU()
        ]

        if use_layernorm:
            layers.append(torch.nn.LayerNorm(internal_dim))
        if use_dropout:
            layers.append(torch.nn.Dropout(dropout_rate))

        if extra_layer:
            layers.extend([torch.nn.Linear(internal_dim * 2, internal_dim), torch.nn.ReLU()])

        layers.append(torch.nn.Linear(internal_dim, projection_dim))

        if use_layernorm:
            layers.append(torch.nn.LayerNorm(projection_dim))

        self.projection = torch.nn.Sequential(*layers)

    def forward(self, x):
        if self.use_attention:
            if x.dim() == 3 and x.shape[1] == 1:
                x = x.squeeze(1)  # Squeeze out the second dimension
            # split the embeddings out
            start = 0
            embeddings = []
            for emb_dim, proj_layer in zip(self.embedder_dims, self.attn_projections):
                embedding = x[:, start:start + emb_dim]
                projected_embedding = proj_layer(embedding)
                embeddings.append(projected_embedding)
                start += emb_dim
            weights = softmax(self.attn_weights, dim=0)
            # for embed in embeddings:
            #     print(embed.shape)
            x = torch.stack([w * e for w, e in zip(weights, embeddings)], dim=0).sum(dim=0)
            # print("after attention projections:", x.shape)
            x = self.projection(x)
            # print("after final projection:", x.shape)
        else:
            x = self.projection(normalize(x, p=2, dim=-1))
        # x = self.projection(x)
        return normalize(x, p=2, dim=-1)


class CosineClassifier(torch.nn.Module):
    """Classifier to train the ProjectionModel"""
    def __init__(self, embed_dim, num_classes, scale=10.0):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(num_classes, embed_dim))
        self.scale = scale  # Optional learnable scaling
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        # x: [B, D]
        x = normalize(x, p=2, dim=-1)
        w = normalize(self.weight, p=2, dim=-1)
        return self.scale * torch.matmul(x, w.T)

In [14]:

class MultiEmbedderProjection(torch.nn.Module):
    def __init__(self, embed_dims, projection_dim=512, hidden_dim=None):
        super().__init__()

        self.num_embedders = len(embed_dims)
        hidden_dim = hidden_dim or projection_dim

        # 1) per-embedder projection to common space
        self.per_embedder = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(d, hidden_dim),
                torch.nn.GELU(),
                torch.nn.LayerNorm(hidden_dim)
            )
            for d in embed_dims
        ])

        # 2) learnable scalar gates
        self.gates = torch.nn.Parameter(torch.zeros(self.num_embedders))

        # 3) final projection to desired embedding dim
        self.project = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, projection_dim),
            torch.nn.GELU(),
            torch.nn.Linear(projection_dim, projection_dim),
            torch.nn.LayerNorm(projection_dim)
        )

    def forward(self, x):
        """
        x is concatenated embeddings of shape [B, sum(embed_dims)]
        """
        # Split into per-embedder chunks
        chunks = torch.split(x, [p[0].in_features for p in self.per_embedder], dim=-1)

        # Per-embedder projected embeddings
        projected = [proj(chunk) for proj, chunk in zip(self.per_embedder, chunks)]

        # Soft gating
        weights = F.softmax(self.gates, dim=0)

        fused = sum(w * h for w, h in zip(weights, projected))

        # Final CLIP-style projection
        out = self.project(fused)

        return F.normalize(out, dim=-1)


class FewShotProjection(torch.nn.Module):
    def __init__(self, embed_dims, projection_dim=512):
        super().__init__()

        self.num_embedders = len(embed_dims)

        # 1) Per-embedder lightweight adapters
        self.per_embedder = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(d, projection_dim, bias=False),
                torch.nn.LayerNorm(projection_dim)
            )
            for d in embed_dims
        ])

        # 2) Learnable gate for each embedder (scalar)
        self.gates = torch.nn.Parameter(torch.zeros(self.num_embedders))

        # 3) Final normalization (no deep MLP)
        self.norm = torch.nn.LayerNorm(projection_dim)

    def forward(self, x):
        # split input back into chunks per embedder
        chunks = torch.split(x, [p[0].in_features for p in self.per_embedder], dim=-1)

        # project each embedder
        projected = [proj(chunk) for proj, chunk in zip(self.per_embedder, chunks)]

        # soft fusion
        weights = F.softmax(self.gates, dim=0)
        fused = sum(w * h for w, h in zip(weights, projected))

        # final normalized embedding
        return F.normalize(self.norm(fused), dim=-1)


In [15]:
from src.criterion.classification import SeesawLossWithLogits


def train(model, classifier, train_loader, val_loader, num_epochs=300, patience=5, lr=1e-5, device='cuda'):
    return _train_ce_infonce(model, classifier, train_loader, val_loader, num_epochs, patience, lr, device)

class LearnableLossWeighting(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize log variances as learnable parameters
        self.log_sigma_ce = torch.nn.Parameter(torch.tensor(0.0))
        self.log_sigma_triplet = torch.nn.Parameter(torch.tensor(0.0))

    def forward(self, ce_loss, triplet_loss):
        # From Kendall et al. CVPR 2018
        loss = (
                torch.exp(-self.log_sigma_ce) * ce_loss +
                torch.exp(-self.log_sigma_triplet) * triplet_loss +
                self.log_sigma_ce + self.log_sigma_triplet
        )
        return 0.5 * loss

def _train_ce_infonce(
    model,
    classifier,
    train_loader,
    val_loader,
    num_epochs,
    patience,
    lr,
    device,
    lambda_triplet=None
):
    model.to(device)
    classifier.to(device)
    if lambda_triplet == "learned" or lambda_triplet is None:
        loss_weighter = LearnableLossWeighting().to(device)
        optimizer = torch.optim.AdamW(
            list(model.parameters()) +
            list(classifier.parameters()) +
            list(loss_weighter.parameters()),
            lr=lr, weight_decay=1e-4
        )
    elif isinstance(lambda_triplet, float) or isinstance(lambda_triplet, int):
        optimizer = torch.optim.AdamW(
            list(model.parameters()) + list(classifier.parameters()),
            lr=lr,
            weight_decay=1e-4
        )

    infonce_loss_func = NTXentLoss(temperature=0.07).to(device)

    best_val_loss = float('inf')
    best_model_state = None
    best_classifier_state = None
    patience_counter = 0

    pbar = tqdm(range(num_epochs), total=num_epochs)

    # loss_fn = SeesawLossWithLogits(df=train_loader.dataset.df).to(device)
    for epoch in pbar:
        model.train()
        classifier.train()
        total_loss = 0.0
        ce_loss_total = 0.0
        triplet_loss_total = 0.0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            embeddings = model(x)
            logits = classifier(embeddings)

            # Cross-entropy
            ce_loss = F.cross_entropy(logits, y)
            # ce_loss = loss_fn(logits, y)

            infonce_loss = infonce_loss_func(embeddings, y)


            # Combined loss
            if lambda_triplet == "learned" or lambda_triplet is None:
                loss = loss_weighter(ce_loss, infonce_loss)
            elif isinstance(lambda_triplet, float) or isinstance(lambda_triplet, int):
                loss = ce_loss + lambda_triplet * infonce_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            ce_loss_total += ce_loss.item()
            triplet_loss_total += infonce_loss.item()

        # Validation
        model.eval()
        classifier.eval()
        val_loss = 0.0

        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                embeddings = model(x)
                logits = classifier(embeddings)
                ce_loss = F.cross_entropy(logits, y)

                infonce_loss = infonce_loss_func(embeddings, y)

                if lambda_triplet == "learned" or lambda_triplet is None:
                    val_loss += (loss_weighter(ce_loss, infonce_loss)).item()
                elif isinstance(lambda_triplet, float) or isinstance(lambda_triplet, int):
                    val_loss += (ce_loss + lambda_triplet * infonce_loss).item()

        avg_val_loss = val_loss / len(val_loader)
        pbar.set_description(
            f"Epoch {epoch+1:3d} | "
            f"Train CE: {ce_loss_total:.4f} | "
            f"InfoNCE: {triplet_loss_total:.4f} | "
            f"Total: {total_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}"
        )

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            best_classifier_state = copy.deepcopy(classifier.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping.")
                break

    model.load_state_dict(best_model_state)
    classifier.load_state_dict(best_classifier_state)
    return model, classifier

## Fit projection network

## Top-5 Accuracy against validation set

In [16]:
def get_embeddings_for_class(dataset: ImageDataset, cls):
    df = dataset.get_dataset()
    class_idxs = df[df[TARGET_KEY_NAME] == cls].index
    return df.iloc[class_idxs][EMBEDDING_KEY_NAME]


class PrototypeClassifier(torch.nn.Module):
    def __init__(self, train_dataset, projection_model, device='cuda'):
        super().__init__()
        self.device = device
        self.train_dataset = train_dataset
        self.projection_model = projection_model.to(self.device)
        self.projection_model.eval()

        self.n_classes, self.emb_dim = None, None

        class_embeddings, _ = self._get_classifier_embeddings(train_dataset)

        print("class embeddings shape before projection:", class_embeddings[0].shape)
        self.class_embeddings = [self.projection_model(class_embedding.to(device)) for class_embedding in class_embeddings]
        print("class embeddings shape after projection:", self.class_embeddings[0].shape)

        self.class_prototypes = torch.nn.Parameter(self.get_mean_prototypes(self.class_embeddings), requires_grad=False)

        print("prototypes shape:", self.class_prototypes.shape)

    def _get_classifier_embeddings(self, dataset_train: ImageDataset):
        class_embeddings = []
        empty_classes = []
        self.n_classes = dataset_train.get_dataset()[TARGET_KEY_NAME].nunique()
        for cls in range(self.n_classes):
            cls_embs = get_embeddings_for_class(dataset_train, cls)
            if len(cls_embs) == 0:
                # if no embeddings for class, use zeros
                empty_classes.append(cls)
                class_embeddings.append(torch.zeros(1, dataset_train.emb_dim))
            else:
                class_embeddings.append(torch.tensor(np.vstack(cls_embs.values)))
        return class_embeddings, empty_classes

    def get_mean_prototypes(self, embeddings):
        # return normalize(torch.stack([class_embs.mean(dim=0) for class_embs in embeddings]), p=2, dim=-1)
        return torch.stack([class_embs.mean(dim=0) for class_embs in embeddings])

    @torch.no_grad()
    def make_prediction(self, embeddings, batch_size=2048):
        probas_list = []

        for i in range(0, len(embeddings), batch_size):
            batch = embeddings[i:i+batch_size].to(self.device)
            batch = self.projection_model(batch)
            batch = F.normalize(batch, dim=-1)

            sims = batch @ self.class_prototypes.T     # [B, C]
            probas = F.softmax(sims, dim=1).cpu()

            probas_list.append(probas)

        return torch.cat(probas_list, dim=0)


In [17]:
def get_val_probas(train_dataset, eval_dataset, format_for_submission=False, batch_size=10, reduction="mean", projection_model=None):

    proba_accumulator = {}

    query_embeddings = None

    if projection_model is None:
        projection_model = fitted_embedding_projection_model

    classifier = PrototypeClassifier(train_dataset, projection_model=projection_model, device='cpu')

    print(f"Class prototypes shape: {classifier.class_prototypes.shape}")

    # Initialize a dictionary to store predictions
    predictions = {}

    # Process in batches
    unique_observation_ids = eval_dataset.df["observationID"].unique()
    for i in range(0, len(unique_observation_ids), batch_size):
        batch_ids = unique_observation_ids[i:i+batch_size]
        batch_data = eval_dataset.df[eval_dataset.df["observationID"].isin(batch_ids)]

        # Process each filename group in this batch
        batch_embeddings = []
        batch_filenames = []

        for filename, group in batch_data.groupby("observationID", sort=False):
            embeddings_array = group[EMBEDDING_KEY_NAME].to_numpy()
            embeddings_array = np.vstack(embeddings_array).squeeze()
            if reduction == "median":
                avg_embedding = np.median(embeddings_array, axis=0, keepdims=True).squeeze()
            elif reduction == "mean":
                avg_embedding = np.mean(embeddings_array, axis=0, keepdims=True).squeeze()

            batch_embeddings.append(avg_embedding)
            batch_filenames.append(filename)

        # Convert to tensor and make predictions
        avg_embeddings = torch.tensor(np.array(batch_embeddings), dtype=torch.float32)
        probas = classifier.make_prediction(avg_embeddings)

        # # Add to predictions dictionary
        # for fname, pred in zip(batch_filenames, top_5.indices.numpy()):
        #     predictions[fname] = pred

        for fname, proba in zip(batch_filenames, probas):
            proba_accumulator[fname] = proba.clone()

        # Free memory
        del avg_embeddings, batch_embeddings, batch_data

    return proba_accumulator

In [18]:
@torch.no_grad()
def get_submission_from_summed_probas(eval_dataset, proba_accumulator, format_for_submission=False, k=5):
    # Final prediction dictionary
    final_predictions = {}

    # For each observation, get top-5 from accumulated probabilities
    for fname, proba in proba_accumulator.items():
        top5 = torch.topk(proba, k=k)
        final_predictions[fname] = top5.indices.numpy()  # or top5.values if needed too

    # Map predictions back to eval dataset
    eval_dataset.df["preds"] = eval_dataset.df["observationID"].map(final_predictions)

    submission = eval_dataset.df.copy()
    submission = submission.drop_duplicates(subset="observationID")

    if format_for_submission:
        submission = submission[["observationID", "preds"]]
        submission['preds'] = submission['preds'].apply(lambda x: ' '.join(map(str, x)))

    return submission

## Fit the projection model and evaluate the validation top5 accuracy

In [19]:
import math


def split_probability(n_obs, max_prob=0.9, steepness=1.0, midpoint=5):
    """Probabilistic curve for splitting based on number of observations."""
    return max_prob / (1 + math.exp(-steepness * (n_obs - midpoint)))

def probabilistic_train_val_split(
        df: pd.DataFrame,
        class_col: str = TARGET_KEY_NAME,
        obs_col: str = "observationID",
        val_frac: float = 0.2,
        random_state: int | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    rng = np.random.default_rng(random_state)
    train_idx, val_idx = [], []

    for cls, g in df.groupby(class_col):
        obs_ids = g[obs_col].unique()
        n_obs = len(obs_ids)

        if n_obs == 1:
            train_idx.extend(g.index)
            continue

        p_split = split_probability(n_obs)

        if rng.random() > p_split:
            train_idx.extend(g.index)
            continue

        n_val_obs = max(1, int(round(n_obs * val_frac)))
        val_obs = rng.choice(obs_ids, size=n_val_obs, replace=False)

        val_mask = g[obs_col].isin(val_obs)
        val_idx.extend(g[val_mask].index)
        train_idx.extend(g[~val_mask].index)

    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    return train_df, val_df


In [None]:
# illustrate the split probability increasing with n_obs, 0.45 with midpoint observations, capping at 0.9
for i in range(45):
    print(i, split_probability(i))

In [20]:
from sklearn.metrics import top_k_accuracy_score


TRANSFORM = None
train_dataset = ImageDataset(df=train_df, transform=TRANSFORM)
val_dataset = ImageDataset(df=val_df, transform=TRANSFORM)
test_dataset = ImageDataset(df=test_df, transform=TRANSFORM)

input_dim = train_dataset.df[EMBEDDING_KEY_NAME].iloc[0].shape[-1]
print("input dim before projection", input_dim)
num_classes = train_dataset.df[TARGET_KEY_NAME].nunique()
print("num classes", num_classes)
projection_embedder_dimension = 1024 # 768
print("projection embedder dim", projection_embedder_dimension)
embedder_dims = list(embedding_dimensions.values())
print("embedder dims", embedder_dims)

train_val_df = pd.concat([train_dataset.df, val_dataset.df], ignore_index=True)
train_df, val_df = probabilistic_train_val_split(
    train_val_df,
    class_col=TARGET_KEY_NAME,
    obs_col="observationID",
    # obs_col="image_path",
    val_frac=0.1,
)

batch_size = 128
train_loader = DataLoader(EmbeddingDataset(train_df), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(EmbeddingDataset(val_df), batch_size=batch_size)

projection_model = ProjectionModel(
    input_dim=input_dim, embedder_dims=embedder_dims, projection_dim=projection_embedder_dimension,
   use_layernorm=False,
   use_dropout=False, dropout_rate=0.25,
   use_attention=False, attention_dim=512,
   internal_dim=2048, extra_layer=False
)

classifier = CosineClassifier(embed_dim=projection_embedder_dimension, num_classes=num_classes)

fitted_embedding_projection_model, fitted_classifier = train(projection_model, classifier, train_loader, val_loader, num_epochs=100)
# submission = get_val_predictions(train_dataset, val_dataset, query_aware_prototypes=False, use_tim=use_tim)

# probas_dict = get_val_probas(train_dataset, val_dataset, format_for_submission=False, batch_size=10,
#                              reduction="mean", projection_model=fitted_embedding_projection_model)
#
# # Create a defaultdict to accumulate summed probabilities
# summed_probas = defaultdict(lambda: None)
# # Loop over multiple models' probability outputs
# for fname, proba in probas_dict.items():
#     if summed_probas[fname] is None:
#         summed_probas[fname] = proba.clone()
#     else:
#         summed_probas[fname] += proba
#
# submission = get_submission_from_summed_probas(val_dataset, summed_probas)
#
# labels = train_dataset.df.sort_values(TARGET_KEY_NAME)[TARGET_KEY_NAME].unique()
# y_true = submission[TARGET_KEY_NAME].to_numpy()
# y_pred = np.array(submission["preds"].to_list())
# y_pred = multi_hot_encode(y_pred, len(labels))
# accuracy = top_k_accuracy_score(y_true, y_pred, k=5, labels=labels)
# print(accuracy)


input dim before projection 5888
num classes 2427
projection embedder dim 1024
embedder dims [4096, 1024, 768]


Epoch  33 | Train CE: 280.7676 | InfoNCE: 3.3047 | Total: 30.0815 | Val Loss: 1.6953:  32%|███▏      | 32/100 [03:25<07:16,  6.41s/it]     

Early stopping.





In [None]:
torch.save(fitted_embedding_projection_model.state_dict(), "./out/projection_model.pth")
torch.save(fitted_classifier.state_dict(), "./out/tmp-classifier.pth")

In [21]:
import gc
gc.collect()
torch.cuda.empty_cache()

## Average the embeddings over the observationID and use that average embedding to make each classification

In [22]:
list_of_proba_dicts_final_submission = []
train_val_dataset = ImageDataset(df=pd.concat([train_df, val_df]), transform=TRANSFORM)

# Create a classifier as before
classifier = PrototypeClassifier(train_val_dataset, projection_model=projection_model, device='cuda',)


# Initialize lists to store averaged embeddings and filenames
avg_embeddings_list = []
filenames_list = []

# Process each filename group
for filename, group in test_dataset.df.groupby("observationID", sort=False):
    # Convert list of embeddings to array, making sure to squeeze extra dimensions
    embeddings_array = group[EMBEDDING_KEY_NAME].to_numpy()

    embeddings_array = np.vstack(embeddings_array).squeeze()
    avg_embedding = np.mean(embeddings_array, axis=0, keepdims=True).squeeze()

    avg_embeddings_list.append(avg_embedding)
    filenames_list.append(filename)

# Convert to tensor
avg_embeddings = torch.tensor(np.array(avg_embeddings_list), dtype=torch.float32)
print(f"Averaged embeddings shape: {avg_embeddings.shape}")

class embeddings shape before projection: torch.Size([44, 5888])
class embeddings shape after projection: torch.Size([44, 1024])
prototypes shape: torch.Size([2427, 1024])
Averaged embeddings shape: torch.Size([999, 5888])


In [None]:
pred_file_name = f"./out/all_base-projection-1024-out_trainval_views_proto.csv"

# Make predictions using the averaged embeddings
probas = classifier.make_prediction(avg_embeddings, batch_size=256)

# Create a mapping from filenames to predictions
probas_dict = {}
for fname, proba in zip(filenames_list, probas):
    probas_dict[fname] = proba.clone()

list_of_proba_dicts_final_submission.append(probas_dict)

# Create a defaultdict to accumulate summed probabilities
summed_probas = defaultdict(lambda: None)
# Loop over multiple models' probability outputs
for model_probas in list_of_proba_dicts_final_submission:
    for fname, proba in model_probas.items():
        if summed_probas[fname] is None:
            summed_probas[fname] = proba.clone()
        else:
            summed_probas[fname] += proba

submission = get_submission_from_summed_probas(test_dataset, summed_probas, format_for_submission=True, k=10)

submission = submission.rename(columns={"observationID": "observationId", "preds": "predictions"})
submission['observationId'] = submission['observationId'].astype('Int64')
submission.to_csv(pred_file_name, index=None)
print(f"saved submission file as {pred_file_name}")

display(submission.head())