<a href="https://colab.research.google.com/github/adubowski/redi-xai/blob/main/classifier/train_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train a VGG16 Skin Cancer Classifier
The approach here is based on code from https://github.com/laura-rieger/deep-explanation-penalization/tree/master/isic-skin-cancer/ISIC. In particular, the code in the cells "Functions for training" and "Functions for evaluation" are taken almost directly from the Rieger code.

## Libraries, arguments and setup

In [None]:
from google.colab import drive

import torch
import sys
import numpy as np
import pickle as pkl
from os.path import join as oj
from datetime import datetime
import torch.optim as optim
import os
from torch.utils.data import TensorDataset, ConcatDataset
import argparse
from PIL import Image
from tqdm import tqdm
from torch import nn
from numpy.random import randint
import torchvision.models as models
import time
import copy
import gc
import json

### Mount Google Drive and create paths for directories

In [None]:
drive.mount("/content/drive")
dir_path = "/content/drive/MyDrive/redi-detecting-cheating"

Mounted at /content/drive


In [None]:
model_path = oj(dir_path, "models", "initial_classifier")
model_training_path = oj(model_path, "training_224")
data_path = oj(dir_path, "data")

not_cancer_path = oj(data_path, "processed", "no_cancer_224")
cancer_path = oj(data_path, "processed", "cancer_224")

#### Arguments for training

In [None]:
mean = np.asarray([0.485, 0.456, 0.406])
std = np.asarray([0.229, 0.224, 0.225])

# Training settings
parser = argparse.ArgumentParser(description='ISIC Lesion Classification')
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
                    help='input batch size for training (default: 64)')

parser.add_argument('--epochs', type=int, default=5, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.00001, metavar='LR',
                    help='learning rate needs to be extremely small, otherwise loss nans (default: 0.00001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--regularizer_rate', type=float, default=0.0, metavar='N',
                    help='hyperparameter for CDEP weight - higher means more regularization')
args = parser.parse_args(args=[])

regularizer_rate = args.regularizer_rate

num_epochs = args.epochs

device = torch.device(0)

torch.manual_seed(args.seed);
model = models.vgg16(pretrained=True)

model.classifier[-1] = nn.Linear(4096, 2)
model = model.to(device)
params_to_update = model.classifier.parameters()

#### Clean up the image directories
- Remove empty images
- Remove duplicates which appear in a new folder but not the original.
- Ensure image sizes are all 224x224

In [None]:
def clean_up_empty_files(path):
    list_files= os.listdir(path)
    num_files = len(list_files)
    for i in tqdm(range(num_files)):
        if os.path.getsize(oj(path, list_files[i])) < 100:
            os.remove(oj(path, list_files[i]))
            print("File " + str(i) + "deleted!")

def clean_up_duplicates(path1, path2):
    newfiles = os.listdir(path1)
    oldfiles = os.listdir(path2)
    diff = [f for f in newfiles if f not in oldfiles]
    for i in tqdm(diff):
        os.remove(oj(path1, i))
        print("File " + str(i) + "deleted!")

def check_img_sizes(path):
    list_files= os.listdir(path)
    num_files = len(list_files)
    for i in tqdm(range(num_files)):
        im = Image.open(oj(path, list_files[i]))
        if im.width != 224 or im.height != 224:
            print(list_files[i])

In [None]:
# clean_up_empty_files(cancer_path)
# clean_up_empty_files(not_cancer_path)   

# newpath = oj(data_path, "no_cancer_224_inpainted")
# oldpath = oj(data_path, "processed", "no_cancer_224")
# clean_up_duplicates(newpath, oldpath)

# check_img_sizes(not_cancer_path)

#### Torch dataset class

