In [1]:
import gc
import os
import cv2
import sys
import json
import tqdm
import time
import timm
import torch
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import scipy as sp
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD
from scipy.special import softmax
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Loading and Parsing Metadata

### Loading metadata files

In [3]:
train_metadata = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020_train_metadata_DEV.csv")
test_metadata = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020_test_metadata_DEV.csv")
test_metadata_genera = pd.read_csv("/Datasets/SvampeAtlas-14.12.2020/metadata/DanishFungi2020-Mini_test_metadata_DEV.csv")

In [4]:
print(len(train_metadata), len(test_metadata), len(test_metadata_genera))

266344 29594 3640


In [5]:
metadata = pd.concat([train_metadata, test_metadata])
len(metadata)

295938

In [6]:
test_metadata.Habitat = test_metadata.Habitat.replace(np.nan, 'unknown', regex=True)
test_metadata.Substrate = test_metadata.Substrate.replace(np.nan, 'unknown', regex=True)
# test_metadata.month = test_metadata.month.replace(np.nan, 'unknown', regex=True)

In [7]:
metadata.Habitat = metadata.Habitat.replace(np.nan, 'unknown', regex=True)
metadata.Substrate = metadata.Substrate.replace(np.nan, 'unknown', regex=True)
# metadata.month = metadata.month.replace(np.nan, 0, regex=True)

In [8]:
metadata.Substrate.unique()

array(['bark of living trees', 'dead wood (including bark)', 'soil',
       'stone', 'leaf or needle litter', 'wood', 'living leaves', 'bark',
       'siliceous stone', 'cones', 'peat mosses',
       'stems of herbs, grass etc', 'building stone (e.g. bricks)',
       'dead stems of herbs, grass etc', 'wood chips or mulch', 'unknown',
       'wood and roots of living trees',
       'living stems of herbs, grass etc', 'faeces', 'mosses', 'fungi',
       'insects', 'fruits', 'fire spot', 'catkins', 'lichens',
       'other substrate', 'calcareous stone', 'living flowers',
       'remains of vertebrates (e.g. feathers and fur)', 'liverworts',
       'mycetozoans'], dtype=object)

In [9]:
class_to_genus = np.zeros(len(metadata['class_id'].unique()))
for species in metadata['class_id'].unique():
    class_to_genus[species] = metadata[metadata['class_id'] == species]['genus_id'].unique()[0]

### Extracting Species distribution

In [10]:
class_priors = np.zeros(len(metadata['class_id'].unique()))
for species in metadata['class_id'].unique():
    class_priors[species] = len(metadata[metadata['class_id'] == species])

class_priors = class_priors/sum(class_priors)

### Extracting species-month distribution

In [11]:
month_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    month = str(observation.month)
    class_id = observation.class_id
    if month not in month_distributions:        
        month_distributions[month] = np.zeros(len(metadata['class_id'].unique()))
    else:
        month_distributions[month][class_id] += 1

for key, value in month_distributions.items():
    month_distributions[key] = value / sum(value)

100%|██████████| 295938/295938 [00:21<00:00, 13837.09it/s]


### Extracting species-habitat distribution

In [12]:
habitat_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    habitat = observation.Habitat
    class_id = observation.class_id
    if habitat not in habitat_distributions:        
        habitat_distributions[habitat] = np.zeros(len(metadata['class_id'].unique()))
    else:
        habitat_distributions[habitat][class_id] += 1

for key, value in habitat_distributions.items():
    habitat_distributions[key] = value / sum(value)

100%|██████████| 295938/295938 [00:21<00:00, 13903.63it/s]


### Extracting species-substrate distribution

In [13]:
substrate_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    substrate = observation.Substrate
    class_id = observation.class_id
    if substrate not in substrate_distributions:        
        substrate_distributions[substrate] = np.zeros(len(metadata['class_id'].unique()))
    else:
        substrate_distributions[substrate][class_id] += 1

for key, value in substrate_distributions.items():
    substrate_distributions[key] = value / sum(value)

100%|██████████| 295938/295938 [00:21<00:00, 13832.36it/s]


