# Params

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [None]:
# device = 'cpu'
map_vpc = {'1': 0, '3': 1, '4': 2, '5': 3, '0': 4, '6': 5} # map labels: 1 = benign (0), we don't have 2, i = Gleason i (i-2) for i=[3:5]
num_classes = 6
batch_size = 32
magnification=40
stains = ['HnE']
fold = 'fold1'

if fold == 'fold1':
    train_slides_vpc = [2, 5, 6, 7]
    val_slides_vpc = [3]
    test_slides_vpc = [1]
elif fold == 'fold2':
    train_slides_vpc = [1, 3, 6, 7]
    val_slides_vpc = [2]
    test_slides_vpc = [5]
elif fold == 'fold3':
    train_slides_vpc = [1, 2, 3, 5]
    val_slides_vpc = [7]
    test_slides_vpc = [6]
# else:
#     assert False 'Please choose the correct fold! - {"fold1", "fold2", "fold3"}'

# model_results_path = 'model/' + str(magnification) + '/' + fold + '/aug_model'
model_results_path = 'model_VPC_Zurich_Colorado/' + str(magnification) + '/' + fold + '/256_aug_model'


# path_VPC = '../data/DeepPath_BlockImages/'
path_VPC = 'VPC_staintools/'


# Import

In [None]:
import os
import pandas as pd
from skimage import io
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from IPython import display
from sklearn.metrics import roc_auc_score
import numpy as np
import staintools
import random
import torchvision.transforms.functional as TF
import cv2 as cv

# Utils

In [None]:
# a function to move tensors from the CPU to the GPU
def dict_to_device(orig, device):
    new = {}
    for k,v in orig.items():
        new[k] = v.to(device)
    return new

def plotImage(img, ax=plt):
    img_pil = torchvision.transforms.ToPILImage()(img)
    img_size = torch.FloatTensor(img_pil.size)
    ax.imshow(img_pil)
    
class MyRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

rotation_transform = MyRotationTransform(angles=[0, 90, 180, 270])
    
AUGMENTED_TRANSFORM = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    rotation_transform,
#     transforms.GaussianBlur(20, sigma=(0,0.1)),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)]
)

# Dataset

In [None]:
class VPCDataset(Dataset): # VPC might be not robust to augmentaion
    def __init__(self, root_dir, slides, map1, magnifications, stains, augmentation=False, transform=transforms.ToTensor()):
        self.root_dir = root_dir
        self.transform = transform
        self.map = map1
        self.y = []
        self.img_files = []
        self.ratio = np.zeros(num_classes)
        self.augmentation = augmentation
        stain_map = {'HnE': 'stain001', 'Ki67': 'stain002', 'P63': 'stain003'}
        magnification_map = {10: 'scale001', 20: 'scale002', 40: 'scale003'}
        self.stain_name = []
        self.magnification_name = []
        for stain in stains:
            self.stain_name.append(stain_map[stain])
        for magnification in magnifications:
            self.magnification_name.append(magnification_map[magnification])
            
        
        # extracting image locations
        for slide in slides:
            slide_path = root_dir + 'slide00' + str(slide) + '/'
            if augmentation:
                self.img_files.extend([img_file for img_file in os.listdir(slide_path) 
                                       if (img_file.split('_')[4] in self.magnification_name and
                                           img_file.split('_')[3] in self.stain_name)])
            else:
                self.img_files.extend([img_file for img_file in os.listdir(slide_path) 
                                       if (img_file.split('_')[4] in self.magnification_name and
                                           img_file.split('_')[3] in self.stain_name) and len(img_file.split('_')) == 6])

        # get ratio of the class
        for img_file in self.img_files:
            label = int(self.map[img_file[49]])
            self.ratio[label] += 1
        self.ratio = 1 / (self.ratio + 1) ## avoid divided by zero 
        self.ratio /= np.sum(self.ratio)
        
        if self.augmentation:
            self.img_files.sort()


    
    def __len__(self):
        if self.augmentation:
            return len(self.img_files) // 15
        return len(self.img_files)

    def __getitem__(self, idx):
        y = None
        img = None
        
        index = idx # if there is no augmentation
        if self.augmentation:
            index = idx * 15 + random.randint(0, 14) # choose one of the 14 augmentations or image itself, randomly
        img_file = self.img_files[index]
        img_path = self.root_dir + img_file[:8] + '/' + img_file # hard coded!!!
        img = io.imread(img_path)
        img = cv.resize(img, (256, 256), interpolation=cv.INTER_CUBIC)
        y = int(self.map[img_file[49]])
        
        if self.transform:
            img = self.transform(img)
            return {'img': img, 'label': y}

