In [None]:
import gc
import os
import cv2
import sys
import json
import tqdm
import time
import timm
import torch
import random
import sklearn.metrics

from PIL import Image
from pathlib import Path
from functools import partial
from contextlib import contextmanager

import numpy as np
import scipy as sp
import pandas as pd
import torch.nn as nn

from torch.optim import Adam, SGD
from scipy.special import softmax
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from albumentations import Compose, Normalize, Resize
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
!nvidia-smi

## Loading and Parsing Metadata

### Loading metadata files

In [None]:
train_metadata = pd.read_csv("../../../resources/DF20/DanishFungi2020_train_metadata_DEV.csv")
test_metadata = pd.read_csv("../../../resources/DF20/DanishFungi2020_test_metadata_DEV.csv")

In [None]:
test_metadata['image_path'] = test_metadata.apply(lambda x: '/local/nahouby/Datasets/DF20/' + x['image_path'].split('/SvampeAtlas-14.12.2020/')[-1], axis=1)
test_metadata['image_path'] = test_metadata.apply(lambda x:  x['image_path'].split('.')[0] + '.JPG', axis=1)

In [None]:
print(len(train_metadata), len(test_metadata))

In [None]:
metadata = pd.concat([train_metadata, test_metadata])
len(metadata)

In [None]:
test_metadata.Habitat = test_metadata.Habitat.replace(np.nan, 'unknown', regex=True)
test_metadata.Substrate = test_metadata.Substrate.replace(np.nan, 'unknown', regex=True)
# test_metadata.month = test_metadata.month.replace(np.nan, 'unknown', regex=True)

In [None]:
metadata.Habitat = metadata.Habitat.replace(np.nan, 'unknown', regex=True)
metadata.Substrate = metadata.Substrate.replace(np.nan, 'unknown', regex=True)
# metadata.month = metadata.month.replace(np.nan, 0, regex=True)

In [None]:
metadata.Substrate.unique()

In [None]:
class_to_genus = np.zeros(len(metadata['class_id'].unique()))
for species in metadata['class_id'].unique():
    class_to_genus[species] = metadata[metadata['class_id'] == species]['genus_id'].unique()[0]

### Extracting Species distribution

In [None]:
class_priors = np.zeros(len(metadata['class_id'].unique()))
for species in metadata['class_id'].unique():
    class_priors[species] = len(metadata[metadata['class_id'] == species])

class_priors = class_priors/sum(class_priors)

### Extracting species-month distribution

In [None]:
month_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    month = str(observation.month)
    class_id = observation.class_id
    if month not in month_distributions:        
        month_distributions[month] = np.zeros(len(metadata['class_id'].unique()))
    else:
        month_distributions[month][class_id] += 1

for key, value in month_distributions.items():
    month_distributions[key] = value / sum(value)

### Extracting species-habitat distribution

In [None]:
habitat_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    habitat = observation.Habitat
    class_id = observation.class_id
    if habitat not in habitat_distributions:        
        habitat_distributions[habitat] = np.zeros(len(metadata['class_id'].unique()))
    else:
        habitat_distributions[habitat][class_id] += 1

for key, value in habitat_distributions.items():
    habitat_distributions[key] = value / sum(value)

### Extracting species-substrate distribution

In [None]:
substrate_distributions = {}

for _, observation in tqdm.tqdm(metadata.iterrows(), total=len(metadata)):
    substrate = observation.Substrate
    class_id = observation.class_id
    if substrate not in substrate_distributions:        
        substrate_distributions[substrate] = np.zeros(len(metadata['class_id'].unique()))
    else:
        substrate_distributions[substrate][class_id] += 1

for key, value in substrate_distributions.items():
    substrate_distributions[key] = value / sum(value)

## Predicting with trained network

In [None]:
def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [None]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        month = self.df['month'].values[idx]
        sub = self.df['Substrate'].values[idx]
        hab = self.df['Habitat'].values[idx]
        
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label, file_path, month, hab, sub


In [None]:
WIDTH, HEIGHT = 384, 384