## Predicting with trained network

In [14]:
def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [15]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        month = self.df['month'].values[idx]
        sub = self.df['Substrate'].values[idx]
        hab = self.df['Habitat'].values[idx]
        
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label, file_path, month, hab, sub


In [16]:
WIDTH, HEIGHT = 224, 224

def get_transforms():

    return Compose([Resize(WIDTH, HEIGHT),
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],),
                    ToTensorV2()])

In [None]:
NUM_CLASSES = 1604


from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

model._fc = nn.Linear(model._fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load('../../checkpoints/DF20-EfficientNet-B0_best_accuracy.pth'))

model.to(device)
model.eval()

In [18]:
#test_dataset_genera = TestDataset(test_metadata_genera, test_metadata_genera['class_id'], transform=get_transforms())
test_dataset = TestDataset(test_metadata, transform=get_transforms())

In [19]:
batch_size = 64

#test_loader_genera = DataLoader(test_dataset_genera, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [20]:
avg_val_loss = 0.
preds = np.zeros((len(test_metadata)))
GT_lbls = []
image_paths = []
preds_raw = []
criterion = nn.CrossEntropyLoss()
months = []
subs = []
habitats = []

for i, (images, labels, paths, M, H, S) in enumerate(tqdm.tqdm(test_loader, total=len(test_loader))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
    preds[i * batch_size: (i+1) * batch_size] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    months.extend(M)
    subs.extend(S)
    habitats.extend(H)

vanilla_f1 = f1_score(test_metadata['class_id'], preds, average='macro')
vanilla_accuracy = accuracy_score(test_metadata['class_id'], preds)
vanilla_recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)

print('Vanilla:', vanilla_f1, vanilla_accuracy, vanilla_recall_3)

100%|██████████| 463/463 [05:57<00:00,  1.30it/s]


Vanilla: 0.5684921240794554 0.6590525106440495 0.8224640129756031


### Weighting by Habitat

In [21]:
wrong_predictions_H = []
weighted_predictions_H = []
weighted_predictions_raw_H = []
prior_ratio_H = []

for lbl, preds, hab in tqdm.tqdm(zip(GT_lbls, preds_raw, habitats), total=len(GT_lbls)):
    
    habitat_dist = habitat_distributions[hab]
    preds = softmax(preds)
    
    p_habitat = (preds * habitat_dist) / sum(preds * habitat_dist)
    prior_ratio = p_habitat / class_priors
    max_index = np.argmax(prior_ratio * preds)        
    
    prior_ratio_H.append(prior_ratio)
    weighted_predictions_raw_H.append(prior_ratio * preds)
    weighted_predictions_H.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_H.append([lbl, hab])

