# Import

In [None]:
# system
import sys
import os 
import time

# data
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2 as cv
from sklearn.metrics import accuracy_score, confusion_matrix, plot_confusion_matrix

# deep learning 
import torch.backends.cudnn as cudnn
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader 

# custom helpers
import trainer
import metrics
import process
import models

# for formatting
from pprint import pprint
import warnings
warnings.filterwarnings('ignore')

# Notebook Toggles

In [None]:
COMPUTE_DISTRIBUTION = False
MODEL_ZOO = False

## Hyperparameters 

In [None]:
device = torch.device('cuda:0')
epochs = 30
learning_rate = 3e-4
weight_decay = 1e-5
step_period = 15
lr_decay = 0.95
num_workers = 4
batch_size = 128

# Data

Data is provided by "The ISIC 2020 Challenge". The dataset is annotated for binary classification of skin lesions for melanoma detection. More information may be found at "https://challenge2020.isic-archive.com/"

Data is sources from 2000 patients and includes 33,126 dermoscopic training images

## Loading and Statistics

In [None]:
def toPath(root, image_id):
    return os.path.join(root, image_id + '.npy')

def toLabel(key, mapping):
    return mapping[key]

### HAM

In [None]:
# paths
ham_path = '/usr/local/faststorage/ezimmer/data/'

# parse data
metadata = pd.read_csv('./data/HAM/metadata.csv')
headers = metadata.head()

# verify data
label_mapping = {label : idx for idx, label in enumerate(sorted(np.unique(metadata['dx'])))}
data_file = {ID : {'image' : None, 'label' : None} for ID in metadata['image_id']}

for ID in metadata['image_id']:
    data_file[ID]['image'] = toPath(ham_path, ID)
    data_file[ID]['label'] = toLabel(metadata[metadata['image_id'] == ID]['dx'].values[0], label_mapping)

# get class distribution
class_counts = {idx : 0 for idx in range(len(np.unique(metadata['dx'])))}
for ID in data_file.keys():
    class_counts[data_file[ID]['label']] += 1
    
ncls = len(class_counts.keys())

In [None]:
print("HAM 10000 Metadata")
print('------------------------------------------------------------------')
print(headers, '\n')
print("Number of Classes")
print(ncls, '\n')
print('Unique Labels')
print(label_mapping, '\n')
print("Class Balance")
print(class_counts)

In [None]:
if COMPUTE_DISTRIBUTION: 
    ys, xs = [], []
    for ID in data_file.keys():
        img = cv.imread(data_file[ID]['image'])
        y, x, _ = img.shape
        ys.append(y)
        xs.append(x)
    
    print("Dataset Image Size Distribution")
    print("Num Patients", len(ys), len(xs))
    print("Unique Values", np.unique(ys), np.unique(xs))
    plt.scatter(y,x)
    plt.title("Image dimensions")
    plt.show()

## Experiment, Dataset, Dataloader

Since the dataset is imbalanced, partition train/validation/test by number of classes and then wrap in an oversampler

In [None]:
# generate partitioned exp
partitioned_data = process.generateExperiment(data_file, ncls)

# generate datasets
train_sets = [process.SkinSet(partitioned_data[cls]['train']) for cls in range(ncls)]
val_sets = [process.SkinSet(partitioned_data[cls]['validation']) for cls in range(ncls)]
test_set = process.SkinSet(partitioned_data['test'])

# oversampler on training and validation
train_set = process.Oversampler(train_sets)
val_set = process.Oversampler(val_sets)

# Loaders
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)

# Model

Randomly initialized ResNet-50/18/34

In [None]:
def init_weights(m):
        if (type(m) == nn.Conv2d or type(m) == nn.Linear):
            nn.init.kaiming_normal_(m.weight)

if MODEL_ZOO: 
    
    def generateModel(n_cls, init=None, device=None):
        model = torchvision.models.resnet34(pretrained=False)
        model.fc.out_features=n_cls

        if init is not None:
            model.apply(init)

        if device is not None:
            model = model.to(device)

        return model

    model = generateModel(ncls, init_weights, device)
    
else:
    basic_config = ([3, 32, 64, 128, 256], 1, True, 0.15, 7)
    model = models.BasicCNN(*basic_config).to(device)
    model.apply(init_weights)

# Optimization and Criteria

Optimizer: AdamW for super convergence and fast training

Scheduler: Step LR decay 

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_period, lr_decay)
criterion = nn.CrossEntropyLoss().to(device)

# Metrics and Aggregation

In [None]:
# Metric storage
train_stats = metrics.Aggregator()
val_stats = metrics.Aggregator()
test_stats = metrics.Aggregator()

# add stats
train_stats.addStat('loss')
val_stats.addStat('loss')
train_stats.addStat('acc', accuracy_score)
val_stats.addStat('acc', accuracy_score)

# Augmentations

In [None]:
augmentations = process.Transformer()

# Training 

In [None]:
best_stat = 0
best_model = None
for epoch in range(epochs):
    
    t = time.time()
    
    # train
    preds, labels, t_loss = trainer.train(model, criterion, optimizer, scheduler, train_loader, device)
    train_stats.logStat('loss', (t_loss,))
    train_stats.logStat('acc', (labels, preds))
    
    # validate
    preds, labels, v_loss = trainer.train(model, criterion, optimizer, scheduler, val_loader, device)
    val_stats.logStat('loss', (v_loss,))
    val_stats.logStat('acc', (labels, preds))
    
    if val_stats.getStats('acc')[-1] > best_stat:
        best_stat = val_stats.getStats('acc')[-1]
        best_model = model.state_dict()
    
    t = time.time() - t
        
    print("Epoch:", epoch+1)
    print("--------------------------------")
    print("Time:", t)
    print("Training Loss:       ", t_loss)
    print("Validation Loss:     ", v_loss)
    print("Training Accuracy:   ", train_stats.getStats('acc')[-1])
    print("Validation Accuracy: ", val_stats.getStats('acc')[-1], '\n')
    
    
metrics.Plotter.plot(train_stats.getStats('loss'), val_stats.getStats('loss'), 'Epochs', 'Loss', 'Losses')
metrics.Plotter.plot(train_stats.getStats('acc'), val_stats.getStats('acc'), 'Epochs', 'Accuracy', 'Accuracies')

# Testing

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
model.load_state_dict(best_model)
preds, labels, _ = trainer.evaluate(model, criterion, optimizer, scheduler, test_loader, device)
test_acc = accuracy_score(labels, preds)
conf_mat = confusion_matrix(y_true, y_pred)
print("Test Accuracy: ", test_acc)
plot_confusion_matrix(conf_mat, sorted(list(label_mapping.keys())))