In [1]:
import gc
import os
import cv2
import sys
import json
import tqdm
import time
import timm
import torch
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import scipy as sp
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD
from scipy.special import softmax
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Loading and Parsing Metadata

### Loading metadata files

In [3]:
train_metadata_genera = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020-Mini_train_metadata_DEV.csv")
test_metadata_genera = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020-Mini_test_metadata_DEV.csv")

In [4]:
print(len(train_metadata_genera), len(test_metadata_genera))

32753 3640


In [5]:
metadata = pd.concat([train_metadata_genera, test_metadata_genera])
len(metadata)

36393

In [7]:
test_metadata_genera.Habitat = test_metadata_genera.Habitat.replace(np.nan, 'unknown', regex=True)
test_metadata_genera.Substrate = test_metadata_genera.Substrate.replace(np.nan, 'unknown', regex=True)
# test_metadata.month = test_metadata.month.replace(np.nan, 'unknown', regex=True)

In [8]:
train_metadata_genera.Habitat = train_metadata_genera.Habitat.replace(np.nan, 'unknown', regex=True)
train_metadata_genera.Substrate = train_metadata_genera.Substrate.replace(np.nan, 'unknown', regex=True)
# metadata.month = metadata.month.replace(np.nan, 0, regex=True)

In [9]:
metadata.Substrate.unique()

array(['dead wood (including bark)', 'soil', 'bark', 'cones',
       'leaf or needle litter', 'wood', 'mosses', 'bark of living trees',
       nan, 'stems of herbs, grass etc', 'wood chips or mulch', 'catkins',
       'other substrate', 'dead stems of herbs, grass etc', 'peat mosses',
       'fungi', 'faeces', 'wood and roots of living trees',
       'living stems of herbs, grass etc', 'fruits', 'fire spot',
       'lichens', 'living leaves', 'building stone (e.g. bricks)',
       'liverworts'], dtype=object)

In [10]:
class_to_genus = np.zeros(len(train_metadata_genera['class_id'].unique()))
for species in train_metadata_genera['class_id'].unique():
    class_to_genus[species] = train_metadata_genera[train_metadata_genera['class_id'] == species]['genus_id'].unique()[0]

### Extracting Species distribution

In [11]:
class_priors = np.zeros(len(train_metadata_genera['class_id'].unique()))
for species in train_metadata_genera['class_id'].unique():
    class_priors[species] = len(train_metadata_genera[train_metadata_genera['class_id'] == species])

class_priors = class_priors/sum(class_priors)

### Extracting species-month distribution

In [12]:
month_distributions = {}

for _, observation in tqdm.tqdm(train_metadata_genera.iterrows(), total=len(train_metadata_genera)):
    month = str(observation.month)
    class_id = observation.class_id
    if month not in month_distributions:        
        month_distributions[month] = np.zeros(len(train_metadata_genera['class_id'].unique()))
    else:
        month_distributions[month][class_id] += 1

for key, value in month_distributions.items():
    month_distributions[key] = value / sum(value)

100%|██████████| 32753/32753 [00:02<00:00, 13737.60it/s]


### Extracting species-habitat distribution

In [13]:
habitat_distributions = {}

for _, observation in tqdm.tqdm(train_metadata_genera.iterrows(), total=len(train_metadata_genera)):
    habitat = observation.Habitat
    class_id = observation.class_id
    if habitat not in habitat_distributions:        
        habitat_distributions[habitat] = np.zeros(len(train_metadata_genera['class_id'].unique()))
    else:
        habitat_distributions[habitat][class_id] += 1

for key, value in habitat_distributions.items():
    habitat_distributions[key] = value / sum(value)

100%|██████████| 32753/32753 [00:02<00:00, 13774.56it/s]
  habitat_distributions[key] = value / sum(value)


### Extracting species-substrate distribution

In [14]:
substrate_distributions = {}

for _, observation in tqdm.tqdm(train_metadata_genera.iterrows(), total=len(train_metadata_genera)):
    substrate = observation.Substrate
    class_id = observation.class_id
    if substrate not in substrate_distributions:        
        substrate_distributions[substrate] = np.zeros(len(train_metadata_genera['class_id'].unique()))
    else:
        substrate_distributions[substrate][class_id] += 1

