In [6]:
import os
import timm
import torch

import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score
from fgvc.utils.utils import set_random_seed
from fgvc.utils.utils import set_cuda_device

SEED = 777
set_random_seed(SEED)

In [7]:
if torch.backends.mps.is_available():
    device = torch.device("mps")

if torch.cuda.is_available():
    device = torch.device("cuda")
    
print(f'Device: {device}')

Device: mps


In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Using metadata

In [9]:
train_df = pd.read_csv("../metadata/DanishFungi2024-Mini-train.csv")
test_df = pd.read_csv("../metadata/DanishFungi2024-Mini-pubtest.csv")

TARGET_FEATURE = "class_id"

In [10]:
train_df

Unnamed: 0,observationID,year,month,day,countryCode,locality,taxonID,scientificName,kingdom,phylum,...,Substrate,rightsHolder,Latitude,Longitude,CoorUncert,Habitat,image_path,class_id,MetaSubstrate,poisonous
0,2238472345,2016.0,8.0,21.0,DK,Blåbjerg,63728.0,"Russula fragilis Fr., 1838",Fungi,Basidiomycota,...,soil,Tom Smidth,55.742985,8.250188,50.0,Mixed woodland (with coniferous and deciduous ...,DF20/2238472345-167057.JPG,130,jord,0
1,2238584938,2018.0,12.0,21.0,DK,Povlsker,17325.0,Mycena vitilis (Fr.) Quél.,Fungi,Basidiomycota,...,dead wood (including bark),Jan Riis-Hansen,55.022503,15.072464,15.0,hedgerow,DF20/2238584938-38711.JPG,103,wood,0
2,2238573692,2018.0,10.0,12.0,DK,Ullerup Skov,19966.0,"Russula faginea Romagn., 1967",Fungi,Basidiomycota,...,soil,Anne Storgaard,55.967314,11.882787,50.0,Deciduous woodland,DF20/2238573692-333898.JPG,125,jord,0
3,2238331684,2013.0,9.0,17.0,DK,Langholt,17250.0,Mycena leptocephala (Pers.) Gillet,Fungi,Basidiomycota,...,,Erik Arnfred Thomsen,57.110430,10.079720,10.0,Bog woodland,DF20/2238331684-86169.JPG,81,wood,0
4,2238357655,2013.0,11.0,13.0,DK,Dragør,10107.0,Agaricus sylvaticus Schaeff.,Fungi,Basidiomycota,...,soil,Grith Styrck Carlsen,55.591850,12.666590,25.0,park/churchyard,DF20/2238357655-161487.JPG,19,jord,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
32687,2238365543,2014.0,2.0,10.0,DK,Sebberkloster Skov,45233.0,Mycena tintinnabulum (Batsch) Quél.,Fungi,Basidiomycota,...,,Per Taudal Poulsen,56.946753,9.527035,50.0,Unmanaged deciduous woodland,DF20/2238365543-13468.JPG,102,wood,0
32688,2868486474,2020.0,10.0,2.0,DK,Slagelse Lystskov,17215.0,Mycena crocata (Schrad.) P.Kumm.,Fungi,Basidiomycota,...,dead wood (including bark),Tom Kristensen,55.393250,11.406092,75.0,Deciduous woodland,DF20/2868486474-139525.JPG,70,wood,0
32689,2874310495,2020.0,10.0,10.0,DK,Slæbæk Skov,20024.0,Russula ochroleuca (Pers.) Fr.,Fungi,Basidiomycota,...,soil,Ann Liza Storm,55.109195,10.570891,16.0,Deciduous woodland,DF20/2874310495-214712.JPG,149,jord,0
32690,2238581119,2018.0,11.0,13.0,DK,Bøtø Plantage,17242.0,Mycena inclinata (Fr.) Quél.,Fungi,Basidiomycota,...,dead wood (including bark),Jørgen Mikkelsen,54.637863,11.959178,15.0,Mixed woodland (with coniferous and deciduous ...,DF20/2238581119-112313.JPG,79,wood,0


In [11]:
IMAGE_DIR = "/Users/lukaspicek/Downloads/DF20M"
train_df["image_path"] = train_df.image_path.apply(
    lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

test_df["image_path"] = test_df.image_path.apply(
    lambda path: os.path.join(IMAGE_DIR, os.path.basename(path)))

In [12]:
from sklearn import preprocessing

label_encoders = {}
columns_to_be_encoded = ["Habitat", "Substrate"]

for column_name in columns_to_be_encoded:
    le = preprocessing.LabelEncoder()
    label_encoders = {column_name: le}
    
    train_df[column_name] = le.fit_transform(train_df[column_name]).astype(np.int64)
    test_df[column_name] = le.fit_transform(test_df[column_name]).astype(np.int64)

In [13]:
metadata = pd.concat([train_df, test_df])
len(metadata)

36393

# Calculating prios

In [14]:
cls_counts = metadata.groupby(TARGET_FEATURE).size()
class_distribution = cls_counts / len(metadata)
sum(class_distribution)

1.0

## Calculate Distributions of Selected Features

In [15]:
from utils.matadata_processing import get_target_to_feature_conditional_distributions

SELECTED_FEATURES = ["Habitat", "month", "Substrate"]

# test_df = test_df[~test_df[SELECTED_FEATURES].isna().any(axis=1)]

metadata_distributions = {}
for feature in SELECTED_FEATURES:
    metadata_distributions[feature] = get_target_to_feature_conditional_distributions(
        metadata,
        feature,
        TARGET_FEATURE,
        add_to_missing=False
    )

# Predictions

## 1. Loading model from HuggingFace hub ⏳

In [16]:
N_CLASSES = len(metadata[TARGET_FEATURE].unique())
IMAGE_SIZE = [224, 224]

MODEL_NAME = "BVRA/vit_base_patch16_224.ft_df20m_224"
USE_CALIBRATION = True
USE_OBSERVATION_PREDS = True

model = timm.create_model(f"hf-hub:{MODEL_NAME}", pretrained=True)
model = model.eval()

# model_mean = list(model.default_cfg['mean'])
# model_std = list(model.default_cfg['std'])
# print(model_mean, model_std)
model_mean = [0.5, 0.5, 0.5]
model_std = [0.5, 0.5, 0.5]
print(model_mean, model_std)

model.to(device)
model.eval()
print(f"Done. {device}")

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (BVRA/vit_base_patch16_224.ft_df20m_224)


[0.5, 0.5, 0.5] [0.5, 0.5, 0.5]
Done. mps


In [17]:
from fgvc.special.calibration import ModelWithTemperature, get_temperature

if USE_CALIBRATION:
    model = ModelWithTemperature(model)
    model.to(device)

## 2. Prepare Dataloader

In [18]:
from utils.DanishFungiDataset import DanishFungiDataset, get_transforms

test_dataset = DanishFungiDataset(
    test_df,
    image_path_feature='image_path',
    target_feature=TARGET_FEATURE,
    extra_features=[*SELECTED_FEATURES, "observationID"], 
    transform=get_transforms(model_mean, model_std, IMAGE_SIZE)
)

In [19]:
batch_size = 64

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

## Inference with pre-trained model

In [20]:
from utils.matadata_processing import predict_with_features

preds, preds_raw, GT_lbls, seen_features = predict_with_features(model, test_loader, device)

vanilla_f1 = f1_score(test_df[TARGET_FEATURE], preds, average='macro')
vanilla_accuracy = accuracy_score(test_df[TARGET_FEATURE], preds)
vanilla_recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], preds_raw, k=3)
vanilla_recall_5 = top_k_accuracy_score(test_df[TARGET_FEATURE], preds_raw, k=5)
vanilla_recall_10 = top_k_accuracy_score(test_df[TARGET_FEATURE], preds_raw, k=10)

print('Vanilla:', vanilla_f1, vanilla_accuracy, vanilla_recall_3, vanilla_recall_5, vanilla_recall_10)

100%|██████████| 58/58 [01:44<00:00,  1.81s/it]

Vanilla: 0.5292422354129023 0.6573898946230748 0.8230208051877871 0.8794920291813023 0.9305593082950554





### Average image predictions to get an observation prediction

In [21]:
seen_observation_ids = np.array(seen_features["observationID"])
unique_observation_ids = np.unique(seen_observation_ids)

preds_raw_np = np.array(preds_raw)

obs_preds_raw = np.zeros((len(test_df), N_CLASSES))
obs_preds = np.zeros((len(test_df)))

for unique_observation_id in unique_observation_ids:
    same_observation_indexes = np.where(seen_observation_ids == unique_observation_id)
    
    observation_predictions = preds_raw_np[same_observation_indexes]
    _obs_preds = np.average(observation_predictions, axis=0)
    obs_preds_raw[same_observation_indexes] = _obs_preds
    obs_preds[same_observation_indexes] = _obs_preds.argmax()
    
obs_f1 = f1_score(test_df[TARGET_FEATURE], obs_preds, average='macro')
obs_accuracy = accuracy_score(test_df[TARGET_FEATURE], obs_preds)
obs_recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], obs_preds_raw, k=3)