f1 = f1_score(test_metadata['class_id'], weighted_predictions_H, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], weighted_predictions_H)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], weighted_predictions_raw_H, k=3)
print('Habitat:', f1, accuracy, recall_3)
print('Habitat dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:10<00:00, 2716.07it/s]


Habitat: 0.6121738396572888 0.6853078326687843 0.848077312968845
Habitat dif: 0.0436817155778334 0.026255322024734795 0.02561329999324191


### Weighting by Substrate

In [22]:
wrong_predictions_S = []
weighted_predictions_S = []
weighted_predictions_raw_S = []
prior_ratio_S = []

for lbl, preds, sub in tqdm.tqdm(zip(GT_lbls, preds_raw, subs), total=len(GT_lbls)):

    substrate_dist = substrate_distributions[sub]
    preds = softmax(preds)
    
    p_substrate = (preds * substrate_dist) / sum(preds * substrate_dist)
    prior_ratio = p_substrate / class_priors
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_S.append(prior_ratio)
    weighted_predictions_raw_S.append(prior_ratio * preds)
    weighted_predictions_S.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_S.append([lbl, sub])
        
f1 = f1_score(test_metadata['class_id'], weighted_predictions_S, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], weighted_predictions_S)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], weighted_predictions_raw_S, k=3)
print('Substrate:', f1, accuracy, recall_3)
print('Substrate dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:10<00:00, 2773.92it/s]


Substrate: 0.5994712645558894 0.6784821247550179 0.8401703047915118
Substrate dif: 0.03097914047643402 0.01942961411096844 0.01770629181590866


### Weighting by Month

In [23]:
wrong_predictions_M = []
weighted_predictions_M = []
weighted_predictions_raw_M = []
prior_ratio_M = []

for lbl, preds, month in tqdm.tqdm(zip(GT_lbls, preds_raw, months), total=len(GT_lbls)):
    
    month_dist = month_distributions[str(float(month))]
    preds = softmax(preds)
    
    p_month = (preds * month_dist) / sum(preds * month_dist)
    prior_ratio = p_month / class_priors        
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_M.append(prior_ratio)
    weighted_predictions_raw_M.append(prior_ratio * preds)
    weighted_predictions_M.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_M.append([lbl, month])

f1 = f1_score(test_metadata['class_id'], weighted_predictions_M, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], weighted_predictions_M)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], weighted_predictions_raw_M, k=3)
print('Month:', f1, accuracy, recall_3)
print('Month dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:12<00:00, 2325.92it/s]


Month: 0.5909418742168571 0.6750692707981347 0.8390214232614719
Month dif: 0.02244975013740169 0.016016760154085263 0.016557410285868768


### Weighting by Month and Substrate

In [24]:
from scipy.special import softmax

merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
        
    m_pred = (preds * m * s) / sum(preds * m * s)
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)
    merged_predictions.append(max_index)
    
f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
print('M+S:' , f1, accuracy, recall_3)
print('M+S dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:13<00:00, 2239.21it/s]


M+S: 0.6223293544675633 0.6940933973102656 0.8564911806447253
M+S dif: 0.05383723038810795 0.03504088666621619 0.03402716766912217


### Weighting by Month and Habitat

In [25]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * m * h) / sum((preds * m * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
print('M+H:', f1, accuracy, recall_3)    
print('M+H dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:12<00:00, 2327.14it/s]


M+H: 0.6387567782394821 0.701358383456106 0.8627762384267081
M+H dif: 0.0702646541600267 0.042305872812056555 0.04031222545110502


### Weighting by Substrate and Habitat

In [26]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * s * h) / sum((preds * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
print('S+H:' , f1, accuracy, recall_3)
print('S+H dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:12<00:00, 2317.67it/s]


S+H: 0.6418046227775008 0.7031830776508752 0.8640264918564574
S+H dif: 0.07331249869804546 0.044130567006825716 0.04156247888085429


### Weighting by Month, Substrate and Habitat

In [27]:
from scipy.special import softmax

wrong_predictions_all = []
merged_predictions = []
merged_predictions_raw = []

wrong_predictions_all_genus = []
merged_predictions_genus = []

for lbl, img_path, o, m, s, h in tqdm.tqdm(zip(GT_lbls, image_paths, preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
 
    m_pred = (preds * m * s * h) / sum((preds * m * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)
    
    merged_predictions_genus.append(class_to_genus[max_index])
    
    if lbl != max_index:
        wrong_predictions_all.append([lbl, max_index, img_path])
    
        if class_to_genus[lbl] != class_to_genus[max_index]:
            wrong_predictions_all_genus.append([lbl, max_index, img_path])
            
f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
recall_5 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=5)
recall_10 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=10)

print('All:', f1, accuracy, recall_3, recall_5, recall_10)
print('All dif:', f1-vanilla_f1, accuracy-vanilla_accuracy, recall_3-vanilla_recall_3)

100%|██████████| 29594/29594 [00:10<00:00, 2700.81it/s]


All: 0.655772280031751 0.710988713928499 0.8724065688991012 0.9120091910522403 0.9463404744204906
All dif: 0.08728015595229566 0.05193620328444959 0.049942555923498055


In [28]:
f1 = f1_score(test_metadata['genus_id'], merged_predictions_genus, average='macro')
accuracy = accuracy_score(test_metadata['genus_id'], merged_predictions_genus)
print('All:', f1, accuracy)

All: 0.734368922922385 0.8061093464891532