def get_transforms():

    return Compose([Resize(WIDTH, HEIGHT),
                    Normalize(mean = model_mean, std = model_std),
                    ToTensorV2()])

def getModel(architecture_name, target_size, pretrained = False):
    net = timm.create_model(architecture_name, pretrained=pretrained)
    net_cfg = net.default_cfg
    last_layer = net_cfg['classifier']
    num_ftrs = getattr(net, last_layer).in_features
    setattr(net, last_layer, nn.Linear(num_ftrs, target_size))
    return net

In [None]:
N_CLASSES = len(train_metadata['class_id'].unique())

MODEL_NAME = 'vit_large_patch16_384'
model = getModel(MODEL_NAME, N_CLASSES, pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])
model.load_state_dict(torch.load('../../../checkpoints/DF20-ViT_large_patch16_384-100E.pth'))

model.to(device)
model.eval()
print("Done.")

In [None]:
BATCH_SIZE = 16

test_dataset = TestDataset(test_metadata, transform=get_transforms())
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
avg_val_loss = 0.
preds = np.zeros((len(test_metadata)))
GT_lbls = []
image_paths = []
preds_raw = []
criterion = nn.CrossEntropyLoss()
months = []
subs = []
habitats = []

for i, (images, labels, paths, M, H, S) in enumerate(tqdm.tqdm(test_loader, total=len(test_loader))):

    images = images.to(device)
    labels = labels.to(device)
    
    with torch.no_grad():
        y_preds = model(images)
    preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
    GT_lbls.extend(labels.to('cpu').numpy())
    preds_raw.extend(y_preds.to('cpu').numpy())
    image_paths.extend(paths)
    months.extend(M)
    subs.extend(S)
    habitats.extend(H)

vanilla_f1 = f1_score(test_metadata['class_id'], preds, average='macro')
vanilla_accuracy = accuracy_score(test_metadata['class_id'], preds)
vanilla_recall_3 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=3)
vanilla_recall_5 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=5)
vanilla_recall_10 = top_k_accuracy_score(test_metadata['class_id'], preds_raw, k=10)

print('Vanilla:', vanilla_f1, vanilla_accuracy, vanilla_recall_3, vanilla_recall_5, vanilla_recall_10)

### Weighting by Habitat

In [None]:
wrong_predictions_H = []
weighted_predictions_H = []
weighted_predictions_raw_H = []
prior_ratio_H = []

for lbl, preds, hab in tqdm.tqdm(zip(GT_lbls, preds_raw, habitats), total=len(GT_lbls)):
    
    habitat_dist = habitat_distributions[hab]
    preds = softmax(preds)
    
    p_habitat = (preds * habitat_dist) / sum(preds * habitat_dist)
    prior_ratio = p_habitat / class_priors
    max_index = np.argmax(prior_ratio * preds)        
    
    prior_ratio_H.append(prior_ratio)
    weighted_predictions_raw_H.append(prior_ratio * preds)
    weighted_predictions_H.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_H.append([lbl, hab])

