## Imports

In [None]:
import torch
import torch.utils.data as Data
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import transforms as T, models
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from torchsummary import summary
import seaborn as sns
from scipy.special import softmax
from functools import partial
from datetime import datetime
from torch_lr_finder import LRFinder

import warnings
from IPython.display import display
warnings.filterwarnings("ignore")

from src.plant_pathology.leaf_dataset import LeafDataset
from src.plant_pathology.model_loops import training, validation, testing
from src.plant_pathology.models import get_resnet, get_densenet, get_effecientnet
from src.plant_pathology.visualizations import show_saliency_maps, create_class_visualization
from src.plant_pathology.loss import LabelSmoothingCrossEntropy
from src.plant_pathology.metrics import comp_metric, healthy_roc_auc, multiple_diseases_roc_auc, scab_roc_auc, rust_roc_auc
from src.plant_pathology.onecyclelr import OneCycleLR

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.empty_cache()
else:
    device = torch.device('cpu')


print('using device:', device)

## Reading/Processing the Data

In [None]:
IMAGE_PATH = Path('./plant-pathology-2020-fgvc7/images')

def image_path(file_stem):
    return IMAGE_PATH/f'{file_stem}.jpg'

In [None]:
train_df = pd.read_csv('./plant-pathology-2020-fgvc7/train.csv')
test_df = pd.read_csv('./plant-pathology-2020-fgvc7/test.csv')

train_paths = train_df['img_file'] = train_df['image_id'].apply(image_path)
test_paths = test_df['img_file'] = test_df['image_id'].apply(image_path)

train_labels = train_df[['healthy','multiple_diseases','rust','scab']]


In [None]:
train_paths, valid_paths, train_labels, valid_labels = train_test_split(
    train_paths, train_labels, test_size = 0.2, random_state=23, stratify = train_labels)
train_paths.reset_index(drop=True,inplace=True)
train_labels.reset_index(drop=True,inplace=True)
valid_paths.reset_index(drop=True,inplace=True)
valid_labels.reset_index(drop=True,inplace=True)

### Visualize Data

In [None]:
import cv2
import imutils
import matplotlib
matplotlib.rcParams.update({'font.size': 22})

In [None]:
scab = cv2.imread(str(train_df.iloc[0]['img_file']))
multi = cv2.imread(str(train_df.iloc[1]['img_file']))
rust = cv2.imread(str(train_df.iloc[1819]['img_file']))
healthy = cv2.imread(str(train_df.iloc[4]['img_file']))
kernel = np.ones((6,6),np.float32)/25

types = [healthy, multi, rust, scab]

In [None]:
fig, axs = plt.subplots(4, 5)
y_labels = ['Healthy', 'Multi', 'Rust', 'Scab']
x_labels = ['Normal', 'Horizontal Flip', 'Vertical Flip', 'Rotated 25', 'Filtered']

for i in range(4):
    axs[i, 0].imshow(types[i])
    axs[i, 0].set(ylabel=y_labels[i])
    axs[i, 1].imshow(cv2.flip(types[i], 1))
    axs[i, 2].imshow(cv2.flip(types[i], 0))
    axs[i, 3].imshow(imutils.rotate(types[i], 25))
    axs[i, 4].imshow(cv2.filter2D(types[i],-1,kernel))
    
    if (i + 1) == 4:
        for j in range(5):
            axs[i, j].set(xlabel=x_labels[j])
    

for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])

        
fig.set_size_inches(18.5, 10.5)
fig.savefig('example.jpg')

### Initialization

In [None]:
BATCH_SIZE = 20
NUM_EPOCHS = 30
TRAIN_SIZE = train_labels.shape[0]
VALID_SIZE = valid_labels.shape[0]

In [None]:
train_dataset = LeafDataset(train_paths, train_labels)
trainloader = Data.DataLoader(train_dataset, shuffle=True, batch_size = BATCH_SIZE, num_workers = 2)

valid_dataset = LeafDataset(valid_paths, valid_labels, train = False)
validloader = Data.DataLoader(valid_dataset, shuffle=False, batch_size = BATCH_SIZE, num_workers = 2)

test_dataset = LeafDataset(test_paths,train = False, test = True)
testloader = Data.DataLoader(test_dataset, shuffle=False, batch_size = BATCH_SIZE, num_workers = 2)

### Metrics

In [None]:
def modified_accuracy_score(labels, preds):
    preds = np.argmax(preds, axis=1)
    return accuracy_score(labels, preds)
acc_fns = [modified_accuracy_score, healthy_roc_auc, multiple_diseases_roc_auc, rust_roc_auc, scab_roc_auc, comp_metric]

#### Please only run one of the net sections before running the training loop

## Get models
model selections:
- densenet
- resnet
- inception

### DenseNet

