### Imports and Jupyter setup

In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import tqdm
import torch
import wandb
import numpy as np
import pandas as pd
import torch.nn as nn

from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from timm.scheduler import CosineLRScheduler
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pd.set_option('display.max_columns', None)
device

### Custom Imports

In [None]:
from fgvc.utils.datasets import TrainDataset
from fgvc.utils.augmentations import test_transforms
# from fgvc.utils.utils import timer, init_logger, , 

from fgvc.utils.utils import timer, init_logger, seed_everything, getModel

In [None]:
!nvidia-smi

### Load Dataset Metadata

In [None]:
train_metadata = pd.read_csv("../../metadata/PlantCLEF2018_train_metadata.csv")
val_metadata = pd.read_csv("../../metadata/PlantCLEF2018_val_metadata.csv")

PlantCLEF2017_test = pd.read_csv("../../metadata/PlantCLEF2017_test_metadata.csv")
PlantCLEF2018_test = pd.read_csv("../../metadata/PlantCLEF2018_test_metadata.csv")

print(f'Number of samples in PlantCLEF2017_test: {len(PlantCLEF2017_test)}')
print(f'Number of samples in PlantCLEF2018_test: {len(PlantCLEF2018_test)}')

In [None]:
val_metadata['image_path'] = val_metadata['image_path'].apply(lambda x: x.replace('../../../nahouby/Datasets/PlantCLEF/', '/local/nahouby/Datasets/PlantCLEF/'))
val_metadata['image_path'] = val_metadata['image_path'].apply(lambda x: x.replace('../../nahouby/Datasets/PlantCLEF/', '/local/nahouby/Datasets/PlantCLEF/'))

### Training Parameters

In [None]:
# Adjust BATCH_SIZE and ACCUMULATION_STEPS to values that if multiplied results in 64 !!!!!1

config = {"augmentations": 'light-random_crop',
           "optimizer": 'SGD',
           "scheduler": 'cyclic_cosine',
           "image_size": (224, 224),
           "random_seed": 777,
           "number_of_classes": len(train_metadata['class_id'].unique()),
           "architecture": 'tf_efficientnetv2_s_in21k',
           "batch_size": 32,
           "accumulation_steps": 4,
           "epochs": 100,
           "learning_rate": 0.01,
           "dataset": 'PlantCLEF2018',
           "loss": 'CrossEntropyLoss',
           "training_samples": len(train_metadata),
           "valid_samples": len(val_metadata),
           "workers": 12,
           }

RUN_NAME = f"{config['architecture']}-{config['optimizer']}-{config['scheduler']}-{config['augmentations']}"

### Fix Seeds

In [None]:
seed_everything(config['random_seed'])

### Init Model

In [None]:
# %%
model = getModel(config['architecture'], config['number_of_classes'], pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])

model.load_state_dict(torch.load('./tf_efficientnetv2_s_in21k-SGD-cyclic_cosine-light-random_crop-100E.pth'))

In [None]:
# Adjust BATCH_SIZE and ACCUMULATION_STEPS to values that if multiplied results in 64 !!!!!1

crop_augmentations = test_transforms(data='center_crop', image_size=config['image_size'], mean=model_mean, std=model_std)    

PlantCLEF2017_test_dataset_crop = TrainDataset(PlantCLEF2017_test, transform=crop_augmentations)
PlantCLEF2018_test_dataset_crop = TrainDataset(PlantCLEF2018_test, transform=crop_augmentations)
val_dataset = TrainDataset(val_metadata, transform=crop_augmentations)


