In [1]:
import gc
import os
import cv2
import sys
import json
import tqdm
import time
import timm
import torch
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import scipy as sp
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD
from scipy.special import softmax
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

os.environ["CUDA_VISIBLE_DEVICES"]="2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Loading and Parsing Metadata

### Loading metadata files

In [25]:
train_metadata_genera = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020-Mini_train_metadata_DEV.csv")
test_metadata_genera = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020-Mini_test_metadata_DEV.csv")

In [26]:
print(len(train_metadata_genera), len(test_metadata_genera))

32753 3640


In [27]:
metadata = pd.concat([train_metadata, test_metadata])
len(metadata)

295938

In [28]:
test_metadata_genera.Habitat = test_metadata_genera.Habitat.replace(np.nan, 'unknown', regex=True)
test_metadata_genera.Substrate = test_metadata_genera.Substrate.replace(np.nan, 'unknown', regex=True)
# test_metadata.month = test_metadata.month.replace(np.nan, 'unknown', regex=True)

In [39]:
train_metadata_genera.Habitat = train_metadata_genera.Habitat.replace(np.nan, 'unknown', regex=True)
train_metadata_genera.Substrate = train_metadata_genera.Substrate.replace(np.nan, 'unknown', regex=True)
# metadata.month = metadata.month.replace(np.nan, 0, regex=True)

In [40]:
metadata.Substrate.unique()

array(['bark of living trees', 'dead wood (including bark)', 'soil',
       'stone', 'leaf or needle litter', 'wood', 'living leaves', 'bark',
       'siliceous stone', 'cones', 'peat mosses',
       'stems of herbs, grass etc', 'building stone (e.g. bricks)',
       'dead stems of herbs, grass etc', 'wood chips or mulch', 'unknown',
       'wood and roots of living trees',
       'living stems of herbs, grass etc', 'faeces', 'mosses', 'fungi',
       'insects', 'fruits', 'fire spot', 'catkins', 'lichens',
       'other substrate', 'calcareous stone', 'living flowers',
       'remains of vertebrates (e.g. feathers and fur)', 'liverworts',
       'mycetozoans'], dtype=object)

In [41]:
class_to_genus = np.zeros(len(train_metadata_genera['class_id'].unique()))
for species in train_metadata_genera['class_id'].unique():
    class_to_genus[species] = train_metadata_genera[train_metadata_genera['class_id'] == species]['genus_id'].unique()[0]

### Extracting Species distribution

In [42]:
class_priors = np.zeros(len(train_metadata_genera['class_id'].unique()))
for species in train_metadata_genera['class_id'].unique():
    class_priors[species] = len(train_metadata_genera[train_metadata_genera['class_id'] == species])

class_priors = class_priors/sum(class_priors)

### Extracting species-month distribution

In [43]:
month_distributions = {}

for _, observation in tqdm.tqdm(train_metadata_genera.iterrows(), total=len(train_metadata_genera)):
    month = str(observation.month)
    class_id = observation.class_id
    if month not in month_distributions:        
        month_distributions[month] = np.zeros(len(train_metadata_genera['class_id'].unique()))
    else:
        month_distributions[month][class_id] += 1

for key, value in month_distributions.items():
    month_distributions[key] = value / sum(value)

100%|██████████| 32753/32753 [00:04<00:00, 7303.11it/s]


### Extracting species-habitat distribution

In [44]:
habitat_distributions = {}

for _, observation in tqdm.tqdm(train_metadata_genera.iterrows(), total=len(train_metadata_genera)):
    habitat = observation.Habitat
    class_id = observation.class_id
    if habitat not in habitat_distributions:        
        habitat_distributions[habitat] = np.zeros(len(train_metadata_genera['class_id'].unique()))
    else:
        habitat_distributions[habitat][class_id] += 1

for key, value in habitat_distributions.items():
    habitat_distributions[key] = value / sum(value)

100%|██████████| 32753/32753 [00:03<00:00, 8288.71it/s]
  habitat_distributions[key] = value / sum(value)


### Extracting species-substrate distribution

In [45]:
substrate_distributions = {}

for _, observation in tqdm.tqdm(train_metadata_genera.iterrows(), total=len(train_metadata_genera)):
    substrate = observation.Substrate
    class_id = observation.class_id
    if substrate not in substrate_distributions:        
        substrate_distributions[substrate] = np.zeros(len(train_metadata_genera['class_id'].unique()))
    else:
        substrate_distributions[substrate][class_id] += 1