print('ObservationID:', obs_f1, obs_accuracy, obs_recall_3)

if USE_OBSERVATION_PREDS:
    vanilla_f1 = obs_f1
    vanilla_accuracy = obs_accuracy
    vanilla_recall_3 = obs_recall_3
    preds_raw = obs_preds_raw


ObservationID: 0.6086140553984372 0.7506079437989732 0.9021885976763037


## Weighting by each Selected Feature

In [22]:
from utils.matadata_processing  import weight_predictions_by_feature_distribution

def post_process_selected_features(metadata_distributions, class_distribution, raw_predictions, ground_truth_labels):
    feature_prior_ratios = {}
    metrics_by_features = {}
    for feature in SELECTED_FEATURES:
        metadata_distribution = metadata_distributions[feature]
        seen_feature_values = seen_features[feature]

        weighted_predictions, weighted_predictions_raw, feature_prior_ratio = weight_predictions_by_feature_distribution(
            target_to_feature_conditional_distributions=metadata_distribution,
            target_distribution=class_distribution,
            ground_truth_labels=ground_truth_labels,
            raw_predictions=raw_predictions,
            ground_truth_feature_categories=seen_feature_values
        )
        feature_prior_ratios[feature] = feature_prior_ratio

        f1 = f1_score(test_df[TARGET_FEATURE], weighted_predictions, average='macro')
        accuracy = accuracy_score(test_df[TARGET_FEATURE], weighted_predictions)
        recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], weighted_predictions_raw, k=3)
        metrics_by_features[feature] = {
            "f1": f1,
            "accuracy": accuracy,
            "recall_3": recall_3
        }
        print(f'{feature}:', f1, accuracy, recall_3)
        print(f'{feature} dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100))
    
    return feature_prior_ratios, metrics_by_features
        
feature_prior_ratios, metrics_by_features = post_process_selected_features(
    metadata_distributions=metadata_distributions,
    class_distribution=class_distribution,
    raw_predictions=preds_raw,
    ground_truth_labels=GT_lbls
)

100%|██████████| 3701/3701 [00:00<00:00, 4454.85it/s]


Habitat: 0.6625571390081367 0.7803296406376655 0.9181302350716023
Habitat dif: 0.054 2.97 2.0


100%|██████████| 3701/3701 [00:00<00:00, 4658.41it/s]


month: 0.6426806251412378 0.7654687922183194 0.910294514995947
month dif: 0.034 1.49 1.0


100%|██████████| 3701/3701 [00:00<00:00, 5254.90it/s]

Substrate: 0.6446187291832475 0.7668197784382599 0.9140772764117806
Substrate dif: 0.036 1.62 1.0





## Weighting by Combinations of Selected Features

In [23]:
from itertools import combinations
from utils.matadata_processing import weight_predictions_combined_feature_priors


def post_process_prior_combinations(raw_predictions, feature_prior_ratios):
    metrics_by_combination = {}
    all_combinations_selected_features = []
    for num_features in range(2, len(SELECTED_FEATURES) + 1):
        all_combinations_selected_features.extend(combinations(SELECTED_FEATURES, num_features))
    
    for combination in all_combinations_selected_features:

        selected_feature_prior_ratios = [feature_prior_ratios[feature] for feature in combination]

        merged_predictions, merged_predictions_raw = weight_predictions_combined_feature_priors(
            raw_predictions=raw_predictions,
            feature_prior_ratios=selected_feature_prior_ratios
        )

        f1 = f1_score(test_df[TARGET_FEATURE], merged_predictions, average='macro')
        accuracy = accuracy_score(test_df[TARGET_FEATURE], merged_predictions)
        recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], merged_predictions_raw, k=3)
        
        combination_name = " + ".join(combination)
        
        metrics_by_combination[combination_name] = {
            "f1": f1,
            "accuracy": accuracy,
            "recall_3": recall_3
        }
        print(combination_name)
        print("F1, Acc, Recall3: ", f1, accuracy, recall_3)
        print("Diff: ", np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))
    
    return metrics_by_combination
        
