In [1]:
import os
import timm
import torch

import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score
# from fgvc.utils.utils import set_random_seed

SEED = 777
# set_random_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f'Device: {device}')

Device: mps


In [3]:
%load_ext autoreload
%autoreload 2

## Using metadata

In [7]:
print(os.getcwd())

/Users/jeremycui/Documents/UCB_MIDS/DATASCI207/fungitastic-classification-datasci207-Fall-2025/inference


In [8]:
train_df = pd.read_csv("../dataset/FungiTastic/metadata/FungiTastic-Mini/FungiTastic-Mini-Train.csv")
val_df = pd.read_csv("../dataset/FungiTastic/metadata/FungiTastic-Mini/FungiTastic-Mini-ClosedSet-Val.csv")
test_df = pd.read_csv("../dataset/FungiTastic/metadata/FungiTastic-Mini/FungiTastic-Mini-ClosedSet-Test.csv")

In [9]:
test_df = test_df[test_df.habitat != "masonry"]

In [10]:
full_test = pd.read_csv("./FungiTastic-Mini-test-RAW.csv")

observation_2_class_id = dict(zip(full_test.observationID, full_test.class_id)) 
test_df["category_id"] = test_df.observationID.apply(lambda obsID: observation_2_class_id[obsID])

FileNotFoundError: [Errno 2] No such file or directory: './FungiTastic-Mini-test-RAW.csv'

In [9]:
TRAIN_IMAGE_DIR = "/Users/lukaspicek/Downloads/images/FungiTastic-Mini/train/500p"
train_df["image_path"] = train_df.filename.apply(
    lambda filename: os.path.join(TRAIN_IMAGE_DIR, filename))

VAL_IMAGE_DIR = "/Users/lukaspicek/Downloads/images/FungiTastic-Mini/val/500p"

val_df["image_path"] = val_df.filename.apply(
    lambda filename: os.path.join(VAL_IMAGE_DIR, filename))

TEST_IMAGE_DIR = "/Users/lukaspicek/Downloads/images/FungiTastic-Mini/test/500p"

test_df["image_path"] = test_df.filename.apply(
    lambda filename: os.path.join(TEST_IMAGE_DIR, filename))

In [10]:
all_metadata = pd.concat([train_df, val_df, test_df])
len(all_metadata), len(train_df), len(val_df), len(test_df)

(66990, 46842, 9412, 10736)

In [11]:
train_df