In [None]:
class CancerDataset(torch.utils.data.Dataset):
    def __init__(self, path: str = None, is_cancer: int = None, data_files = None, labels = None):
        """ 
        Expects path and is_cancer both to be supplied if the relevant images all lie in the same directory and have the same class
        or a list of full filepaths and list of all labels are both supplied using data_files and labels otherwise.
        """
        if path:
            self.path = path
            self.data_files = os.listdir(self.path)
            self.is_cancer = is_cancer

        else:
            self.path = ''
            self.data_files = data_files
            self.labels = labels
            self.is_cancer = None
      
    def __getitem__(self, i):
        # Read in the image, convert to float between [0,1] and standardise.
        img = Image.open(oj(self.path, self.data_files[i]))
        img_array = np.asarray(img)/255.0
        img_array -= mean[None, None, :]
        img_array /= std[None, None, :]
        img.close()
        torch_img = torch.from_numpy(img_array.swapaxes(0,2).swapaxes(1,2)).float()
        # Take the global class if supplied, otherwise extract the relevant label from the list of labels.
        is_cancer = self.is_cancer if self.is_cancer is not None else self.labels[i]
        return (torch_img, is_cancer)

    def __len__(self):
        return len(self.data_files)

#### Functions for training

In [None]:
def gradient_sum(im, target, model, crit, device='cuda'):
    '''assume that eveything is already on cuda'''
    im.requires_grad = True
    grad_params = torch.abs(torch.autograd.grad(crit(model(im), target), im,create_graph = True)[0].sum(dim=1)).sum()
    return grad_params

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, resume_training=False):
    since = time.time()
    # train_loss_history = []
    # train_acc_history = []
    # train_cd_history= []

    best_loss = 10.0
    patience = 3
    cur_patience = 0
    if len(os.listdir(model_training_path)) > 0 and resume_training:
        model_list = [(f, os.path.getmtime(oj(model_training_path,f))) for f in os.listdir(model_training_path) if f.endswith('.pt')]
        model_list.sort(key=lambda tup: tup[1], reverse=True)  # sorts in place from most to least recent
        model_name = model_list[0][0]
        model.classifier.load_state_dict(torch.load(oj(model_training_path, model_name)))
        print("Model loaded!")
    for epoch in range(1, num_epochs + 1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        optimizer.step()
        model.train()  # Set model to training mode
        phase = 'train'
        running_loss = 0.0
        running_loss_cd = 0.0
        running_corrects = 0

        # Iterate over data.
        for i, (inputs, labels) in tqdm(enumerate(dataloaders[phase])):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                # need to do calc beforehand because we do need the gradients
                if phase == 'train' and regularizer_rate !=0:
                    inputs.requires_grad = True
                    add_loss = gradient_sum(inputs, labels, model, criterion)  
                    if add_loss!=0:
                        (regularizer_rate*add_loss).backward()
                        optimizer.step()
                    #print(torch.cuda.memory_allocated()/(np.power(10,9)))
                    optimizer.zero_grad()   
                    running_loss_cd += add_loss.item() * inputs.size(0)
  
                    #inputs.require_grad = False
                      
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                if phase == 'train':
                    (loss).backward()
                    optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_cd_loss = running_loss_cd / dataset_sizes[phase]
    
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print('{} Loss: {:.4f} Acc: {:.4f} CD Loss : {:.4f}'.format(
            phase, epoch_loss, epoch_acc, epoch_cd_loss))

        # train_loss_history.append(epoch_loss)
        # train_cd_history.append(epoch_cd_loss)
        # train_acc_history.append(epoch_acc.item())
        torch.save(model.classifier.state_dict(), oj(model_training_path, datetime.now().strftime("%Y%m%d%H%M%S") + ".pt"))     
  
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60)
    )
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    return model

#### Functions for evaluation

In [None]:
from sklearn.metrics import auc,average_precision_score, roc_curve,roc_auc_score,precision_recall_curve, f1_score