metrics_by_combination = post_process_prior_combinations(
    raw_predictions=preds_raw,
    feature_prior_ratios=feature_prior_ratios
)

Habitat + month
F1, Acc, Recall3:  0.680415621648917 0.7846527965414752 0.9265063496352337
Diff:  0.072 3.4 2.43
Habitat + Substrate
F1, Acc, Recall3:  0.6773465792189034 0.7857335855174277 0.9254255606592813
Diff:  0.069 3.51 2.32
month + Substrate
F1, Acc, Recall3:  0.6669382585375561 0.7781680626857606 0.9197514185355309
Diff:  0.058 2.76 1.76
Habitat + month + Substrate
F1, Acc, Recall3:  0.6943800393064463 0.7941097000810592 0.9292083220751148
Diff:  0.086 4.35 2.7


In [39]:
results

{'Vanilla': {'f1': 0.6086140553984372,
  'accuracy': 0.7506079437989732,
  'recall_3': 0.9021885976763037},
 'ObservationID': {'f1': -0.6086140553984372,
  'accuracy': -0.7506079437989732,
  'recall_3': -0.9021885976763037},
 'Habitat': {'f1': -0.5546709717887376,
  'accuracy': -0.7208862469602809,
  'recall_3': -0.886246960281005},
 'month': {'f1': -0.5745474856556365,
  'accuracy': -0.735747095379627,
  'recall_3': -0.8940826803566603},
 'Substrate': {'f1': -0.5726093816136268,
  'accuracy': -0.7343961091596866,
  'recall_3': -0.8902999189408267},
 'Habitat + month': {'f1': -0.5368124891479573,
  'accuracy': -0.7165630910564712,
  'recall_3': -0.8778708457173736},
 'Habitat + Substrate': {'f1': -0.5398815315779709,
  'accuracy': -0.7154823020805188,
  'recall_3': -0.878951634693326},
 'month + Substrate': {'f1': -0.5502898522593183,
  'accuracy': -0.7230478249121859,
  'recall_3': -0.8846257768170764},
 'Habitat + month + Substrate': {'f1': -0.522848071490428,
  'accuracy': -0.707106

In [24]:
results = {
    "Vanilla":       {'f1': vanilla_f1, 'accuracy': vanilla_accuracy, 'recall_3': vanilla_recall_3},
    "ObservationID": {'f1': obs_f1, 'accuracy': obs_accuracy, 'recall_3': obs_recall_3}
}
results.update(metrics_by_features)
results.update(metrics_by_combination)

results_df = pd.DataFrame(results).transpose()
results_df = results_df[['accuracy', 'recall_3', 'f1']]
results_df.head(50)

Unnamed: 0,accuracy,recall_3,f1
Vanilla,0.750608,0.902189,0.608614
ObservationID,0.750608,0.902189,0.608614
Habitat,0.78033,0.91813,0.662557
month,0.765469,0.910295,0.642681
Substrate,0.76682,0.914077,0.644619
Habitat + month,0.784653,0.926506,0.680416
Habitat + Substrate,0.785734,0.925426,0.677347
month + Substrate,0.778168,0.919751,0.666938
Habitat + month + Substrate,0.79411,0.929208,0.69438


In [29]:
results_df - results_df.iloc[0, 1:]

Unnamed: 0,accuracy,f1,recall_3
Vanilla,,0.0,0.0
ObservationID,,0.0,0.0
Habitat,,0.053943,0.015942
month,,0.034067,0.008106
Substrate,,0.036005,0.011889
Habitat + month,,0.071802,0.024318
Habitat + Substrate,,0.068733,0.023237
month + Substrate,,0.058324,0.017563
Habitat + month + Substrate,,0.085766,0.02702


In [31]:
from utils.matadata_processing import late_metadata_fusion

late_metadata_fusion(
    metadata,
    model,
    test_loader,
    device,
    TARGET_FEATURE,
    SELECTED_FEATURES
)

100%|██████████| 58/58 [01:44<00:00,  1.81s/it]
100%|██████████| 3701/3701 [00:01<00:00, 3253.38it/s]
100%|██████████| 3701/3701 [00:01<00:00, 3461.52it/s]
100%|██████████| 3701/3701 [00:01<00:00, 2568.20it/s]


(0.5816617242993047, 0.6860308024858146, 0.8589570386382059)
(0.5506635504240965, 0.6681977843825992, 0.835179681167252)
(0.5661870872783966, 0.6808970548500405, 0.8430154012429073)
(0.5941866430385047, 0.6952175087814104, 0.8657119697379086)
(0.6085063955885723, 0.7052148068089705, 0.8697649283977303)
(0.5840983177671484, 0.6900837611456363, 0.8522021075385031)
(0.6175595124449261, 0.7119697379086734, 0.876249662253445)


{'Habitat': {'predictions': [130,
   119,
   128,
   124,
   33,
   45,
   149,
   114,
   72,
   86,
   60,
   156,
   122,
   70,
   119,
   116,
   128,
   0,
   76,
   70,
   30,
   20,
   99,
   78,
   159,
   103,
   115,
   146,
   129,
   149,
   107,
   78,
   56,
   30,
   34,
   30,
   180,
   91,
   35,
   149,
   176,
   33,
   131,
   178,
   116,
   94,
   12,
   180,
   32,
   22,
   94,
   119,
   180,
   76,
   76,
   51,
   170,
   47,
   20,
   1,
   27,
   47,
   85,
   27,
   102,
   76,
   173,
   47,
   47,
   40,
   144,
   21,
   25,
   107,
   51,
   147,
   175,
   119,
   30,
   119,
   72,
   144,
   173,
   47,
   34,
   40,
   7,
   35,
   28,
   124,
   81,
   30,
   40,
   49,
   85,
   85,
   76,
   147,
   34,
   42,
   77,
   94,
   37,
   1,
   45,
   180,
   47,
   153,
   74,
   42,
   102,
   37,
   119,
   77,
   36,
   149,
   11,
   129,
   33,
   0,
   114,
   78,
   28,
   178,
   76,
   141,
   129,
   159,
   70,
   45,
   101,
   149,
  