In [None]:
densenet = get_densenet(train_labels)
model = densenet
num_params = len(list(model.parameters()))
for idx, param  in enumerate(model.parameters()):
    if idx < num_params // 2:
        param.requires_grad = False

### ResNet

In [None]:
resnet34 = get_resnet(train_labels, pretrained=True)
model = resnet34

### Inception

In [None]:
inception = models.inception_v3(pretrained=True)

### EffecientNet

In [None]:
effecient_net = get_effecientnet(train_labels, pretrained=True)
model = effecient_net

### Optimizer and Loss

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=8e-4, momentum=0.8, weight_decay=1e-4)
loss_fn = LabelSmoothingCrossEntropy()

### LR Finder

In [None]:
lr_finder = LRFinder(model, optimizer, loss_fn, device='cuda')

In [None]:
lr_finder.range_test(trainloader, end_lr=100, num_iter=100, step_mode='exp')

In [None]:
lr_finder.plot()

In [None]:
lr_finder.reset()

### Updated Optimizer and Schedular

In [None]:
factor = 6
end_lr = 4 / 1000
optimizer = torch.optim.SGD(model.parameters(), lr=end_lr, momentum=0.8, weight_decay=1e-4)

num_iters = len(train_dataset) / BATCH_SIZE * NUM_EPOCHS
scheduler = OneCycleLR(optimizer, num_steps=num_iters // 2, lr_range=(end_lr/factor, end_lr))

### Training loop

In [None]:
train_loss = []
valid_loss = []
train_acc = []
val_acc = []
lrs = []

model.to(device)

In [None]:
best_model = get_resnet(train_labels, pretrained=True)
best_combined_roc_auc = float('-inf')

# now = datetime.now().strftime("%m_%d_%Y")
Path(f'model_checkpoints/{model.__class__.__name__}').mkdir(parents=True, exist_ok=True)
checkpoint = Path(f'model_checkpoints/{model.__class__.__name__}')
# checkpoint.mkdir(exist_ok=True)

for epoch in range(NUM_EPOCHS):
    
    tl, ta, lr = training(model, trainloader, optimizer, scheduler, loss_fn, acc_fns, device, TRAIN_SIZE)
    display(pd.DataFrame([epoch, tl, *ta], index=['Epoch', 'Train Loss', 'Train Accuracy',  'Healthy ROC_AUC', 'Multi ROC_AUC', 'Rust ROC_AUC', 'Scab ROC_AUC', 'Combined ROC_AUC']).T)
    vl, va, conf_mat = validation(model, validloader, loss_fn, acc_fns, confusion_matrix, device, VALID_SIZE)
    display(pd.DataFrame([epoch, vl, *va], index=['Epoch', 'Valid Loss', 'Valid Accuracy',  'Healthy ROC_AUC', 'Multi ROC_AUC', 'Rust ROC_AUC', 'Scab ROC_AUC', 'Combined ROC_AUC']).T)
    train_loss.append(tl)
    valid_loss.append(vl)
    train_acc.append(ta)
    val_acc.append(va)
    lrs.extend(lr)
    
    if va[-1] > best_combined_roc_auc:
        best_combined_roc_auc = va[-1]
        best_model.load_state_dict(model.state_dict()) # copy weights and stuff
    
    if (epoch+1)%10==0:
        torch.save(model.state_dict(), checkpoint/f'epoch_{epoch}_{loss_fn.__class__.__name__}.pt')
        
torch.save(best_model.state_dict(), checkpoint/f'best_{loss_fn.__class__.__name__}.pt')

In [None]:
train_acc = np.array(train_acc)
val_acc = np.array(val_acc)
lrs = np.array(lrs)

### Save Results for future use

In [None]:
train_results = pd.DataFrame(train_acc)
train_results['Loss'] = train_loss
train_results.columns = ['Accuracy', 'Healthy ROC_AUC', 'Multi ROC_AUC', 'Rust ROC_AUC', 'Scab ROC_AUC', 'Combined ROC_AUC', 'Loss']
train_results.index.name = 'Epoch'

val_results = pd.DataFrame(val_acc)
val_results['Loss'] = valid_loss
val_results.columns = ['Accuracy', 'Healthy ROC_AUC', 'Multi ROC_AUC', 'Rust ROC_AUC', 'Scab ROC_AUC', 'Combined ROC_AUC', 'Loss']
val_results.index.name = 'Epoch'

In [None]:
Path(f'model_results/{best_model.__class__.__name__}').mkdir(parents=True, exist_ok=True)
results_folder = Path(f'model_results/{best_model.__class__.__name__}')
train_results.to_csv(results_folder/f'{model.__class__.__name__}_train_results_{loss_fn.__class__.__name__}.csv')
val_results.to_csv(results_folder/f'{model.__class__.__name__}_valid_results_{loss_fn.__class__.__name__}.csv')
np.savetxt(results_folder/f'{model.__class__.__name__}_lrs.csv', lrs, delimiter=',')

### Plots

In [None]:
Path(f'model_plots/{best_model.__class__.__name__}').mkdir(parents=True, exist_ok=True)
plots_folder = Path(f'model_plots/{best_model.__class__.__name__}')

In [None]:
plt.figure(figsize=(14, 8))
sns.lineplot(list(range(len(lrs))), lrs)
plt.xlabel('# Iterations', fontsize=18)
plt.ylabel('Learning Rate', fontsize=18)
plt.title('Learning Rate vs # Iterations', fontsize=20)
plt.savefig(plots_folder/'learning_rate.jpg')

In [None]:
plt.figure(figsize=(14, 8))
plt.ylim(0,1.5)
sns.lineplot(list(range(len(train_loss))), train_loss)
sns.lineplot(list(range(len(valid_loss))), valid_loss)
plt.xlabel('Epoch', fontsize=18)
plt.ylabel('Loss', fontsize=18)
plt.legend(['Train','Val'], fontsize=16)
plt.title('Loss vs Epoch', fontsize=20)
plt.savefig(plots_folder/'loss.jpg')

In [None]:
acc_names = ['Accuracy', 'Healthy ROC_AUC', 'Multi ROC_AUC', 'Rust ROC_AUC', 'Scab ROC_AUC', 'Combined ROC_AUC']
for idx, acc_name in enumerate(acc_names):
    plt.figure(figsize=(14, 8))
    sns.lineplot(list(range(len(train_acc[:, idx]))), train_acc[:, idx])
    sns.lineplot(list(range(len(val_acc[:, idx]))), val_acc[:, idx])
    plt.xlabel('Epoch', fontsize=18)
    plt.ylabel(acc_name, fontsize=18)
    plt.legend(['Train','Val'], fontsize=16)
    plt.title(f'{acc_name} vs Epoch', fontsize=20)
    plt.savefig(plots_folder/f'{model.__class__.__name__}_{acc_name}.jpg')

In [None]:
_, va, conf_mat = validation(best_model, validloader, loss_fn, acc_fns, confusion_matrix, device, VALID_SIZE)
labels = ['Healthy', 'Multiple','Rust','Scab']
plt.figure(figsize=(15, 10))
sns.heatmap(conf_mat, xticklabels=labels, yticklabels=labels, annot=True)
plt.title('Confusion Matrix', fontsize=20)
plt.savefig(plots_folder/'confusion.jpg')

In [None]:
va

### Testing Performance

In [None]:
def get_testing_output(model, device):
    model = model.to(device)
    subs = []
    for i in range(5): #average over 5 runs
        out = testing(model, testloader, device)
        output = pd.DataFrame(softmax(out,1), columns = ['healthy','multiple_diseases','rust','scab']) #the submission expects probability scores for each class
        output.drop(0, inplace = True)
        output.reset_index(drop=True,inplace=True)
        subs.append(output)

    sub_eff1 = sum(subs)/5
    return sub_eff1

### Model Emsembling

In [None]:
densenet_checkpoint = get_densenet(train_labels, model_path='/home/anishwalawalkar/plant-pathology/model_checkpoints/DenseNet/best_LabelSmoothingCrossEntropy.pt')
# resnet_checkpoint = get_resnet(train_labels, model_path='/home/anishwalawalkar/plant-pathology/model_checkpoints/ResNet_best_smoothing.pt')

In [None]:
sub_densenet = get_testing_output(densenet_checkpoint, device)
# sub_resnet = get_testing_output(resnet_checkpoint, device)

In [None]:
# submission = 0.25 * sub_resnet + 0.75 * sub_densenet
submission = sub_densenet

In [None]:
submission['image_id'] = test_df['image_id']

In [None]:
submission.to_csv('submission.csv', index=False)

### Saliency Maps

In [None]:
densenet_checkpoint = get_densenet(train_labels, model_path='/home/anishwalawalkar/plant-pathology/model_checkpoints/DenseNet/best_LabelSmoothingCrossEntropy.pt')

In [None]:
class_names = ['Healthy', 'Multi', 'Rust', 'Scab']
y = [0, 1, 2, 3]
X = np.array([
    np.array(Image.open(str(train_df.iloc[1817]['img_file']))),
    np.array(Image.open(str(train_df.iloc[1]['img_file']))),
    np.array(Image.open(str(train_df.iloc[1819]['img_file']))),
    np.array(Image.open(str(train_df.iloc[0]['img_file']))),
])

In [None]:
show_saliency_maps(X, y, densenet_checkpoint, class_names, device)

### Class Visualization

In [None]:
densenet_checkpoint = get_densenet(train_labels, model_path='/home/anishwalawalkar/plant-pathology/model_checkpoints/DenseNet/best_LabelSmoothingCrossEntropy.pt')

In [None]:
create_class_visualization(y[3], model, device, class_names)