def get_output(model, dataset):
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=16,
                                             shuffle=False, num_workers=2)
    model = model.eval()
    y = []
    y_hat = []
    softmax= torch.nn.Softmax()
    with torch.no_grad() :
        for inputs, labels in data_loader:
            y_hat.append((labels).cpu().numpy())
            y.append(torch.nn.Softmax(dim=1)( model(inputs.cuda()))[:,1].detach().cpu().numpy()) # take the probability for cancer
    y_hat = np.concatenate( y_hat, axis=0 )
    y = np.concatenate( y, axis=0 )
    return y, y_hat # in the training set the values were switched

def get_auc_f1(model, dataset,fname = None, ):
    if fname !=None:
        with open(fname, 'rb') as f:
            weights = torch.load(f)
        if "classifier.0.weight" in weights.keys(): #for the gradient models we unfortunately saved all of the weights
            model.load_state_dict(weights)
        else:
            model.classifier.load_state_dict(weights)
        y, y_hat = get_output(model.classifier, dataset)
    else:   
        y, y_hat = get_output(model, dataset)
    auc =roc_auc_score(y_hat, y)
    f1 = np.asarray([f1_score(y_hat, y > x) for x in np.linspace(0.1,1, num = 10) if (y >x).any() and (y<x).any()]).max()
    return auc, f1

## Initial Classifier Training

#### Combine datasets and split to train-test

In [None]:
cancer_dataset = CancerDataset(path=cancer_path, is_cancer=1)
not_cancer_dataset = CancerDataset(path=not_cancer_path, is_cancer=0)
complete_dataset = ConcatDataset((cancer_dataset, not_cancer_dataset))

num_total = len(complete_dataset)
num_train = int(0.8 * num_total)
num_test = num_total - num_train
torch.manual_seed(0);
train_dataset, test_dataset = torch.utils.data.random_split(complete_dataset, [num_train, num_test])
datasets = {'train' : train_dataset, 'test':test_dataset}
dataset_sizes = {'train' : len(train_dataset), 'test':len(test_dataset)}

dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=args.batch_size,
                                             shuffle=True, num_workers=2)
              for x in ['train', 'test']}

##### Record the specific files in the training/test sets.

In [None]:
def list_to_file(li, filename):
  with open(filename, 'w') as f:
    for item in li:
      f.write("%s\n" % item)

def extract_filenames(train_subset, test_subset):
  # Extract the relevant indices of the concat dataset
  train_idx, test_idx = train_subset.indices, test_subset.indices

  # Extract the filenames for the cancer_dataset and not_cancer_dataset and concatenate with their directory path.
  # Each original dataset is stored by the ConcatDataset class. So even though train_subset is a subset, the info for the whole cancer dataset is stored in train_subset.dataset.datasets[0]
  cancer_filepaths      = [oj(train_subset.dataset.datasets[0].path, file) for file in train_subset.dataset.datasets[0].data_files]
  not_cancer_filepaths  = [oj(train_subset.dataset.datasets[1].path, file) for file in train_subset.dataset.datasets[1].data_files]

  filepaths = cancer_filepaths + not_cancer_filepaths    # Append the lists together, this combined list is what the indices are based on.

  train_files = [filepaths[i] for i in train_idx]
  test_files  = [filepaths[i] for i in test_idx]

  return train_files, test_files

In [None]:
# # Call the function and get the full file paths.
# train_files, test_files = extract_filenames(train_dataset, test_dataset)
# list_to_file(train_files, oj(dir_path, 'models', 'train_files.txt'))   # Write the training filepaths to a text file.
# list_to_file(test_files,  oj(dir_path, 'models', 'test_files.txt'))    # Write the testing filepaths to a text file.

#### Weights for training
Since the classes are unbalanced, we need to account for this in the loss function while training

In [None]:
cancer_ratio = len(cancer_dataset)/len(complete_dataset)

not_cancer_ratio = 1 - cancer_ratio
cancer_weight = 1/cancer_ratio
not_cancer_weight = 1/ not_cancer_ratio
weights = np.asarray([not_cancer_weight, cancer_weight])
weights /= weights.sum()
weights = torch.tensor(weights).to(device)

criterion = nn.CrossEntropyLoss(weight = weights.double().float())