for key, value in substrate_distributions.items():
    substrate_distributions[key] = value / sum(value)

100%|██████████| 32753/32753 [00:04<00:00, 8167.25it/s]
  substrate_distributions[key] = value / sum(value)


## Predicting with trained network

In [None]:
def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [10]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        month = self.df['month'].values[idx]
        sub = self.df['Substrate'].values[idx]
        hab = self.df['Habitat'].values[idx]
        
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label, file_path, month, hab, sub


In [11]:
WIDTH, HEIGHT = 224, 224

def get_transforms():

    return Compose([Resize(WIDTH, HEIGHT),
                    Normalize(mean = model_mean, std = model_std),
                    ToTensorV2()])

In [12]:
def getModel(architecture_name, target_size, pretrained = False):
    net = timm.create_model(architecture_name, pretrained=pretrained)
    net_cfg = net.default_cfg
    last_layer = net_cfg['classifier']
    num_ftrs = getattr(net, last_layer).in_features
    setattr(net, last_layer, nn.Linear(num_ftrs, target_size))
    return net

In [None]:
NUM_CLASSES = 182


MODEL_NAME = 'vit_base_patch16_224'
#MODEL_NAME = 'vit_large_patch16_384'
model = getModel(MODEL_NAME, NUM_CLASSES, pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])
#model = nn.DataParallel(model)

#READ STATE DICT!!!
model.load_state_dict(torch.load('../../checkpoints/DF20M-ViT_base_patch16_224-GENERA-100E.pth'))

model.to(device)
model.eval()

In [14]:
#test_dataset_genera = TestDataset(test_metadata_genera, test_metadata_genera['class_id'], transform=get_transforms())
test_dataset = TestDataset(test_metadata_genera, transform=get_transforms())

In [15]:
batch_size = 16

#test_loader_genera = DataLoader(test_dataset_genera, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [16]:
avg_val_loss = 0.
preds = np.zeros((len(test_metadata_genera)))
GT_lbls = []
image_paths = []
preds_raw = []
criterion = nn.CrossEntropyLoss()
months = []
subs = []
habitats = []

for i, (images, labels, paths, M, H, S) in enumerate(tqdm.tqdm(test_loader, total=len(test_loader))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
    preds[i * batch_size: (i+1) * batch_size] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    months.extend(M)
    subs.extend(S)
    habitats.extend(H)

vanilla_f1 = f1_score(test_metadata_genera['class_id'], preds, average='macro')
vanilla_accuracy = accuracy_score(test_metadata_genera['class_id'], preds)
vanilla_recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], preds_raw, k=3)

print('Vanilla:', vanilla_f1, vanilla_accuracy, vanilla_recall_3)

100%|██████████| 228/228 [01:36<00:00,  2.37it/s]

Vanilla: 0.586413701939184 0.6925824175824176 0.8653846153846154





### Weighting by Habitat

In [46]:
wrong_predictions_H = []
weighted_predictions_H = []
weighted_predictions_raw_H = []
prior_ratio_H = []

for lbl, preds, hab in tqdm.tqdm(zip(GT_lbls, preds_raw, habitats), total=len(GT_lbls)):
    
    habitat_dist = habitat_distributions[hab]
    preds = softmax(preds)
    
    p_habitat = (preds * habitat_dist) / sum(preds * habitat_dist)
    prior_ratio = p_habitat / class_priors
    max_index = np.argmax(prior_ratio * preds)        
    
    prior_ratio_H.append(prior_ratio)
    weighted_predictions_raw_H.append(prior_ratio * preds)
    weighted_predictions_H.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_H.append([lbl, hab])