PlantCLEF2017_test_loader_crop = DataLoader(PlantCLEF2017_test_dataset_crop, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

PlantCLEF2018_test_loader_crop = DataLoader(PlantCLEF2018_test_dataset_crop, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

val_loader = DataLoader(val_dataset, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

In [None]:
model.to(device)
model.eval()

print(f'Model Loaded and set to Eval mode.')

In [None]:
from fgvc.utils.performance import test_loop_develop

###  PlantCLEF 2017

In [None]:
performance_2017 = test_loop_develop(PlantCLEF2017_test, PlantCLEF2017_test_loader_crop, model, device)

In [None]:
performance_2017 = test_loop_develop(PlantCLEF2017_test, PlantCLEF2017_test_loader_crop, model, device)
print('Accuracy:', performance_2017['acc'])
print('Obs. Accuracy (max logit):', performance_2017['max_logits_acc'])
print('Obs. Accuracy (mean logits):', performance_2017['mean_logits_acc'])
print('Obs. Accuracy (max softmax):', performance_2017['max_softmax_acc'])
print('Obs. Accuracy (mean softmax):', performance_2017['mean_softmax_acc'])

### PlantCLEF 2018

In [None]:
performance_2018 = test_loop_develop(PlantCLEF2018_test, PlantCLEF2018_test_loader_crop, model, device)

In [None]:
print('Accuracy:', performance_2018['acc'])
print('Obs. Accuracy (max logit):', performance_2018['max_logits_acc'])
print('Obs. Accuracy (mean logits):', performance_2018['mean_logits_acc'])
print('Obs. Accuracy (max softmax):', performance_2018['max_softmax_acc'])
print('Obs. Accuracy (mean softmax):', performance_2018['mean_softmax_acc'])

### Validation Performance

In [None]:
from fgvc.utils.performance import test_loop_insights

In [None]:
performance_val, val_metadata = test_loop_insights(val_metadata, val_loader, model, device)

In [None]:
print('Accuracy:', performance_val['acc'])
print('Obs. Accuracy (max logit):', performance_val['max_logits_acc'])
print('Obs. Accuracy (mean logits):', performance_val['mean_logits_acc'])
print('Obs. Accuracy (max softmax):', performance_val['max_softmax_acc'])
print('Obs. Accuracy (mean softmax):', performance_val['mean_softmax_acc'])

In [None]:
val_metadata.fillna('unknown', inplace=True)
val_metadata.Content.unique()

In [None]:
np.argmax(tmp.softmax)

In [None]:
def threshold_analysis(test_metadata, performance_threshold: int = 0.70, performance_step: int = 0.05):
    class_tresholds = {}
    classified_documents = 0
    
    for class_id in tqdm.tqdm(sorted(test_metadata.class_id.unique()), total=len(test_metadata.class_id.unique())):

        for threshold in np.arange(0.0, 1.0, performance_step):

            class_metadata = test_metadata[test_metadata.class_id == class_id]
            tmp = class_metadata[class_metadata['max_softmax'] >= threshold]
            if len(tmp) != 0:
                vanilla_accuracy = accuracy_score(tmp['class_id'], tmp['preds'])

                if performance_threshold <= vanilla_accuracy:
                    class_tresholds[class_id] = threshold
                    num_documents = len(tmp[tmp['max_softmax'] >= threshold])
                    if len(class_metadata) != 0:
                        doc_fraction = num_documents / len(class_metadata)
                    else:
                        doc_fraction = 0

                    classified_documents += num_documents
                    break                
        else:
            class_tresholds[class_id] = 1.0

    return class_tresholds, classified_documents / len(test_metadata)

In [None]:
class_tresholds, fraction = threshold_analysis(val_metadata, 0.8, 0.05)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.hist(class_tresholds.values(), 20)

In [None]:
preds = np.zeros((len(PlantCLEF2017_test)))
preds_raw = []
wrong_paths = []

for i, (images, _, _) in tqdm.tqdm(enumerate(PlantCLEF2017_test_loader_crop), total=len(PlantCLEF2017_test_loader_crop)):

    images = images.to(device)

    with torch.no_grad():
        y_preds = model(images)
    preds[i * len(images): (i+1) * len(images)] = y_preds.argmax(1).to('cpu').numpy()
    preds_raw.extend(y_preds.to('cpu').numpy())

In [None]:
from scipy.special import softmax

In [None]:
PlantCLEF2017_test['logits'] = preds_raw
PlantCLEF2017_test['preds'] = preds
PlantCLEF2017_test['softmax'] = [softmax(row) for row in PlantCLEF2017_test['logits']]
PlantCLEF2017_test['max_softmax'] = [np.max(row) for row in PlantCLEF2017_test['softmax']]

In [None]:
accuracy_score(PlantCLEF2017_test['class_id'], PlantCLEF2017_test['preds'].astype('int32'))

In [None]:
PlantCLEF2017_test['observation_mean'] = None

ObservationIds = PlantCLEF2017_test.ObservationId.unique()

for obs_id in tqdm.tqdm(ObservationIds, total=len(ObservationIds)):
    obs_images = PlantCLEF2017_test[PlantCLEF2017_test['ObservationId'] == obs_id].copy()
    obs_images['softmax'] = obs_images.apply(lambda row: row.softmax if row.max_softmax >= class_tresholds[row.preds] else np.ones(config['number_of_classes']) / config['number_of_classes'], axis=1)
    max_index =  np.argmax(sum(obs_images['softmax']))
    for index, pred in obs_images.iterrows():
        PlantCLEF2017_test.at[index, 'observation_mean'] = max_index

In [None]:
test_metadata_obs = PlantCLEF2017_test.drop_duplicates(subset=['ObservationId'])

In [None]:
accuracy_score(test_metadata_obs['class_id'], test_metadata_obs['observation_mean'].astype('int32'))

In [None]:
accuracy_score(PlantCLEF2017_test['class_id'], PlantCLEF2017_test['preds'].astype('int32'))

In [None]:
PlantCLEF2017_test

In [None]:
len(selected_predictions) / len(val_metadata)

In [None]:
vanilla_accuracy = accuracy_score(selected_predictions['class_id'], selected_predictions['preds'])

In [None]:
vanilla_accuracy

In [None]:
def threshold_analysis_logits(test_metadata, performance_threshold: int = 0.70, num_steps = 20):
    class_tresholds = {}
    classified_documents = 0
    
    min_logit = min(test_metadata['max_logits']) - 1
    max_logit = max(test_metadata['max_logits']) + 1
    
    performance_step = (max_logit - min_logit) / num_steps
    
    for class_id in sorted(test_metadata.class_id.unique()):

        for threshold in np.arange(min_logit, max_logit, performance_step):

            class_metadata = test_metadata[test_metadata.class_id == class_id]
            tmp = class_metadata[class_metadata['max_logits'] >= threshold]
            if len(tmp) != 0:
                vanilla_accuracy = accuracy_score(tmp['class_id'], tmp['preds'])

                if performance_threshold <= vanilla_accuracy:
                    class_tresholds[class_id] = threshold
                    num_documents = len(tmp[tmp['max_logits'] >= threshold])
                    if len(class_metadata) != 0:
                        doc_fraction = num_documents / len(class_metadata)
                    else:
                        doc_fraction = 0

                    classified_documents += num_documents

                    #print(f'Threshold for class {class_id_2_doc_type[class_id]} is {round(threshold * 100,2)}. Achieved accuracy of {vanilla_accuracy} for {round(doc_fraction * 100,2)} of files.')
                    break                
        else:
            class_tresholds[class_id] = 1.0
            #print(f'No threshold for class {class_id_2_doc_type[class_id]}.')

    return class_tresholds, classified_documents / len(test_metadata)

In [None]:
class_tresholds, fraction = threshold_analysis_logits(val_metadata, 0.7, 20)

In [None]:
plt.hist(class_tresholds.values(), 20)

In [None]:
class_fractions = []

for class_id in tqdm.tqdm(sorted(val_metadata.class_id.unique()), total=len(val_metadata.class_id.unique())):

    class_metadata = val_metadata[val_metadata.class_id == class_id]
    tmp = class_metadata[class_metadata['max_logits'] >= class_tresholds[class_id]]
    class_fractions.append(tmp)

selected_predictions = pd.concat(class_fractions).reset_index().drop(columns=['index', 'Unnamed: 0'])
vanilla_accuracy = accuracy_score(selected_predictions['class_id'], selected_predictions['preds'])

In [None]:
vanilla_accuracy

In [None]:
class_tresholds, fraction = threshold_analysis(val_metadata, 0.1)

In [None]:
plt.hist(class_tresholds.values(), 20)

In [None]:
performance_step = 0.1

fractions = []
accuracies = []

for performance_threshold in tqdm.tqdm(np.arange(0.0, 1.0, performance_step), total=1/performance_step):

    class_tresholds, fraction = threshold_analysis(val_metadata, performance_threshold, performance_step)
    
    class_fractions = []

    for class_id in sorted(val_metadata.class_id.unique()):

        class_metadata = val_metadata[val_metadata.class_id == class_id]
        tmp = class_metadata[class_metadata['max_probability'] >= class_tresholds[class_id]]
        class_fractions.append(tmp)

    selected_predictions = pd.concat(class_fractions).reset_index().drop(columns=['index', 'Unnamed: 0'])
    vanilla_accuracy = accuracy_score(selected_predictions['class_id'], selected_predictions['preds'])
    
    fractions.append(fraction)
    accuracies.append(vanilla_accuracy)

In [None]:
accuracies

In [None]:
fractions

In [None]:
plt.plot(accuracies, fractions, '-', linewidth=1, markersize=2)
plt.ylabel('Fraction of Classified documents')
plt.xlabel('Overall Accuracy.')
plt.xlim(0.55, 1.0)
plt.ylim(0.4, 1.0)
plt.tight_layout()
plt.savefig('accuracy_to_num.pdf', dpi=200)