optimizer_ft = optim.SGD(params_to_update, lr=args.lr, momentum=args.momentum)

#### Train and save the model

In [None]:
model = train_model(model, dataloaders, criterion, optimizer_ft, num_epochs=num_epochs, resume_training=False)
# pid = datetime.now().strftime('%Y%m%d%H%M%S') 
# torch.save(model.classifier.state_dict(),oj(dir_path, model_path, pid + ".pt"))

In [None]:
auc, f1 = get_auc_f1(model, test_dataset)
print("AUC: ", auc)
print("F1: ", f1)

In [None]:
results_file_path = oj(dir_path, "auc_f1_224_10.txt")
print(results_file_path)
with open(results_file_path, 'w') as f:
    f.write('AUC: ' + str(auc) + "\n")
    f.write('F1: ' + str(f1) + "\n")

## Classifier Retraining
Train a classifier after inpainting the coloured patches in the training set.

In [None]:
# Save to a different folder so it is easy to access for testing.
model_training_path = oj(model_path, "training_inpainted")

In [None]:
train_files = open(oj(dir_path, "models", "train_files_used.txt"), 'rt').read().splitlines()
test_files = open(oj(dir_path, "models", "test_files_used.txt"), 'rt').read().splitlines()

# Create the labels based on the directory the images are contained in.
train_labels = [0 if "no_cancer" in fpath else 1 for fpath in train_files]
test_labels = [0 if "no_cancer" in fpath else 1 for fpath in test_files]

# Replace the training ims that have patches with their inpainted counterparts.
inpainted_train_dir = oj(dir_path, "data", "results_gmcnn", 
                         "test_20210608-102851_inpaint_train_patches_gmcnn_s224x224_gc32", "inpainted")
inpainted_train_files = os.listdir(inpainted_train_dir)

train_files_adj = [oj(inpainted_train_dir, os.path.basename(fpath)) 
                      if os.path.basename(fpath) in inpainted_train_files
                      else fpath 
                      for fpath in train_files ]

#### Create Torch datasets and dataloaders

In [None]:
train_dataset = CancerDataset(data_files = train_files_adj, labels = train_labels)
# Test using the original images rather than inpainted versions.
test_dataset = CancerDataset(data_files = test_files, labels = test_labels)

datasets = {'train' : train_dataset, 'test':test_dataset}
dataset_sizes = {'train' : len(train_dataset), 'test':len(test_dataset)}

dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=args.batch_size,
                                             shuffle=True, num_workers=2)
              for x in ['train', 'test']}

#### Initialise new VGG model.

In [None]:
device = torch.device(0)

torch.manual_seed(args.seed);
model = models.vgg16(pretrained=True)

model.classifier[-1] = nn.Linear(4096, 2)
model = model.to(device)
params_to_update = model.classifier.parameters()

#### Weights for training

In [None]:
cancer_ratio = sum(train_labels)/len(train_labels)

weights = np.asarray([1/(1-cancer_ratio), 1/cancer_ratio])
weights /= weights.sum()
weights = torch.tensor(weights).to(device)

criterion = nn.CrossEntropyLoss(weight = weights.double().float())

optimizer_ft = optim.SGD(params_to_update, lr=args.lr, momentum=args.momentum)

#### Train model

In [None]:
model = train_model(model, dataloaders, criterion, optimizer_ft, num_epochs=5, resume_training=False)
# pid = datetime.now().strftime('%Y%m%d%H%M%S') 
# torch.save(model.classifier.state_dict(),oj(dir_path, model_path, pid + ".pt"))

#### Evaluate

In [None]:
auc, f1 = get_auc_f1(model, test_dataset)
print("AUC: ", auc)
print("F1: ", f1)

In [None]:
results_file_path = oj(dir_path, "auc_f1_inpainted_10.txt")
print(results_file_path)
with open(results_file_path, 'w') as f:
    f.write('AUC: ' + str(auc) + "\n")
    f.write('F1: ' + str(f1) + "\n")