f1 = f1_score(test_metadata['class_id'], weighted_predictions_H, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], weighted_predictions_H)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], weighted_predictions_raw_H, k=3)
print('Habitat:', f1, accuracy, recall_3)
print('Habitat dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

### Weighting by Substrate

In [None]:
wrong_predictions_S = []
weighted_predictions_S = []
weighted_predictions_raw_S = []
prior_ratio_S = []

for lbl, preds, sub in tqdm.tqdm(zip(GT_lbls, preds_raw, subs), total=len(GT_lbls)):

    substrate_dist = substrate_distributions[sub]
    preds = softmax(preds)
    
    p_substrate = (preds * substrate_dist) / sum(preds * substrate_dist)
    prior_ratio = p_substrate / class_priors
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_S.append(prior_ratio)
    weighted_predictions_raw_S.append(prior_ratio * preds)
    weighted_predictions_S.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_S.append([lbl, sub])
        
f1 = f1_score(test_metadata['class_id'], weighted_predictions_S, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], weighted_predictions_S)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], weighted_predictions_raw_S, k=3)
print('Substrate:', f1, accuracy, recall_3)
print('Substrate dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

### Weighting by Month

In [None]:
wrong_predictions_M = []
weighted_predictions_M = []
weighted_predictions_raw_M = []
prior_ratio_M = []

for lbl, preds, month in tqdm.tqdm(zip(GT_lbls, preds_raw, months), total=len(GT_lbls)):
    
    month_dist = month_distributions[str(float(month))]
    preds = softmax(preds)
    
    p_month = (preds * month_dist) / sum(preds * month_dist)
    prior_ratio = p_month / class_priors        
    max_index = np.argmax(prior_ratio * preds)     
    
    prior_ratio_M.append(prior_ratio)
    weighted_predictions_raw_M.append(prior_ratio * preds)
    weighted_predictions_M.append(max_index)
    
    if lbl != max_index:
        wrong_predictions_M.append([lbl, month])

f1 = f1_score(test_metadata['class_id'], weighted_predictions_M, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], weighted_predictions_M)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], weighted_predictions_raw_M, k=3)
print('Month:', f1, accuracy, recall_3)
print('Month dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

### Weighting by Month and Substrate

In [None]:
from scipy.special import softmax

merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
        
    m_pred = (preds * m * s) / sum(preds * m * s)
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)
    merged_predictions.append(max_index)
    
f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
print('M+S:' , f1, accuracy, recall_3)
print('M+S dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

### Weighting by Month and Habitat

In [None]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * m * h) / sum((preds * m * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
recall_5 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=5)
recall_10 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=10)

print('M+H:', f1, accuracy, recall_3, recall_5, recall_10)
print('M+H dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

### Weighting by Substrate and Habitat

In [None]:
merged_predictions = []
merged_predictions_raw = []

for o, m, s, h in tqdm.tqdm(zip(preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
    
    m_pred = (preds * s * h) / sum((preds * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)

f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)

print('S+H:' , f1, accuracy, recall_3)
print('S+H dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

### Weighting by Month, Substrate and Habitat

In [None]:
from scipy.special import softmax

wrong_predictions_all = []
merged_predictions = []
merged_predictions_raw = []

wrong_predictions_all_genus = []
merged_predictions_genus = []

for lbl, img_path, o, m, s, h in tqdm.tqdm(zip(GT_lbls, image_paths, preds_raw, prior_ratio_M, prior_ratio_S, prior_ratio_H), total=len(prior_ratio_M)):
    
    preds = softmax(preds)
 
    m_pred = (preds * m * s * h) / sum((preds * m * s * h))
    max_index = np.argmax(m_pred)
    
    merged_predictions_raw.append(m_pred)    
    merged_predictions.append(max_index)
    
    merged_predictions_genus.append(class_to_genus[max_index])
    
    if lbl != max_index:
        wrong_predictions_all.append([lbl, max_index, img_path])
    
        if class_to_genus[lbl] != class_to_genus[max_index]:
            wrong_predictions_all_genus.append([lbl, max_index, img_path])
            
f1 = f1_score(test_metadata['class_id'], merged_predictions, average='macro')
accuracy = accuracy_score(test_metadata['class_id'], merged_predictions)
recall_3 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=3)
recall_5 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=5)
recall_10 = top_k_accuracy_score(test_metadata['class_id'], merged_predictions_raw, k=10)

print('All:', f1, accuracy, recall_3, recall_5, recall_10)
print('All dif:', np.around(f1-vanilla_f1, 3), np.around((accuracy-vanilla_accuracy) * 100, 2), np.around((recall_3-vanilla_recall_3)*100, 2))

In [None]:
f1 = f1_score(test_metadata['genus_id'], merged_predictions_genus, average='macro')
accuracy = accuracy_score(test_metadata['genus_id'], merged_predictions_genus)
print('Genera lvl performance:', np.around(f1*100, 2), np.around(accuracy*100), 2)