### Imports and Jupyter setup

In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import tqdm
import torch
import wandb
import numpy as np
import pandas as pd
import torch.nn as nn

from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from timm.scheduler import CosineLRScheduler
from sklearn.metrics import f1_score, accuracy_score, top_k_accuracy_score

os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pd.set_option('display.max_columns', None)
device

### Custom Imports

In [None]:
from fgvc.utils.datasets import TrainDataset
from fgvc.utils.augmentations import test_transforms
# from fgvc.utils.utils import timer, init_logger, , 

from fgvc.utils.utils import timer, init_logger, seed_everything, getModel

In [None]:
!nvidia-smi

### Load Dataset Metadata

In [None]:
train_metadata = pd.read_csv("../../metadata/PlantCLEF2018_train_metadata.csv")
val_metadata = pd.read_csv("../../metadata/PlantCLEF2018_val_metadata.csv")


PlantCLEF2017_test = pd.read_csv("../../metadata/PlantCLEF2017_test_metadata.csv")
PlantCLEF2018_test = pd.read_csv("../../metadata/PlantCLEF2018_test_metadata.csv")

print(f'Number of samples in PlantCLEF2017_test: {len(PlantCLEF2017_test)}')
print(f'Number of samples in PlantCLEF2017_test: {len(PlantCLEF2018_test)}')

In [None]:
PlantCLEF2017_test['image_path'] = PlantCLEF2017_test['image_path'].apply(lambda x: x.replace('/local/nahouby/Datasets/PlantCLEF/', '/Data-10T/PlantCLEF/'))
PlantCLEF2017_test['image_path'] = PlantCLEF2017_test['image_path'].apply(lambda x: x.replace('/local/nahouby/Datasets/PlantCLEF/', '/Data-10T/PlantCLEF/'))

PlantCLEF2018_test['image_path'] = PlantCLEF2018_test['image_path'].apply(lambda x: x.replace('/local/nahouby/Datasets/PlantCLEF/', '/Data-10T/PlantCLEF/'))
PlantCLEF2018_test['image_path'] = PlantCLEF2018_test['image_path'].apply(lambda x: x.replace('/local/nahouby/Datasets/PlantCLEF/', '/Data-10T/PlantCLEF/'))

### Training Parameters

In [None]:
# Adjust BATCH_SIZE and ACCUMULATION_STEPS to values that if multiplied results in 64 !!!!!1

config = {"augmentations": 'light-random_crop',
           "optimizer": 'SGD',
           "scheduler": 'cyclic_cosine',
           "image_size": (224, 224),
           "random_seed": 777,
           "number_of_classes": len(train_metadata['class_id'].unique()),
           "architecture": 'vit_base_patch32_224',
           "batch_size": 32,
           "accumulation_steps": 4,
           "epochs": 100,
           "learning_rate": 0.01,
           "dataset": 'PlantCLEF2018',
           "loss": 'CrossEntropyLoss',
           "training_samples": len(train_metadata),
           "valid_samples": len(val_metadata),
           "workers": 12,
           }

RUN_NAME = f"{config['architecture']}-{config['optimizer']}-{config['scheduler']}-{config['augmentations']}"

### Fix Seeds & Log Setup

In [None]:
LOG_FILE = f'{RUN_NAME}.log'
LOGGER = init_logger(LOG_FILE)

seed_everything(config['random_seed'])

### Init Model

In [None]:
# %%
model = getModel(config['architecture'], config['number_of_classes'], pretrained=True)
model_mean = list(model.default_cfg['mean'])
model_std = list(model.default_cfg['std'])

model.load_state_dict(torch.load('./vit_base_patch32_224-SGD-cyclic_cosine-light-random_crop-100E.pth'))

In [None]:
# Adjust BATCH_SIZE and ACCUMULATION_STEPS to values that if multiplied results in 64 !!!!!1