Unnamed: 0,eventDate,year,month,day,habitat,countryCode,scientificName,kingdom,phylum,class,...,region,district,filename,category_id,metaSubstrate,poisonous,elevation,landcover,biogeographicalRegion,image_path
0,2021-02-01,2021,2.0,1.0,Mixed woodland (with coniferous and deciduous ...,DK,Mycena tintinnabulum (Paulet) Quél.,Fungi,Basidiomycota,Agaricomycetes,...,Sjælland,Næstved,0-3032614317.JPG,119,wood,0,35.0,5.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...
1,2021-02-01,2021,2.0,1.0,Mixed woodland (with coniferous and deciduous ...,DK,Mycena tintinnabulum (Paulet) Quél.,Fungi,Basidiomycota,Agaricomycetes,...,Sjælland,Næstved,1-3032614317.JPG,119,wood,0,35.0,5.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...
2,2008-09-01,2008,9.0,1.0,Deciduous woodland,DK,Russula cyanoxantha (Schaeff.) Fr.,Fungi,Basidiomycota,Agaricomycetes,...,Midtjylland,Århus,0-3036761318.JPG,144,jord,0,6.0,10.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...
3,2008-09-01,2008,9.0,1.0,Deciduous woodland,DK,Russula cyanoxantha (Schaeff.) Fr.,Fungi,Basidiomycota,Agaricomycetes,...,Midtjylland,Århus,1-3036761318.JPG,144,jord,0,6.0,10.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...
4,2008-09-01,2008,9.0,1.0,Deciduous woodland,DK,Russula cyanoxantha (Schaeff.) Fr.,Fungi,Basidiomycota,Agaricomycetes,...,Midtjylland,Århus,2-3036761318.JPG,144,jord,0,6.0,10.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46837,2021-09-29,2021,9.0,29.0,Unmanaged deciduous woodland,RU,Mycena rosea Gramberg,Fungi,Basidiomycota,Agaricomycetes,...,Kaluga,Maloyaroslavetskiy rayon,1-3429074356.JPG,109,jord,1,,,boreal,/Users/lukaspicek/Downloads/images/FungiTastic...
46838,2021-09-29,2021,9.0,29.0,Unmanaged deciduous woodland,RU,Mycena rosea Gramberg,Fungi,Basidiomycota,Agaricomycetes,...,Kaluga,Maloyaroslavetskiy rayon,2-3429074356.JPG,109,jord,1,,,boreal,/Users/lukaspicek/Downloads/images/FungiTastic...
46839,2021-09-27,2021,9.0,27.0,Deciduous woodland,DK,Russula ochroleuca Fr.,Fungi,Basidiomycota,Agaricomycetes,...,Hovedstaden,Halsnæs,0-4100099773.JPG,179,jord,0,0.0,17.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...
46840,2021-09-27,2021,9.0,27.0,Deciduous woodland,DK,Russula ochroleuca Fr.,Fungi,Basidiomycota,Agaricomycetes,...,Hovedstaden,Halsnæs,1-4100099773.JPG,179,jord,0,0.0,17.0,continental,/Users/lukaspicek/Downloads/images/FungiTastic...


In [12]:
from sklearn import preprocessing

label_encoders = {}
columns_to_be_encoded = ["habitat", "substrate", "biogeographicalRegion", "metaSubstrate"]

for column_name in columns_to_be_encoded:
    le = preprocessing.LabelEncoder()
    le.fit(all_metadata[column_name].unique())
    label_encoders = {column_name: le}

    train_df[column_name] = le.transform(train_df[column_name]).astype(np.int64)
    val_df[column_name] = le.transform(val_df[column_name]).astype(np.int64)
    test_df[column_name] = le.transform(test_df[column_name]).astype(np.int64)

In [29]:
d

(30, 29, 29)

In [13]:
metadata = pd.concat([train_df])
test_df = test_df
# metadata = metadata.drop_duplicates(subset="observationID").reset_index(drop=True)

# Calculating prios

In [14]:
TARGET_FEATURE = "category_id"

metadata = metadata.drop_duplicates(subset="observationID")

cls_counts = metadata.groupby(TARGET_FEATURE).size()
class_distribution = cls_counts / len(metadata)
sum(class_distribution)

1.0000000000000002

## Calculate Distributions of Selected Features

In [15]:
from utils.matadata_processing import get_target_to_feature_conditional_distributions

SELECTED_FEATURES = ["month", "habitat", "substrate", "biogeographicalRegion", "metaSubstrate"]

# test_df = test_df[~test_df[SELECTED_FEATURES].isna().any(axis=1)]

metadata_distributions = {}
for feature in SELECTED_FEATURES:
    metadata_distributions[feature] = get_target_to_feature_conditional_distributions(
        metadata,
        feature,
        TARGET_FEATURE,
        add_to_missing=False
    )

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.


# Predictions

## 1. Loading model from HuggingFace hub ⏳

In [16]:
N_CLASSES = len(metadata[TARGET_FEATURE].unique())
IMAGE_SIZE = [384, 384]

MODEL_NAME = "BVRA/tf_efficientnet_b3.in1k_ft_df24m_384"

USE_CALIBRATION = True
USE_OBSERVATION_PREDS = False

model = timm.create_model(f"hf-hub:{MODEL_NAME}", pretrained=True)
model = model.eval()

# model_mean = list(model.default_cfg['mean'])
# model_std = list(model.default_cfg['std'])
# print(model_mean, model_std)
model_mean = [0.5, 0.5, 0.5]
model_std = [0.5, 0.5, 0.5]

print(model_mean, model_std)

model.to(device)
model.eval()
print(f"Done. {device}")

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (BVRA/tf_efficientnet_b3.in1k_ft_df24m_384)
  return torch.load(cached_file, map_location='cpu')


[0.5, 0.5, 0.5] [0.5, 0.5, 0.5]
Done. mps


In [17]:
from fgvc.special.calibration import ModelWithTemperature, get_temperature

if USE_CALIBRATION:
    model = ModelWithTemperature(model)
    model.to(device)

## 2. Prepare Dataloader

In [18]:
from utils.DanishFungiDataset import DanishFungiDataset, get_transforms

test_dataset = DanishFungiDataset(
    test_df,
    image_path_feature='image_path',
    target_feature=TARGET_FEATURE,
    extra_features=[*SELECTED_FEATURES, "observationID"], 
    transform=get_transforms(model_mean, model_std, IMAGE_SIZE)
)

In [19]:
batch_size = 64

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

## Inference with pre-trained model

In [20]:
from utils.matadata_processing import predict_with_features

preds, preds_raw, GT_lbls, seen_features = predict_with_features(model, test_loader, device)

vanilla_f1 = f1_score(test_df[TARGET_FEATURE], preds, average='macro')
vanilla_accuracy = accuracy_score(test_df[TARGET_FEATURE], preds)
print('Vanilla:', vanilla_f1, vanilla_accuracy)
# vanilla_recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], preds_raw, k=3)
# vanilla_recall_5 = top_k_accuracy_score(test_df[TARGET_FEATURE], preds_raw, k=5)
# vanilla_recall_10 = top_k_accuracy_score(test_df[TARGET_FEATURE], preds_raw, k=10)

# print('Vanilla:', vanilla_f1, vanilla_accuracy, vanilla_recall_3, vanilla_recall_5, vanilla_recall_10)

  0%|          | 0/168 [00:00<?, ?it/s]INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable 

Vanilla: 0.4054764361151385 0.6752980625931445





### Average image predictions to get an observation prediction

In [21]:
seen_observation_ids = np.array(seen_features["observationID"])
unique_observation_ids = np.unique(seen_observation_ids)

preds_raw_np = np.array(preds_raw)

obs_preds_raw = np.zeros((len(test_df), N_CLASSES))
obs_preds = np.zeros((len(test_df)))

for unique_observation_id in unique_observation_ids:
    same_observation_indexes = np.where(seen_observation_ids == unique_observation_id)
    
    observation_predictions = preds_raw_np[same_observation_indexes]
    _obs_preds = np.average(observation_predictions, axis=0)
    obs_preds_raw[same_observation_indexes] = _obs_preds
    obs_preds[same_observation_indexes] = _obs_preds.argmax()
    
obs_f1 = f1_score(test_df[TARGET_FEATURE], obs_preds, average='macro')
obs_accuracy = accuracy_score(test_df[TARGET_FEATURE], obs_preds)
# obs_recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], obs_preds_raw, k=3)

# print('ObservationID:', obs_f1, obs_accuracy, obs_recall_3)
print('ObservationID:', obs_f1, obs_accuracy)

if USE_OBSERVATION_PREDS:
    vanilla_f1 = obs_f1
    vanilla_accuracy = obs_accuracy
    # vanilla_recall_3 = obs_recall_3
    preds_raw = obs_preds_raw


ObservationID: 0.5339519283417108 0.7779433681073026


## Weighting by each Selected Feature

In [25]:
from utils.matadata_processing  import weight_predictions_by_feature_distribution

def post_process_selected_features(metadata_distributions, class_distribution, raw_predictions, ground_truth_labels):
    feature_prior_ratios = {}
    metrics_by_features = {}
    for feature in SELECTED_FEATURES:
        metadata_distribution = metadata_distributions[feature]
        seen_feature_values = seen_features[feature]

        weighted_predictions, weighted_predictions_raw, feature_prior_ratio = weight_predictions_by_feature_distribution(
            target_to_feature_conditional_distributions=metadata_distribution,
            target_distribution=class_distribution,
            ground_truth_labels=ground_truth_labels,
            raw_predictions=raw_predictions,
            ground_truth_feature_categories=seen_feature_values
        )
        feature_prior_ratios[feature] = feature_prior_ratio

        f1 = f1_score(test_df[TARGET_FEATURE], weighted_predictions, average='macro')
        accuracy = accuracy_score(test_df[TARGET_FEATURE], weighted_predictions)
        # recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], weighted_predictions_raw, k=3)
        metrics_by_features[feature] = {
            "f1": f1,
            "accuracy": accuracy,
            # "recall_3": recall_3
        }
        print(f'{feature}:', f1, accuracy)
        # print(f'{feature}:', f1, accuracy, recall_3)

        # print(f'{feature} dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100))
        print(f'{feature} dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2))

    
    return feature_prior_ratios, metrics_by_features
        
feature_prior_ratios, metrics_by_features = post_process_selected_features(
    metadata_distributions=metadata_distributions,
    class_distribution=class_distribution,
    raw_predictions=preds_raw,
    ground_truth_labels=GT_lbls
)

100%|██████████| 10736/10736 [00:04<00:00, 2638.03it/s]


month: 0.42403487109197435 0.6831222056631893
month dif: 0.019 0.78


100%|██████████| 10736/10736 [00:04<00:00, 2424.55it/s]


habitat: 0.4248839336406116 0.6769746646795827
habitat dif: 0.019 0.17


100%|██████████| 10736/10736 [00:03<00:00, 2777.73it/s]


substrate: 0.42067831323161364 0.6820044709388972
substrate dif: 0.015 0.67


100%|██████████| 10736/10736 [00:04<00:00, 2385.05it/s]


biogeographicalRegion: 0.404321706657905 0.6748323397913562
biogeographicalRegion dif: -0.001 -0.05


100%|██████████| 10736/10736 [00:03<00:00, 2823.68it/s]


metaSubstrate: 0.41766837453289574 0.680327868852459
metaSubstrate dif: 0.012 0.5


## Weighting by Combinations of Selected Features

In [26]:
from itertools import combinations
from utils.matadata_processing import weight_predictions_combined_feature_priors


def post_process_prior_combinations(raw_predictions, feature_prior_ratios):
    metrics_by_combination = {}
    all_combinations_selected_features = []
    for num_features in range(2, len(SELECTED_FEATURES) + 1):
        all_combinations_selected_features.extend(combinations(SELECTED_FEATURES, num_features))
    
    for combination in all_combinations_selected_features:

        selected_feature_prior_ratios = [feature_prior_ratios[feature] for feature in combination]

        merged_predictions, merged_predictions_raw = weight_predictions_combined_feature_priors(
            raw_predictions=raw_predictions,
            feature_prior_ratios=selected_feature_prior_ratios
        )

        f1 = f1_score(test_df[TARGET_FEATURE], merged_predictions, average='macro')
        accuracy = accuracy_score(test_df[TARGET_FEATURE], merged_predictions)
        # recall_3 = top_k_accuracy_score(test_df[TARGET_FEATURE], merged_predictions_raw, k=3)
        
        combination_name = " + ".join(combination)
        
        metrics_by_combination[combination_name] = {
            "f1": f1,
            "accuracy": accuracy,
            # "recall_3": recall_3
        }
        print(combination_name)
        # print("F1, Acc, Recall3: ", f1, accuracy, recall_3)
        print("F1, Acc, Recall3: ", f1, accuracy)

        # print("Diff: ", np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))
        print("Diff: ", np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2))
    
    return metrics_by_combination
        
metrics_by_combination = post_process_prior_combinations(
    raw_predictions=preds_raw,
    feature_prior_ratios=feature_prior_ratios
)

month + habitat
F1, Acc, Recall3:  0.43831895801958404 0.6844262295081968
Diff:  0.033 0.91
month + substrate
F1, Acc, Recall3:  0.435706260862435 0.6902943368107303
Diff:  0.03 1.5
month + biogeographicalRegion
F1, Acc, Recall3:  0.4220751786799952 0.6823770491803278
Diff:  0.017 0.71
month + metaSubstrate
F1, Acc, Recall3:  0.4320342879633957 0.6890834575260805
Diff:  0.027 1.38


  return bound(*args, **kwds)


habitat + substrate
F1, Acc, Recall3:  0.4317493713360047 0.6797690014903129
Diff:  0.026 0.45
habitat + biogeographicalRegion
F1, Acc, Recall3:  0.4218073309903457 0.6743666169895678
Diff:  0.016 -0.09


  return bound(*args, **kwds)


habitat + metaSubstrate
F1, Acc, Recall3:  0.4310226072142507 0.6795827123695977
Diff:  0.026 0.43
substrate + biogeographicalRegion
F1, Acc, Recall3:  0.4166203602383838 0.6798621460506706
Diff:  0.011 0.46
substrate + metaSubstrate
F1, Acc, Recall3:  0.4241563057692613 0.6827496274217586
Diff:  0.019 0.75
biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.41361204796438555 0.6778129657228018
Diff:  0.008 0.25


  return bound(*args, **kwds)


month + habitat + substrate
F1, Acc, Recall3:  0.4442441045261595 0.6848919523099851
Diff:  0.039 0.96
month + habitat + biogeographicalRegion
F1, Acc, Recall3:  0.4316607121666508 0.6814456035767511
Diff:  0.026 0.61


  return bound(*args, **kwds)


month + habitat + metaSubstrate
F1, Acc, Recall3:  0.4461986054121738 0.6848919523099851
Diff:  0.041 0.96
month + substrate + biogeographicalRegion
F1, Acc, Recall3:  0.4303243108922367 0.6861028315946349
Diff:  0.025 1.08
month + substrate + metaSubstrate
F1, Acc, Recall3:  0.4332700773899358 0.688338301043219
Diff:  0.028 1.3
month + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.42428283521348037 0.6847988077496274
Diff:  0.019 0.95


  return bound(*args, **kwds)


habitat + substrate + biogeographicalRegion
F1, Acc, Recall3:  0.4232336780430707 0.6761363636363636
Diff:  0.018 0.08


  return bound(*args, **kwds)


habitat + substrate + metaSubstrate
F1, Acc, Recall3:  0.43381961235218214 0.6791169895678092
Diff:  0.028 0.38


  return bound(*args, **kwds)


habitat + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.42408575265979576 0.676322652757079
Diff:  0.019 0.1
substrate + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.4168406868094324 0.6795827123695977
Diff:  0.011 0.43


  return bound(*args, **kwds)


month + habitat + substrate + biogeographicalRegion
F1, Acc, Recall3:  0.43809148527370084 0.6809798807749627
Diff:  0.033 0.57


  return bound(*args, **kwds)


month + habitat + substrate + metaSubstrate
F1, Acc, Recall3:  0.4431125593781695 0.6847056631892697
Diff:  0.038 0.94


  return bound(*args, **kwds)


month + habitat + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.4383144649165884 0.6812593144560357
Diff:  0.033 0.6
month + substrate + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.4270881456343786 0.6861028315946349
Diff:  0.022 1.08


  return bound(*args, **kwds)


habitat + substrate + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.42334378301373676 0.675763785394933
Diff:  0.018 0.05


  return bound(*args, **kwds)


month + habitat + substrate + biogeographicalRegion + metaSubstrate
F1, Acc, Recall3:  0.4349311092923811 0.6805141579731744
Diff:  0.029 0.52


In [29]:
# results = {
#     "Vanilla":       {'f1': vanilla_f1, 'accuracy': vanilla_accuracy, 'recall_3': vanilla_recall_3},
#     "ObservationID": {'f1': obs_f1, 'accuracy': obs_accuracy, 'recall_3': obs_recall_3}
# }
results = {
    "Vanilla":       {'f1': vanilla_f1, 'accuracy': vanilla_accuracy},
    "ObservationID": {'f1': obs_f1, 'accuracy': obs_accuracy}
}
results.update(metrics_by_features)
results.update(metrics_by_combination)

results_df = pd.DataFrame(results).transpose()
results_df = results_df[['accuracy', 'f1']]
results_df.head(50)

Unnamed: 0,accuracy,f1
Vanilla,0.675298,0.405476
ObservationID,0.777943,0.533952
month,0.683122,0.424035
habitat,0.676975,0.424884
substrate,0.682004,0.420678
biogeographicalRegion,0.674832,0.404322
metaSubstrate,0.680328,0.417668
month + habitat,0.684426,0.438319
month + substrate,0.690294,0.435706
month + biogeographicalRegion,0.682377,0.422075


In [30]:
(results_df - results_df.iloc[0, :]) * 100

Unnamed: 0,accuracy,f1
Vanilla,0.0,0.0
ObservationID,10.264531,12.847549
month,0.782414,1.855843
habitat,0.16766,1.94075
substrate,0.670641,1.520188
biogeographicalRegion,-0.046572,-0.115473
metaSubstrate,0.502981,1.219194
month + habitat,0.912817,3.284252
month + substrate,1.499627,3.022982
month + biogeographicalRegion,0.707899,1.659874


In [31]:
(results_df - results_df.iloc[0, :]) * 100

Unnamed: 0,accuracy,f1
Vanilla,0.0,0.0
ObservationID,10.264531,12.847549
month,0.782414,1.855843
habitat,0.16766,1.94075
substrate,0.670641,1.520188
biogeographicalRegion,-0.046572,-0.115473
metaSubstrate,0.502981,1.219194
month + habitat,0.912817,3.284252
month + substrate,1.499627,3.022982
month + biogeographicalRegion,0.707899,1.659874


In [25]:
(results_df - results_df.iloc[0, :]) * 100

Unnamed: 0,accuracy,f1
Vanilla,0.0,0.0
ObservationID,11.570336,12.564606
month,1.168721,1.599317
biogeographicalRegion,0.38249,-0.250434
metaSubstrate,0.871228,1.698877
month + biogeographicalRegion,1.39184,1.515206
month + metaSubstrate,1.689333,2.751466
biogeographicalRegion + metaSubstrate,1.083723,1.479839
month + biogeographicalRegion + metaSubstrate,1.923077,3.340545


In [21]:
results_df.iloc[0, :]

accuracy    0.810242
f1          0.560898
Name: Vanilla, dtype: float64

In [22]:
from utils.matadata_processing import late_metadata_fusion

late_metadata_fusion(
    metadata,
    model,
    test_loader,
    device,
    TARGET_FEATURE,
    SELECTED_FEATURES
)

  0%|          | 0/148 [00:00<?, ?it/s]INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.13). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable 

ValueError: Number of classes in 'y_true' (193) not equal to the number of classes in 'y_score' (215).You can provide a list of all known classes by assigning it to the `labels` parameter.