<a href="https://colab.research.google.com/github/adubowski/redi-xai/blob/main/classifier/train_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

import torch
import sys
import numpy as np
import pickle as pkl
from os.path import join as oj
from datetime import datetime
import torch.optim as optim
import os
from torch.utils.data import TensorDataset, ConcatDataset
import argparse
from PIL import Image
from tqdm import tqdm
from torch import nn
from numpy.random import randint
import torchvision.models as models
import time
import copy
import gc
import json

### Mount Google Drive and create paths for directories

In [None]:
drive.mount("/content/drive")
dir_path = "/content/drive/MyDrive/redi-detecting-cheating"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
with open(oj(dir_path, 'config.json')) as json_file:
    data = json.load(json_file)

model_path = oj(dir_path, data["model_folder"], "initial_classifier")
model_training_path = oj(model_path, "training")
data_path = oj(dir_path, data["data_folder"])
# seg_path  = oj(data_path, "patch-segmentation")
not_cancer_path = oj(data_path, "processed/no_cancer")
cancer_path = oj(data_path, "processed/cancer")

#### Arguments for training

In [None]:
mean = np.asarray([0.485, 0.456, 0.406])
std = np.asarray([0.229, 0.224, 0.225])

# Training settings
parser = argparse.ArgumentParser(description='ISIC Lesion Classification')
parser.add_argument('--batch_size', type=int, default=16, metavar='N',
                    help='input batch size for training (default: 64)')

parser.add_argument('--epochs', type=int, default=5, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.00001, metavar='LR',
                    help='learning rate needs to be extremely small, otherwise loss nans (default: 0.00001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--regularizer_rate', type=float, default=0.0, metavar='N',
                    help='hyperparameter for CDEP weight - higher means more regularization')
args = parser.parse_args(args=[])

regularizer_rate = args.regularizer_rate

num_epochs = args.epochs

device = torch.device(0)

torch.manual_seed(args.seed);
model = models.vgg16(pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.9.0', 'inception_v3', pretrained=True)

model.classifier[-1] = nn.Linear(4096, 2)
model = model.to(device)
params_to_update = model.classifier.parameters()

#### Remove empty images from the directories.

In [None]:
# def clean_up_path(path):
#     list_files= os.listdir(path)
#     num_files = len(list_files)
#     for i in tqdm(range(num_files)):
#         if os.path.getsize(oj(path, list_files[i])) < 100:
#             os.remove(oj(path, list_files[i]))
#             print("File " + str(i) + "deleted!")

# clean_up_path(cancer_path)
# clean_up_path(not_cancer_path)   

#### Reading in the image files (unused).

In [None]:
# def load_seg(path, orig_path):
#     list_files= os.listdir(orig_path)
#     num_files = min([2000, len(list_files)])
#     imgs_np = np.zeros((num_files,  299, 299), dtype = np.bool)
#     for i in tqdm(range(num_files)):
#         if os.path.isfile(oj(path,  list_files[i])):
#             img = Image.open(oj(path, list_files[i]))
#             imgs_np[i] = np.asarray(img)[:,:,0] > 100
#             img.close()
#     return imgs_np

# cancer_set = load_folder(cancer_path)
# cancer_set -= mean[None, None, :]
# cancer_set /= std[None, None, :]

# cancer_targets = np.ones((cancer_set.shape[0])).astype(np.bool)

# cancer_dataset = TensorDataset(torch.from_numpy(
#     cancer_set.swapaxes(1,3).swapaxes(2,2)).float(), 
#     torch.from_numpy(cancer_targets), 
#     torch.from_numpy(np.zeros((len(cancer_set), 299, 299), dtype = np.bool))
# )
# del cancer_set

# gc.collect()

# not_cancer_set = load_folder(not_cancer_path)
# not_cancer_set -= mean[None, None, :]
# not_cancer_set /= std[None, None, :]
# seg_set = load_seg(seg_path, not_cancer_path)

# not_cancer_targets = np.zeros((not_cancer_set.shape[0])).astype(np.bool)

# not_cancer_dataset = TensorDataset(torch.from_numpy(not_cancer_set.swapaxes(1,3).swapaxes(2,3)).float(), torch.from_numpy(not_cancer_targets),torch.from_numpy(seg_set))

# del not_cancer_set
# del seg_set

# gc.collect()

#### Torch dataset class

In [None]:
class CancerDataset(torch.utils.data.Dataset):
    def __init__(self, path: str, is_cancer: int):
        self.path = path
        self.data_files = os.listdir(self.path)
        self.is_cancer = is_cancer

    def __getitem__(self, i):
        img = Image.open(oj(self.path, self.data_files[i]))
        img_array = np.asarray(img)/255.0
        img_array -= mean[None, None, :]
        img_array /= std[None, None, :]
        img.close()
        torch_img = torch.from_numpy(img_array.swapaxes(0,2).swapaxes(1,2)).float()
        return (torch_img, self.is_cancer)

    def __len__(self):
        return len(self.data_files)

#### Combine datasets and split to train-test-val

In [None]:
cancer_dataset = CancerDataset(cancer_path, 1)
not_cancer_dataset = CancerDataset(not_cancer_path, 0)
complete_dataset = ConcatDataset((cancer_dataset, not_cancer_dataset))
num_total = len(complete_dataset)
num_train = int(0.8 * num_total)
num_val = int(0.1 * num_total)
num_test = num_total - num_train - num_val
torch.manual_seed(0);
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(complete_dataset, [num_train, num_test, num_val])
datasets = {'train' : train_dataset, 'test':test_dataset, 'val': val_dataset}
dataset_sizes = {'train' : len(train_dataset), 'test':len(test_dataset), 'val': len(val_dataset)}
dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=args.batch_size,
                                             shuffle=True, num_workers=2)
              for x in ['train', 'test','val']}