for key, value in substrate_distributions.items():
    substrate_distributions[key] = value / sum(value)

100%|██████████| 32753/32753 [00:02<00:00, 14006.37it/s]
  substrate_distributions[key] = value / sum(value)


## Predicting with trained network

In [15]:
def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [16]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        month = self.df['month'].values[idx]
        sub = self.df['Substrate'].values[idx]
        hab = self.df['Habitat'].values[idx]
        
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label, file_path, month, hab, sub


In [17]:
WIDTH, HEIGHT = 224, 224

def get_transforms():

    return Compose([Resize(WIDTH, HEIGHT),
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],),
                    ToTensorV2()])

In [None]:
NUM_CLASSES = 182


from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b3')

model._fc = nn.Linear(model._fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load('../../checkpoints/DF20M-EfficientNet-B3-224-GENERA_best_accuracy.pth'))

model.to(device)
model.eval()

In [19]:
#test_dataset_genera = TestDataset(test_metadata_genera, test_metadata_genera['class_id'], transform=get_transforms())
test_dataset = TestDataset(test_metadata_genera, transform=get_transforms())

In [20]:
batch_size = 32

#test_loader_genera = DataLoader(test_dataset_genera, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [21]:
avg_val_loss = 0.
preds = np.zeros((len(test_metadata_genera)))
GT_lbls = []
image_paths = []
preds_raw = []
criterion = nn.CrossEntropyLoss()
months = []
subs = []
habitats = []

for i, (images, labels, paths, M, H, S) in enumerate(tqdm.tqdm(test_loader, total=len(test_loader))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
    preds[i * batch_size: (i+1) * batch_size] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    months.extend(M)
    subs.extend(S)
    habitats.extend(H)

vanilla_f1 = f1_score(test_metadata_genera['class_id'], preds, average='macro')
vanilla_accuracy = accuracy_score(test_metadata_genera['class_id'], preds)
vanilla_recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], preds_raw, k=3)

print('Vanilla:', vanilla_f1, vanilla_accuracy, vanilla_recall_3)

100%|██████████| 114/114 [00:48<00:00,  2.33it/s]

Vanilla: 0.5367488844867935 0.6689560439560439 0.8348901098901099





### Weighting by Habitat

In [22]:
wrong_predictions_H = []
weighted_predictions_H = []
weighted_predictions_raw_H = []
prior_ratio_H = []

for lbl, preds, hab in tqdm.tqdm(zip(GT_lbls, preds_raw, habitats), total=len(GT_lbls)):
    
    habitat_dist = habitat_distributions[hab]
    preds = softmax(preds)
    
    p_habitat = (preds * habitat_dist) / sum(preds * habitat_dist)
    prior_ratio = p_habitat / class_priors
    max_index = np.argmax(prior_ratio * preds)        
    
    prior_ratio_H.append(prior_ratio)
    weighted_predictions_raw_H.append(prior_ratio * preds)
    weighted_predictions_H.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_H.append([lbl, hab])

