In [1]:
import os
import torch
import random
import logging
import numpy as np
import torch.nn as nn
import torchvision.transforms.v2 as tf

from tqdm import tqdm
from typing import Tuple, Union
from torch.nn.functional import normalize
from torchvision.datasets import Flowers102
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
from sklearn.neighbors import KNeighborsClassifier

from fungivision.wrapper import FUNGIWrapper
from fungivision.config import KLConfig, DINOConfig, SimCLRConfig

In [2]:
# Configure logging
logging.basicConfig(format="[%(asctime)s:%(levelname)s]: %(message)s", level=logging.INFO)

## Utility Functions

We first define some utility functions in this section, such as the mean-per-class accuracy, which is the default evaluation metric for the Flowers102 dataset.

In [3]:
def seed_everything(seed: int):
	random.seed(seed)
	np.random.seed(seed)

	os.environ["PYTHONHASHSEED"] = str(seed)

	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)

	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False

In [4]:
def mean_per_class_accuracy(preds: np.ndarray, targets: np.ndarray) -> float:
   """
      Calculates the mean per class accuracy by calculating
      the accuracy for each individual class and then averaging
      them. See the link below for more details:

      - https://stackoverflow.com/questions/39770376/scikit-learn-get-accuracy-scores-for-each-class

      Args:
         preds (np.ndarray): the model predictions
         targets (np.ndarray): the ground truth targets

      Returns:
         float: the mean-per-class accuracy metric
   """
   mat = confusion_matrix(preds, targets)

   # Summing over rows results in the total number of elements for each class.
   # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
   class_sums = mat.sum(axis=0)
   per_class_accuracy = mat.diagonal() / class_sums

   return per_class_accuracy.mean()

## FUNGI

We first define a function to extract FUNGI features and the generic feature extraction parameters (batch size, dataset cache, ..). We then initialize a DINOv1 model and extract FUNGI features for the KL, DINO and SimCLR objectives.

In [5]:
def extract_fungi_features(
    wrapper: FUNGIWrapper,
    dataset: Dataset,
    batch_size: int,
    num_workers: int = 18
) -> Tuple[torch.Tensor, torch.Tensor]:
    gradients, targets = [], []

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        # This makes sure each iteration returns a list of images and a list of targets,
        # without the data loader creating a batch by itself
        collate_fn=lambda batch: zip(*batch)
    )

    for images, batch_targets in tqdm(data_loader):
        targets.append(torch.tensor(batch_targets))
        gradients.append(wrapper(images).cpu().float())

    return normalize(torch.cat(gradients, dim=0), dim=-1), torch.cat(targets, dim=0)

In [6]:
def extract_features(
    model: nn.Module,
    device: torch.device,
    dataset: Dataset,
    batch_size: int,
    num_workers: int = 18
):
    features, targets = [], []

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
        with torch.no_grad():
            for images, batch_targets in tqdm(data_loader):
                images = images.to(device)
        
                targets.append(torch.tensor(batch_targets))
                features.append(model(images).cpu().float())

    return torch.cat(features, dim=0), torch.cat(targets, dim=0)

In [7]:
# Set the random seed
seed_everything(128)

In [8]:
# Define the generic feature extraction parameters
batch_size = 16
num_neighbors = 20
target_layer = "blocks.11.attn.proj"
cache_dir = "cache/flowers102"

# Make sure the cache directory exists
os.makedirs(cache_dir, exist_ok=True)

# Run the code on GPU if possible, or fallback on the CPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [9]:
# Load DINOv1
model = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")

Using cache found in /home/wsimoncini/.cache/torch/hub/facebookresearch_dino_main


In [10]:
# Create train and test datasets
train_dataset = Flowers102(root=cache_dir, split="train", download=True)
test_dataset = Flowers102(root=cache_dir, split="test", download=True)

In [11]:
# Wrap the model using the FUNGI feature extractor
fungi = FUNGIWrapper(
    model=model,
    target_layer=target_layer,
    device=device,
    use_fp16=True,
    extractor_configs=[
        KLConfig(),
        DINOConfig(),
        # You can configure the self-supervised objectives by passing arguments
        # to their configuration objects. See each config dataclass in
        # src/fungivision/config for more details
        SimCLRConfig(num_patches=4, stride_scale=6)
    ]
)

# You must call setup before extracting FUNGI features, as some objectives
# may require some supporting data to compute the loss, e.g. the SimCLR
# negative batch
fungi.setup(dataset=train_dataset)