In [None]:
def list_to_file(li, filename):
  with open(filename, 'w') as f:
    for item in li:
      f.write("%s\n" % item)

def extract_filenames(train_subset, val_subset, test_subset):
  # Extract the relevant indices of the concat dataset
  train_idx, val_idx, test_idx = train_subset.indices, val_subset.indices, test_subset.indices

  # Extract the filenames for the cancer_dataset and not_cancer_dataset and concatenate with their directory path.
  # Each original dataset is stored by the ConcatDataset class. So even though train_subset is a subset, the info for the whole cancer dataset is stored in train_subset.dataset.datasets[0]
  cancer_filepaths      = [oj(train_subset.dataset.datasets[0].path, file) for file in train_subset.dataset.datasets[0].data_files]
  not_cancer_filepaths  = [oj(train_subset.dataset.datasets[1].path, file) for file in train_subset.dataset.datasets[1].data_files]

  filepaths = cancer_filepaths + not_cancer_filepaths    # Append the lists together, this combined list is what the indices are based on.

  train_files = [filepaths[i] for i in train_idx]
  val_files   = [filepaths[i] for i in val_idx] 
  test_files  = [filepaths[i] for i in test_idx]

  return train_files, val_files, test_files

# Call the function and get the full file paths.
train_files, val_files, test_files = extract_filenames(train_dataset, val_dataset, test_dataset)

list_to_file(train_files, oj(dir_path, 'models', 'train_files.txt'))   # Write the training filepaths to a text file.
list_to_file(val_files,   oj(dir_path, 'models', 'val_files.txt'))     # Write the validation filepaths to a text file.
list_to_file(test_files,  oj(dir_path, 'models', 'test_files.txt'))    # Write the testing filepaths to a text file.

#### Weights for training
Since the classes are unbalanced, we need to account for this in the loss function while training

In [None]:
cancer_ratio = len(cancer_dataset)/len(complete_dataset)

not_cancer_ratio = 1 - cancer_ratio
cancer_weight = 1/cancer_ratio
not_cancer_weight = 1/ not_cancer_ratio
weights = np.asarray([not_cancer_weight, cancer_weight])
weights /= weights.sum()
weights = torch.tensor(weights).to(device)

criterion = nn.CrossEntropyLoss(weight = weights.double().float())

optimizer_ft = optim.SGD(params_to_update, lr=args.lr, momentum=args.momentum)

#### Functions for training