vanilla_augmentations = test_transforms(data='vanilla', image_size=config['image_size'], mean=model_mean, std=model_std)    
crop_augmentations = test_transforms(data='center_crop', image_size=config['image_size'], mean=model_mean, std=model_std)    

PlantCLEF2017_test_dataset_vanilla = TrainDataset(PlantCLEF2017_test, transform=vanilla_augmentations)
PlantCLEF2017_test_dataset_crop = TrainDataset(PlantCLEF2017_test, transform=crop_augmentations)

PlantCLEF2018_test_dataset_vanilla = TrainDataset(PlantCLEF2018_test, transform=vanilla_augmentations)
PlantCLEF2018_test_dataset_crop = TrainDataset(PlantCLEF2018_test, transform=crop_augmentations)

PlantCLEF2017_test_loader_vanilla = DataLoader(PlantCLEF2017_test_dataset_vanilla, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

PlantCLEF2017_test_loader_crop = DataLoader(PlantCLEF2017_test_dataset_crop, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

PlantCLEF2018_test_loader_vanilla = DataLoader(PlantCLEF2018_test_dataset_vanilla, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

PlantCLEF2018_test_loader_crop = DataLoader(PlantCLEF2018_test_dataset_crop, 
                                               batch_size=config['batch_size'], 
                                               shuffle=False, 
                                               num_workers=config['workers'])

### Vanilla 2017

In [None]:
timecek = time.time()

model.to(device)
model.eval()

preds = np.zeros((len(PlantCLEF2017_test)))
preds_raw = []
criterion = nn.CrossEntropyLoss()
all_labels = np.zeros((len(PlantCLEF2017_test)))
wrong_paths = []

for i, (images, labels, paths) in tqdm.tqdm(enumerate(PlantCLEF2017_test_loader_vanilla), total=len(PlantCLEF2017_test_loader_vanilla)):

    images = images.to(device)
    labels = labels.to(device)

    with torch.no_grad():
        y_preds = model(images)
    preds[i * len(images): (i+1) * len(images)] = y_preds.argmax(1).to('cpu').numpy()
    all_labels[i * len(images): (i+1) * len(images)] = labels.to('cpu').numpy()

    preds_raw.extend(y_preds.to('cpu').numpy())

In [None]:
PlantCLEF2017_test['vanilla'] = preds_raw
PlantCLEF2017_test['vanilla_preds'] = preds
vanila_accuracy = accuracy_score(PlantCLEF2017_test['class_id'], PlantCLEF2017_test['vanilla_preds'])
    
print('Vanila Accuracy:', vanila_accuracy)

### Crop 2017

In [None]:
timecek = time.time()

model.to(device)
model.eval()

preds = np.zeros((len(PlantCLEF2017_test)))
preds_raw = []
criterion = nn.CrossEntropyLoss()
all_labels = np.zeros((len(PlantCLEF2017_test)))
wrong_paths = []

for i, (images, labels, paths) in tqdm.tqdm(enumerate(PlantCLEF2017_test_loader_crop), total=len(PlantCLEF2017_test_loader_crop)):

    images = images.to(device)
    labels = labels.to(device)

    with torch.no_grad():
        y_preds = model(images)
    preds[i * len(images): (i+1) * len(images)] = y_preds.argmax(1).to('cpu').numpy()
    all_labels[i * len(images): (i+1) * len(images)] = labels.to('cpu').numpy()

    preds_raw.extend(y_preds.to('cpu').numpy())

In [None]:
PlantCLEF2017_test['crop'] = preds_raw
PlantCLEF2017_test['crop_preds'] = preds
crop_accuracy = accuracy_score(PlantCLEF2017_test['class_id'], PlantCLEF2017_test['crop_preds'])
    
print('Vanila Accuracy:', crop_accuracy)

In [None]:
print('Vanila Accuracy:', np.round(vanila_accuracy * 100, 2 ))
print('Crop Accuracy:', np.round(crop_accuracy * 100, 2 ))