[2024-07-10 18:39:51,646:INFO]: initializing FUNGI wrapper...
[2024-07-10 18:39:51,646:INFO]: estimating the model output dimensionality...
[2024-07-10 18:39:51,865:INFO]: generating the projection matrix...
[2024-07-10 18:39:55,580:INFO]: running setup for extractor KLGradientsExtractor
[2024-07-10 18:39:55,581:INFO]: running setup for extractor DINOGradientsExtractor
[2024-07-10 18:39:55,581:INFO]: running setup for extractor SimCLRGradientsExtractor
[2024-07-10 18:39:55,581:INFO]: computing the simclr negative batch
[2024-07-10 18:39:58,013:INFO]: encoding 3136 samples...
100%|██████████| 98/98 [00:07<00:00, 13.37it/s]


In [12]:
# Extract train and test FUNGI features and targets
fungi_train_features, _ = extract_fungi_features(wrapper=fungi, dataset=train_dataset, batch_size=batch_size)
fungi_test_features, _ = extract_fungi_features(wrapper=fungi, dataset=test_dataset, batch_size=batch_size)

100%|██████████| 64/64 [01:40<00:00,  1.57s/it]
100%|██████████| 385/385 [10:04<00:00,  1.57s/it]


In [13]:
# Extract train and test DINO embeddings
#
# The DINO inference transform according to the original repo
# https://github.com/facebookresearch/dino/blob/main/eval_knn.py#L32
transform = tf.Compose([
    tf.Resize(256, interpolation=3),
    tf.CenterCrop(224),
    tf.ToTensor(),
    tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Re-initialize the datasets, as we need to specify the inference transform
train_dataset = Flowers102(root=cache_dir, split="train", download=True, transform=transform)
test_dataset = Flowers102(root=cache_dir, split="test", download=True, transform=transform)

train_features, train_targets = extract_features(model=model, device=device, dataset=train_dataset, batch_size=batch_size)
test_features, test_targets = extract_features(model=model, device=device, dataset=test_dataset, batch_size=batch_size)

  targets.append(torch.tensor(batch_targets))
100%|██████████| 64/64 [00:02<00:00, 24.25it/s]
100%|██████████| 385/385 [00:07<00:00, 53.93it/s]


## Evaluation

We evaluate both FUNGI and embedding in k-nearest neighbor evaluation, and report both the accuracy and the mean-per-class accuracy. We evaluate both gradient-only and gradients+embeddings FUNGI features.

In [14]:
def eval(
    train_features: torch.Tensor,
    test_features: torch.Tensor,
    train_targets: torch.Tensor,
    test_targets: torch.Tensor,
    num_neighbors: int = 20,
    normalize: bool = True
):
    if normalize:
        test_features = nn.functional.normalize(test_features, dim=-1, p=2)
        train_features = nn.functional.normalize(train_features, dim=-1, p=2)

    knn_classifier = KNeighborsClassifier(
        n_neighbors=num_neighbors,
        n_jobs=-1
    ).fit(train_features, train_targets)

    predictions = knn_classifier.predict(test_features)

    correct_predictions = (predictions == np.array(test_targets)).sum()
    
    accuracy = correct_predictions / len(test_targets) * 100
    mean_per_class_acc = mean_per_class_accuracy(
        preds=predictions,
        targets=test_targets
    ) * 100
    fungi_test_features

    print(f"the test accuracy was {round(accuracy, 2)}")
    print(f"the mean per-class accuracy was {round(mean_per_class_acc, 2)}")

In [15]:
print("Embeddings")

eval(
    train_features=train_features,
    test_features=test_features,
    train_targets=train_targets,
    test_targets=test_targets,
    num_neighbors=num_neighbors
)

print("---" * 15)
print("FUNGI gradient-only features")

eval(
    train_features=fungi_train_features,
    test_features=fungi_test_features,
    train_targets=train_targets,
    test_targets=test_targets,
    num_neighbors=num_neighbors
)

print("---" * 15)
print("FUNGI gradient+embeddings features")

mixed_train_features = torch.cat([nn.functional.normalize(train_features, dim=-1, p=2), fungi_train_features], dim=-1)
mixed_test_features = torch.cat([nn.functional.normalize(test_features, dim=-1, p=2), fungi_test_features], dim=-1)

eval(
    train_features=mixed_train_features,
    test_features=mixed_test_features,
    train_targets=train_targets,
    test_targets=test_targets,
    num_neighbors=num_neighbors
)

Embeddings
the test accuracy was 73.88
the mean per-class accuracy was 76.99
---------------------------------------------
FUNGI gradient-only features
the test accuracy was 77.85
the mean per-class accuracy was 80.94
---------------------------------------------
FUNGI gradient+embeddings features
the test accuracy was 78.11
the mean per-class accuracy was 81.38