In [None]:
def gradient_sum(im, target, model, crit, device='cuda'):
    '''  assume that eveything is already on cuda'''
    im.requires_grad = True
    grad_params = torch.abs(torch.autograd.grad(crit(model(im), target), im,create_graph = True)[0].sum(dim=1)).sum()
    return grad_params

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()
    val_acc_history = []
    val_loss_history = []
    train_loss_history = []
    
    train_acc_history = []
    train_cd_history= []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10.0
    patience = 3
    cur_patience = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        if len(os.listdir(model_training_path)) > 0:
          model_list = [(f, os.path.getmtime(oj(model_training_path,f))) for f in os.listdir(model_training_path) if f.endswith('.pt')]
          model_list.sort(key=lambda tup: tup[1], reverse=True)  # sorts in place from most to least recent
          model_name = model_list[0][0]
          model.classifier.load_state_dict(torch.load(oj(model_training_path, model_name)))
          print("Model loaded! Epoch: ", epoch)
          model.eval()
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                optimizer.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_loss_cd = 0.0
            running_corrects = 0

            # Iterate over data.
            for i, (inputs, labels) in tqdm(enumerate(dataloaders[phase])):
    
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # need to do calc beforehand because we do need the gradients
                    if phase == 'train' and regularizer_rate !=0:
                        inputs.requires_grad = True
                        add_loss = gradient_sum(inputs, labels, model, criterion)  
                        if add_loss!=0:
                            (regularizer_rate*add_loss).backward()
                            optimizer.step()
                        #print(torch.cuda.memory_allocated()/(np.power(10,9)))
                        optimizer.zero_grad()   
                        running_loss_cd +=add_loss.item() * inputs.size(0)
     
                        #inputs.require_grad = False
                         
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    if phase == 'train':
                        (loss).backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_cd_loss = running_loss_cd / dataset_sizes[phase]
       
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
  
            print('{} Loss: {:.4f} Acc: {:.4f} CD Loss : {:.4f}'.format(
                phase, epoch_loss, epoch_acc, epoch_cd_loss))

            # deep copy the model
            if phase == 'val':
                val_acc_history.append(epoch_acc.item())
                val_loss_history.append(epoch_loss)
            if phase == 'train':
                train_loss_history.append(epoch_loss)
                train_cd_history.append(epoch_cd_loss)
                train_acc_history.append(epoch_acc.item())
                torch.save(model.classifier.state_dict(), oj(model_training_path, datetime.now().strftime("%Y%m%d%H%M%S") + ".pt"))     
            if phase == 'val':
                if epoch_loss < best_loss:
            
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    cur_patience = 0
                else:
                    cur_patience+=1
        if cur_patience >= patience:
            break
  
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights

    hist_dict = {}
    hist_dict['val_acc_history'] = val_acc_history
    hist_dict['val_loss_history'] = val_loss_history
    
    hist_dict['train_acc_history'] = train_acc_history

    hist_dict['train_loss_history'] = val_loss_history
    hist_dict['train_cd_history'] = train_cd_history
    model.load_state_dict(best_model_wts)
    return model, hist_dict 

#### Train and save the model

In [None]:
model, hist_dict = train_model(model, dataloaders, criterion, optimizer_ft, num_epochs=num_epochs)
pid = datetime.now().strftime('%Y%m%d%H%M%S') #''.join(["%s" % randint(0, 9) for num in range(0, 20)])
torch.save(model.classifier.state_dict(),oj(dir_path, model_path, pid + ".pt"))

hist_dict['pid'] = pid
hist_dict['regularizer_rate'] = regularizer_rate
hist_dict['seed'] = args.seed
hist_dict['batch_size'] = args.batch_size
hist_dict['learning_rate'] = args.lr
hist_dict['momentum'] = args.momentum
pkl.dump(hist_dict, open(os.path.join(model_path , pid +  '.pkl'), 'wb'))

Epoch 0/4
----------



0it [00:00, ?it/s][A
1it [00:10, 10.57s/it][A
2it [00:11,  7.72s/it][A
3it [00:20,  8.15s/it][A
4it [00:23,  6.38s/it][A
5it [00:31,  7.12s/it][A
6it [00:32,  5.30s/it][A
7it [00:41,  6.22s/it][A
8it [00:42,  4.81s/it][A
9it [00:52,  6.35s/it][A
10it [00:53,  4.72s/it][A
11it [01:01,  5.77s/it][A
12it [01:03,  4.45s/it][A
13it [01:10,  5.40s/it][A
14it [01:12,  4.35s/it][A
15it [01:21,  5.66s/it][A
16it [01:22,  4.24s/it][A
17it [01:30,  5.43s/it][A
18it [01:32,  4.23s/it][A
19it [01:39,  5.09s/it][A
20it [01:41,  4.16s/it][A
21it [01:50,  5.77s/it][A
22it [01:51,  4.32s/it][A
23it [02:01,  5.96s/it][A
24it [02:02,  4.61s/it][A
25it [02:13,  6.51s/it][A
26it [02:14,  4.83s/it][A
27it [02:23,  6.07s/it][A
28it [02:24,  4.53s/it][A
29it [02:34,  6.00s/it][A
30it [02:34,  4.48s/it][A
31it [02:45,  6.25s/it][A
32it [02:46,  4.65s/it][A
33it [02:56,  6.39s/it][A
34it [02:57,  4.75s/it][A
35it [03:07,  6.22s/it][A
36it [03:08,  4.64s/it][A
37it [03:19,  

In [None]:
auc, f1 = get_auc_f1(model, test_dataset)
fd = os.open(oj(dir_path, "auc_f1_299.txt"), os.O_RDWR|os.CREAT)
os.write(fd, "AUC: " + str(auc))
os.write(fd, "F1: " + str(f1))
os.close(fd)w