f1 = f1_score(test_metadata_genera['class_id'], weighted_predictions_H, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], weighted_predictions_H)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], weighted_predictions_raw_H, k=3)
print('Habitat:', f1, accuracy, recall_3)
print('Habitat dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 6617.22it/s]


Habitat: 0.6152773343556828 0.7096153846153846 0.8763736263736264
Habitat dif: 0.028863632416498808 0.017032967032967083 0.01098901098901095


### Weighting by Substrate

In [47]:
wrong_predictions_S = []
weighted_predictions_S = []
weighted_predictions_raw_S = []
prior_ratio_S = []

for lbl, preds, sub in tqdm.tqdm(zip(GT_lbls, preds_raw, subs), total=len(GT_lbls)):

    substrate_dist = substrate_distributions[sub]
    preds = softmax(preds)
    
    p_substrate = (preds * substrate_dist) / sum(preds * substrate_dist)
    prior_ratio = p_substrate / class_priors
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_S.append(prior_ratio)
    weighted_predictions_raw_S.append(prior_ratio * preds)
    weighted_predictions_S.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_S.append([lbl, sub])
        
f1 = f1_score(test_metadata_genera['class_id'], weighted_predictions_S, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], weighted_predictions_S)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], weighted_predictions_raw_S, k=3)
print('Substrate:', f1, accuracy, recall_3)
print('Substrate dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 6572.74it/s]

Substrate: 0.5971487376198034 0.7002747252747252 0.8673076923076923
Substrate dif: 0.010735035680619398 0.007692307692307665 0.0019230769230769162





### Weighting by Month

In [48]:
wrong_predictions_M = []
weighted_predictions_M = []
weighted_predictions_raw_M = []
prior_ratio_M = []

for lbl, preds, month in tqdm.tqdm(zip(GT_lbls, preds_raw, months), total=len(GT_lbls)):
    
    month_dist = month_distributions[str(float(month))]
    preds = softmax(preds)
    
    p_month = (preds * month_dist) / sum(preds * month_dist)
    prior_ratio = p_month / class_priors        
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_M.append(prior_ratio)
    weighted_predictions_raw_M.append(prior_ratio * preds)
    weighted_predictions_M.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_M.append([lbl, month])

f1 = f1_score(test_metadata_genera['class_id'], weighted_predictions_M, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], weighted_predictions_M)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], weighted_predictions_raw_M, k=3)
print('Month:', f1, accuracy, recall_3)
print('Month dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 5532.44it/s]

Month: 0.600374301036281 0.701098901098901 0.8722527472527473
Month dif: 0.013960599097097015 0.008516483516483486 0.006868131868131844





### Weighting by Month and Substrate

In [49]:
from scipy.special import softmax

merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
        
    m_pred = (preds * m * s) / sum(preds * m * s)
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)
    merged_predictions.append(max_index)
    
f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
print('M+S:' , f1, accuracy, recall_3)
print('M+S dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 7405.34it/s]


M+S: 0.6063674679875247 0.7054945054945055 0.8733516483516484
M+S dif: 0.01995376604834076 0.012912087912087977 0.007967032967032939


### Weighting by Month and Habitat

In [50]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * m * h) / sum((preds * m * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
print('M+H:', f1, accuracy, recall_3)    
print('M+H dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 7791.59it/s]


M+H: 0.6290117827056498 0.720054945054945 0.8854395604395604
M+H dif: 0.04259808076646587 0.027472527472527486 0.020054945054944984


### Weighting by Substrate and Habitat

In [51]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * s * h) / sum((preds * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
print('S+H:' , f1, accuracy, recall_3)
print('S+H dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 7550.01it/s]


S+H: 0.6234002590810134 0.7145604395604396 0.8777472527472527
S+H dif: 0.03698655714182941 0.02197802197802201 0.012362637362637319


### Weighting by Month, Substrate and Habitat

In [52]:
from scipy.special import softmax

wrong_predictions_all = []
merged_predictions = []
merged_predictions_raw = []

wrong_predictions_all_genus = []
merged_predictions_genus = []

for lbl, img_path, o, m, s, h in tqdm.tqdm(zip(GT_lbls, image_paths, preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
 
    m_pred = (preds * m * s * h) / sum((preds * m * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)
    
    merged_predictions_genus.append(class_to_genus[max_index])
    
    if lbl != max_index:
        wrong_predictions_all.append([lbl, max_index, img_path])
    
        if class_to_genus[lbl] != class_to_genus[max_index]:
            wrong_predictions_all_genus.append([lbl, max_index, img_path])
            
f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
recall_5 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=5)
recall_10 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=10)

print('All:', f1, accuracy, recall_3, recall_5, recall_10)
print('All dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 7386.86it/s]


All: 0.6330054459258191 0.7214285714285714 0.8818681318681318 0.9197802197802197 0.9505494505494505
All dif: 0.046591743986635126 0.028846153846153855 0.016483516483516425


In [53]:
f1 = f1_score(test_metadata_genera['genus_id'], merged_predictions_genus, average='macro')
accuracy = accuracy_score(test_metadata_genera['genus_id'], merged_predictions_genus)
print('All:', f1, accuracy)

All: 0.9439925038189818 0.9543956043956044