In [None]:
dataset_train = VPCDataset(path_VPC, train_slides_vpc, map_vpc, [magnification], stains, True, transform=AUGMENTED_TRANSFORM)
dataset_val = VPCDataset(path_VPC, val_slides_vpc, map_vpc, [magnification], stains, True, transform=transforms.ToTensor())

In [None]:
train_loader = DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=2, shuffle=True, pin_memory=False)
val_loader = DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=2, shuffle=False, pin_memory=False)

In [None]:
iterator = iter(train_loader)
batch = next(iterator)
# output = dataset.__getitem__(50)['img']
output = batch['img'][1]
img = batch['img']

In [None]:
torch.min(output)
print(output.shape)

In [None]:
plt.imshow(output.cpu().permute(1, 2, 0))

In [None]:
print(len(dataset_train.img_files))
print(len(dataset_train.img_files) / 15)
print(len(dataset_train))

In [None]:
print(len(dataset_val.img_files))
print(len(dataset_val.img_files) / 15)
print(len(dataset_val))

# Model

In [None]:
class NN(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
#         self.num_classes = num_classes
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(nn.Linear(in_features=512, out_features=num_classes, bias=True, ))#,
                         # nn.ReLU(),
                         # nn.Linear(in_features=128, out_features=32, bias=True),
                         # nn.ReLU(),
                         # nn.Dropout(0.2),
                         # nn.Linear(in_features=32, out_features=num_classes, bias=True))
#         self.model = torchvision.models.resnet50(pretrained=True)
#         self.model.fc.out_features = num_classes
#         self.model.fc = nn.Sequential(nn.Linear(in_features=2048, out_features=num_classes, bias=True, ))#,
                        #  nn.ReLU(),
                        #  nn.Linear(in_features=1000, out_features=num_classes, bias=True))
        print(self.model)

    def forward(self, dictionary):
        return {'label': self.model(dictionary['img'])}

    def prediction(self, dictionary):
        return {'label': torch.argmax(self.forward(dictionary)['label'], dim=1)}

model = NN(num_classes=num_classes).cuda()

# Training

## Training loop

In [None]:
# prepare plotting
fig = plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
axes = fig.subplots(1,3)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,15,25], gamma=0.1)

sm = nn.Softmax(dim=1)

weights = torch.FloatTensor(train_loader.dataset.ratio).to('cuda') # balancing data in loss