f1 = f1_score(test_metadata_genera['class_id'], weighted_predictions_H, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], weighted_predictions_H)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], weighted_predictions_raw_H, k=3)
print('Habitat:', f1, accuracy, recall_3)
print('Habitat dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 12736.48it/s]


Habitat: 0.5846332417443885 0.6881868131868132 0.8532967032967033
Habitat dif: 0.04788435725759499 0.019230769230769273 0.01840659340659334


### Weighting by Substrate

In [23]:
wrong_predictions_S = []
weighted_predictions_S = []
weighted_predictions_raw_S = []
prior_ratio_S = []

for lbl, preds, sub in tqdm.tqdm(zip(GT_lbls, preds_raw, subs), total=len(GT_lbls)):

    substrate_dist = substrate_distributions[sub]
    preds = softmax(preds)
    
    p_substrate = (preds * substrate_dist) / sum(preds * substrate_dist)
    prior_ratio = p_substrate / class_priors
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_S.append(prior_ratio)
    weighted_predictions_raw_S.append(prior_ratio * preds)
    weighted_predictions_S.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_S.append([lbl, sub])
        
f1 = f1_score(test_metadata_genera['class_id'], weighted_predictions_S, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], weighted_predictions_S)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], weighted_predictions_raw_S, k=3)
print('Substrate:', f1, accuracy, recall_3)
print('Substrate dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 12306.43it/s]


Substrate: 0.5506785660491035 0.6755494505494506 0.8387362637362638
Substrate dif: 0.013929681562310003 0.006593406593406681 0.0038461538461538325


### Weighting by Month

In [24]:
wrong_predictions_M = []
weighted_predictions_M = []
weighted_predictions_raw_M = []
prior_ratio_M = []

for lbl, preds, month in tqdm.tqdm(zip(GT_lbls, preds_raw, months), total=len(GT_lbls)):
    
    month_dist = month_distributions[str(float(month))]
    preds = softmax(preds)
    
    p_month = (preds * month_dist) / sum(preds * month_dist)
    prior_ratio = p_month / class_priors        
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_M.append(prior_ratio)
    weighted_predictions_raw_M.append(prior_ratio * preds)
    weighted_predictions_M.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_M.append([lbl, month])

f1 = f1_score(test_metadata_genera['class_id'], weighted_predictions_M, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], weighted_predictions_M)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], weighted_predictions_raw_M, k=3)
print('Month:', f1, accuracy, recall_3)
print('Month dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 12235.60it/s]


Month: 0.5681291828843478 0.6832417582417583 0.8467032967032967
Month dif: 0.031380298397554296 0.014285714285714346 0.011813186813186771


### Weighting by Month and Substrate

In [25]:
from scipy.special import softmax

merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
        
    m_pred = (preds * m * s) / sum(preds * m * s)
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)
    merged_predictions.append(max_index)
    
f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
print('M+S:' , f1, accuracy, recall_3)
print('M+S dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 12146.21it/s]


M+S: 0.5772118550629611 0.6887362637362637 0.8489010989010989
M+S dif: 0.040462970576167656 0.01978021978021982 0.014010989010988961


### Weighting by Month and Habitat

In [26]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * m * h) / sum((preds * m * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
print('M+H:', f1, accuracy, recall_3)    
print('M+H dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 12294.01it/s]


M+H: 0.612410064050507 0.7057692307692308 0.8620879120879121
M+H dif: 0.07566117956371354 0.036813186813186904 0.027197802197802212


### Weighting by Substrate and Habitat

In [27]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * s * h) / sum((preds * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
print('S+H:' , f1, accuracy, recall_3)
print('S+H dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 12827.93it/s]


S+H: 0.5927153578545372 0.6928571428571428 0.8546703296703296
S+H dif: 0.05596647336774374 0.023901098901098927 0.01978021978021971


### Weighting by Month, Substrate and Habitat

In [28]:
from scipy.special import softmax

wrong_predictions_all = []
merged_predictions = []
merged_predictions_raw = []

wrong_predictions_all_genus = []
merged_predictions_genus = []

for lbl, img_path, o, m, s, h in tqdm.tqdm(zip(GT_lbls, image_paths, preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
 
    m_pred = (preds * m * s * h) / sum((preds * m * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)
    
    merged_predictions_genus.append(class_to_genus[max_index])
    
    if lbl != max_index:
        wrong_predictions_all.append([lbl, max_index, img_path])
    
        if class_to_genus[lbl] != class_to_genus[max_index]:
            wrong_predictions_all_genus.append([lbl, max_index, img_path])
            
f1 = f1_score(test_metadata_genera['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata_genera['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=3)
recall_5 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=5)
recall_10 = top_k_accuracy_score(test_metadata_genera['class_id'], merged_predictions_raw, k=10)

print('All:', f1, accuracy, recall_3, recall_5, recall_10)
print('All dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 3640/3640 [00:00<00:00, 11460.59it/s]

All: 0.6099220252435588 0.7032967032967034 0.8596153846153847 0.9038461538461539 0.9436813186813187
All dif: 0.07317314075676529 0.03434065934065944 0.02472527472527475





In [29]:
f1 = f1_score(test_metadata_genera['genus_id'], merged_predictions_genus, average='macro')
accuracy = accuracy_score(test_metadata_genera['genus_id'], merged_predictions_genus)
print('All:', f1, accuracy)

All: 0.9332625307964019 0.9456043956043956