num_epochs = 41
losses = []
# val_losses = []
val_accs = []
val_acc = 0
for epoch in range(num_epochs):
    train_iter = iter(train_loader)
    model.train()
    for i in range(len(train_loader)):
        batch_cpu = next(train_iter)
        batch_gpu = dict_to_device(batch_cpu, 'cuda')
        pred = model(batch_gpu)
        pred_cpu = dict_to_device(pred, 'cpu')
        
        ### Change weight to pos_weight, is input and target positions correct?
        loss = torch.nn.functional.cross_entropy(pred['label'], nn.functional.one_hot(batch_gpu['label'], num_classes=num_classes).double(), weight = weights)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        
        ## plotting ##
        if i%20==0:
            # clear figures for a new update
            for ax in axes:
                ax.cla()
            # plot the predicted pose and ground truth pose on the image
            plotImage(batch_cpu['img'][0], ax=axes[0])
            
            ## !!!! code below is dependent on the prediction code. The commented code below is a better choice!
            axes[0].set_title('Input image with ground truth label = {}, and predicted = {}'.format(batch_cpu['label'][0], int(torch.argmax(pred_cpu['label'][0]))))
            # axes[0].set_title('Input image with ground truth label = {}, and predicted = {}'.format(batch_cpu['label'][0], int(model.prediction(batch_gpu)['label'][0])))

            # plot the training error on a log plot
            axes[1].plot(losses, label='loss')
            axes[1].set_yscale('log')
            axes[1].set_title('Training loss')
            axes[1].set_xlabel('number of gradient iterations')
            axes[1].legend()

            # plot the training error on a log plot
            # axes[2].plot(val_losses, label='val_loss')
            axes[2].plot(val_accs, label='val_acc')
            # axes[2].set_yscale('log')
            axes[2].set_title('Validation Accuracy')
            axes[2].set_xlabel('number of epochs')
            axes[2].legend()

            # clear output window and diplay updated figure
            display.clear_output(wait=True)
            display.display(plt.gcf())
            print("Epoch {}, iteration {} of {} ({} %), loss={}\nval_acc = {}".format(epoch, i, len(train_loader), 100*i//len(train_loader), losses[-1], val_accs))
            # print("Training for the specified amount of epochs would take long.\nStop the process once you verified that the training works on your setup.")

    ## Validation
    val_iter = iter(val_loader)
    model.eval()
    # val_loss = 0
    val_acc = 0
    pred_val = np.zeros(len(dataset_val))
    label = np.zeros(len(dataset_val))
    for i in range(len(val_loader)):
        batch_cpu = next(val_iter)
        batch_gpu = dict_to_device(batch_cpu, 'cuda')
        # pred = model(batch_gpu)
        # pred_cpu = dict_to_device(pred, 'cpu')
        # MSE loss
        # val_loss += (torch.nn.functional.cross_entropy(pred['label'], nn.functional.one_hot(batch_gpu['label'], num_classes=num_classes).double()) * batch_gpu['label'].shape[0] / batch_size).item()
        ## Accuracy
#         val_acc += (torch.sum(model.prediction(batch_gpu)['label'] == batch_gpu['label'])).item()
        ## AUC
#         if pred_val is None:
#             pred_val = sm(model(batch_gpu)['label'].cpu()).detach().numpy()
#             label = batch_cpu['label'].numpy()
#         else:
#             pred_val = np.append(pred_val, sm(model(batch_gpu)['label'].cpu()).detach().numpy(), axis=0)
#             label = np.append(label, batch_cpu['label'].numpy())
        ## Accuracy
        # if pred_val is None:
        #     pred_val = model.prediction(batch_gpu)['label'].cpu().detach().numpy()
        #     label = batch_cpu['label'].numpy()
        # else:
        #     pred_val = np.append(pred_val, model.prediction(batch_gpu)['label'].cpu().detach().numpy())
        #     label = np.append(label, batch_cpu['label'].numpy())
        
        pred_val[i*batch_size:i*batch_size + batch_cpu['label'].shape[0]] = model.prediction(batch_gpu)['label'].cpu().detach().numpy()
        label[i*batch_size:i*batch_size + batch_cpu['label'].shape[0]] = batch_cpu['label'].numpy()
            
            
    ## AUC
#     val_acc = roc_auc_score(label,pred_val,multi_class='ovr')
#     val_accs.append(val_acc)

    ## Accuracy
    val_acc = np.mean(label == pred_val)
    val_accs.append(val_acc)

    ## saving the model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'val_AUC': val_acc
    }, model_results_path + '_' + str(epoch) + '_' + str(val_acc))

    # clear output window and diplay updated figure
    display.clear_output(wait=True)
    display.display(plt.gcf())
    print("Epoch {}, iteration {} of {} ({} %), loss={}\nval_acc = {}".format(epoch, i, len(train_loader), 100*i//len(train_loader), losses[-1], val_accs))

    ### Scheduler ###
    scheduler.step()
    
plt.close('all')
