In [1]:
#installations
#pip install SimpleITK

In [None]:
#imports
import os
import torch
import wandb
import pandas as pd
import torch.nn as nn
import sys
#this path can be specified (if importing)
#sys.path.append()
from PIL import Image
import torchvision.transforms as transforms
#import matplotlib.pyplot as plt
from torch import nn
import sklearn
import numpy as np
import SimpleITK as sitk
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random
import torchvision
#import livelossplot
#from livelossplot import PlotLosses
#import scipy
#import sklearn
#from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
#from sklearn.utils import shuffle
#from sklearn.metrics import confusion_matrix
#import seaborn as sns
#from ctviewer import CTViewer
#from sklearn.metrics import RocCurveDisplay
import datetime
from datetime import datetime
import json
#check if the gpu machine is available
if torch.cuda.is_available():
  device = 'cuda'
  gpu = torch.cuda.get_device_name(0)
  print('Device: ', gpu)
else:
  device = 'cpu'
  gpu = None
  print('Device', device)

#wanddb
key = "" #specify wandb key
#Weights and Bias
if key:
  wandb.login(key=key) #API Key is in your wandb account, under settings (wandb.ai/settings)

In [None]:
#functions

#Dataset
class CustomImageDataset(Dataset):
    def __init__(self, df, col_image, col_label, aug = False, shuffle = False):
      #params
      self.df = df
      if shuffle:
        self.df = self.df.sample(frac = 1, random_state = 42,).reset_index()
      self.aug = aug
      self.col_image = col_image
      self.col_label = col_label

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      #row
      row = self.df.iloc[idx]

      #read image
      image = torch.Tensor(np.expand_dims(sitk.GetArrayFromImage(sitk.ReadImage(row[self.col_image])), axis = 0))
      #read label
      label = row[self.col_label]

      #if aug
      if self.aug:
        #augmentation (increase this given size of the data)
        if random.random() > 0.5: #0.7 0.8
          #horizontal
          image = torchvision.transforms.functional.hflip(image)
        if random.random() > 0.5:
          #vertical
          image = torchvision.transforms.functional.vflip(image)
        #if random.random() > 0.5:
          #rotate
          image = torchvision.transforms.functional.rotate(image, 
                                                           random.choice([30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]))
        #noise
        if random.random() > 0.5:
          noise = torch.normal(random.uniform(-0.9, 0.9), random.uniform(0.01, 0.1), image.shape)
          image = image + noise

        #alternate roations doesn't work due to shape
        #if random.random() > 0.5:
        #  ls = random.sample([1,2,3], 2)
        #  k = random.randint(1, 3)
        #  image = torch.rot90(image, k, ls)

        #affine or perspective may be more aggressive? tranlate and scale --> have to make sure are valid inputs!
        #hopefully increase scale invariance!
        #if random.random() > 0.5:
          #careful not to cut off
          #x = random.choice([5, 10, 15, -5, -10, -15])
          #y = random.choice([5, 10, 15, -5, -10, -15])
          #higher indicates more zoom in (AAA usually in center; may decide to remove if model not fitting correctly)
          #may also remove the perspective aspect of the problem (gauging if bigger or smaller!)
          #scaling = random.choice([0.9, 1.1, 1.5, 1.9])
          #image = torchvision.transforms.functional.affine(image, angle = 0, translate = (x, y), scale = scaling, shear = 0, fill = -1, 
          #interpolation = torchvision.transforms.InterpolationMode.BILINEAR)
      #return
      return image, label

#save json file
def save_params(hyper_params, save_path):
  json_string = json.dumps(hyper_params)
  with open(save_path, 'w') as outfile:
    outfile.write(json_string)

#load the train params back in
def load_params(fpath):
  # Opening JSON file
  with open(fpath) as json_file:
    data = json.load(json_file)
  return data

#get the optimizer
def get_optimizer(model_config, model):
  #AdamW
  if model_config['optimizer'] == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(), lr = model_config['init_lr'], weight_decay = model_config['weight_decay'])
  #Adam
  if model_config['optimizer'] == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr = model_config['init_lr'], weight_decay = model_config['weight_decay'])
  #NAdam
  if model_config['optimizer'] == 'NAdam':
    optimizer = torch.optim.NAdam(model.parameters(), lr = model_config['init_lr'], weight_decay = model_config['weight_decay'])
  #RAdam
  if model_config['optimizer'] == 'RAdam':
    optimizer = torch.optim.RAdam(model.parameters(), lr = model_config['init_lr'], weight_decay = model_config['weight_decay'])
  #return
  return optimizer

#get the model
def get_model(model_config):
  #specify the model init here
  model = Model(model_config['img_size'])
  #return
  return model

#model saving policy
def save_model(model_config, model):
  #save model use weights instead
  model.eval()
  #depending on choice
  if model_config['save_weights_only']:
    torch.save(model.state_dict(), model_config['save_folder'] + 'model_weights.pth')
  else:
    torch.save(model, model_config['save_folder'] + 'model.pth')
  #save info
  save_params(model_config, model_config['save_folder'] + 'model_config.json')

#update previously saved config only 
def update_config_stopearly(save_path):
  #load
  model_config = load_params(save_path + 'model_config.json')
  #update
  model_config['early_stopping']['stopped_early'] = True
  #save
  save_params(model_config, save_path + 'model_config.json')

#new saving policy
def new_saving_policy(early_stop, best_model, model_config, model, epoch):
  #if there is an early stop
  if early_stop:
    #exit training
    exit_training = True
    #has the model been already saved?
    if model_config['early_stopping']['model_criteria']:
      #save just the config with update
      update_config_stopearly(model_config['save_folder'])
    #if not already saved
    else:
      #update
      model_config['early_stopping']['stopped_early'] = True
      #save the model and config
      save_model(model_config, model)
  #if there is not early stop
  else:
    #exit
    exit_training = False
    #need to log that we did not exit training early
    model_config['early_stopping']['stopped_early'] = False
    #you want to save the model every n_epochs often
    if model_config['save_best_model'] == False:
      #check if epocch is divisible and nonzero
      if (epoch % model_config['save_after_n_epochs'] == 0) and (epoch != 0):
        #indicate the model was saved
        model_config['early_stopping']['model_criteria'] = True
        #save
        save_model(model_config, model)
    else:
      #you want to save the best model
      if model_config['epochs_trained'] >= model_config['save_after_n_epochs']:
        #check if current model is the best model
        if best_model:
          #log that it is the best model
          model_config['early_stopping']['best_model'] = True
          #indicate the model was saved
          model_config['early_stopping']['model_criteria'] = True
          #then save
          save_model(model_config, model)
        #if current model is not the best model but want to save for the initital run
        if (best_model == False) and (model_config['epochs_trained'] == model_config['save_after_n_epochs']):
          #inidicate the model was saved
          model_config['early_stopping']['model_criteria'] = True
          #save the model and config
          save_model(model_config, model)

  #return
  return model_config, exit_training

#class earlystopping
class EarlyStopping:
  #early stop if validation does not improve for given patience
  def __init__(self, model_config, verbose = True, trace_func = print):
    #set up
    self.patience = model_config['early_stopping']['patience']
    self.delta = model_config['early_stopping']['delta']
    self.verbose = verbose
    self.trace_func = trace_func
    self.counter = 0
    self.best_score = None
    self.best_model = False
    self.early_stop = False

  #call
  def __call__(self, val_loss):
    #neg val loss
    score = -val_loss
    #init condition
    if self.best_score is None:
      self.best_score = score
    #count number of times model failed to meet the condition
    elif score < self.best_score + self.delta:
      self.counter += 1
      self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
      self.best_model = False
      #identify when early stopping is required
      if self.counter >= self.patience:
        self.early_stop = True
    #if the model shows best score
    else:
      #get the score and counter
      self.best_score = score
      self.counter = 0
      self.best_model = True
    #return the interl
    return self.early_stop, self.best_model

#get the loss fn
def get_loss(model_config, device, df = None):
  #criterion if manually weighted
  if (model_config['loss'] == 'CE') and (model_config['weighted'] == True):
    #get
    criterion = nn.CrossEntropyLoss(weight = torch.Tensor(model_config['weights']).to(device))
  #criterion if not weighted 
  if (model_config['loss'] == 'CE') and (model_config['weighted'] == False):
    #get
    criterion = nn.CrossEntropyLoss()
  #criterion if using automatic weighting?
  if (model_config['loss'] == 'CE') and (model_config['weighted'] == 'auto'):
    #get
    criterion = nn.CrossEntropyLoss(weight = torch.Tensor(auto_weights(model_config, df)).to(device))
  #return
  return criterion

#get scheduler
def get_scheduler(model_config, optimizer):
  #plateau
  if model_config['scheduler']['description'] == 'plateau':
    #get the scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = model_config['scheduler']['mode'], 
                                                           factor = model_config['scheduler']['factor'], 
                                                           patience = model_config['scheduler']['patience'], 
                                                           threshold = model_config['scheduler']['threshold'], 
                                                           threshold_mode = model_config['scheduler']['threshold_mode'], 
                                                           cooldown = model_config['scheduler']['cooldown'], 
                                                           min_lr = model_config['scheduler']['min_lr'], 
                                                           eps = model_config['scheduler']['eps'], 
                                                           verbose = True)
  #return
  return scheduler

#get sampler
def get_sampler(df, model_config):
  labels_unique, counts = np.unique(df[model_config['col_label']], return_counts = True)
  class_weights = [sum(counts) / c for c in counts]
  example_weights = [class_weights[e] for e in df[model_config['col_label']]]
  sampler = torch.utils.data.WeightedRandomSampler(example_weights, len(df['Annotation_Label']))
  return sampler

#get loss weights automatically
def auto_weights(model_config, df):
  #get class weights
  y = df[model_config['col_label']].to_numpy().astype(np.int8)
  #get the class weights
  class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y)
  #convert to tensor
  class_weights = torch.tensor(class_weights, dtype = torch.float)
  #return (may have to send to gpu)
  return class_weights

In [None]:
#main script
def main(config = None):
  #clear workspace when finished with a single model run
  model, x, y_true, y_pred, loss = (None, None, None, None, None)
  dset_train, train_loader, dset_val, val_loader = (None, None, None, None)
  criterion, optimizer, scheduler, early_stopper = (None, None, None, None)
  #reset
  if device == 'cuda':
    torch.cuda.empty_cache()

  #name the model
  model_name = datetime.now().strftime('3D-Model-classification-%Y-%m-%d-%H-%M-%S')

  #init a new wandb run (config = sweep_config)
  with wandb.init(config = config, name = model_name):
    #set up the config (WandB, locked)
    config = wandb.config
    #dict (not locked)
    model_config = dict(config)
    #name the model
    model_config['model'] = model_name
    #save location
    model_config['save_folder'] = model_config['save_folder'] +  model_config['model'] + '/'
    #create the model folder
    if os.path.isdir(model_config['save_folder']) == False:
      os.mkdir(model_config['save_folder'])

    #training data(not in valkfolds!)
    df_train = df[~df['KFold'].isin(model_config['val_kfolds'])]

    #get the training data (remove the folds corresponding to validation)
    dset_train = CustomImageDataset(df_train, 
                                    col_image = model_config['col_image'], col_label = model_config['col_label'], 
                                    aug = model_config['aug'], shuffle = False)
    
    #if train sampler
    if model_config['sampler']:
      train_sampler = get_sampler(df_train, model_config)
    else:
      train_sampler = None

    #train loader
    train_loader = DataLoader(dset_train, sampler = train_sampler, batch_size = model_config['batch_size'])

    #val data (in valkfolds)
    df_val = df[df['KFold'].isin(model_config['val_kfolds'])]

    #get the validation data
    dset_val = CustomImageDataset(df_val, 
                                  col_image = model_config['col_image'], col_label = model_config['col_label'], 
                                  aug = False, shuffle = False)
    
    #if val sampler
    if model_config['sampler']:
      val_sampler = get_sampler(df_val, model_config)
    else:
      val_sampler = None

    #val loader
    val_loader = DataLoader(dset_val, sampler = val_sampler, batch_size = model_config['batch_size'])

    #init the model
    model = get_model(model_config)
    #send
    model.to(device)
    #criterion (can get loss based on training data)
    criterion = get_loss(model_config, device, df_train)
    #optimizer
    optimizer = get_optimizer(model_config, model)
    #scheduler
    scheduler = get_scheduler(model_config, optimizer)
    #early stopping (save time during the sweep)
    early_stopper = EarlyStopping(model_config)

    #track in Jupter Notebook
    #liveloss = PlotLosses()
    #logs
    #logs = {}

    #track for later
    log_train_loss = []
    log_train_acc = []
    log_val_loss = []
    log_val_acc = []


    #iterate through the entire dataset 
    #+1 for shifting (python starts at 0)
    for epoch in range(model_config['epochs_trained'] + 1, model_config['epochs'] + 1):
      #determine train losses
      train_epoch_loss = 0
      #set for training
      model.train()
      #iterate through the training data
      for i, (x, y_true) in enumerate(train_loader):
        #zero optimizer
        optimizer.zero_grad()
        #send to device
        #x = x.to(device)
        #y_true = y_true.to(device)
        #predict
        y_pred = model(x.to(device))
        #determine loss (should already be averaged)
        loss = criterion(y_pred, y_true.to(device))
        #backward
        loss.backward()
        #step
        optimizer.step()
        #track the loss
        train_epoch_loss = train_epoch_loss + loss.item()
        #reset
        x, y_true, y_pred, loss = (None, None, None, None)
        if device == 'cuda':
          torch.cuda.empty_cache()
      #calculate train loss
      train_loss = train_epoch_loss / len(train_loader)
      #calculate train acc
      train_acc = 1 - train_loss

      #determine validation losses
      val_epoch_loss = 0
      #specify eval
      model.eval()
      #set
      with torch.no_grad():
        #iterate
        for i, (x, y_true) in enumerate(val_loader):
          #send to device
          #x = x.to(device)
          #y_true = y_true.to(device)
          #predict
          y_pred = model(x.to(device))
          #determine loss
          loss = criterion(y_pred, y_true.to(device))
          #track the loss (shoudld already be averaged)
          val_epoch_loss = val_epoch_loss + loss.item()
          #reset
          x, y_true, y_pred, loss = (None, None, None, None)
          if device == 'cuda':
            torch.cuda.empty_cache()
      #calulate val loss
      val_loss = val_epoch_loss / len(val_loader)
      #calulate val acc
      val_acc = 1 - val_loss

      #scheduler
      scheduler.step(train_loss)

      #record for training
      log_train_loss.append(train_loss)
      log_train_acc.append(train_acc)
      #record for validation
      log_val_loss.append(val_loss)
      log_val_acc.append(val_acc)

      #wont log lossess or acc after early stopping or save best model

      #log the most recent info
      model_config['train_loss'] = train_loss
      model_config['train_acc'] = train_acc
      model_config['val_loss'] = val_loss
      model_config['val_acc'] = val_acc

      #log all the info
      model_config['log_train_loss'] = log_train_loss
      model_config['log_train_acc'] = log_train_acc
      model_config['log_val_loss'] = log_val_loss
      model_config['log_val_acc'] = log_val_acc

      #keep track of each epoch
      model_config['epochs_trained'] = epoch

      #print
      print('Epoch {0} of {1}: Train Loss {2:.2g} & Acc {3:.2g} v Val Loss {4:.2g} and Acc {5:.2g}'.format(epoch, model_config['epochs'], 
                                                                                                           train_loss, train_acc, val_loss, val_acc))

      #wandb
      wandb.log(model_config)


      #determine if early stopping is required by validation loss
      early_stop, best_model = early_stopper(val_loss)
      #saving policy and determine if training should be exited based on early stop and best model
      model_config, exit_training = new_saving_policy(early_stop, best_model, model_config, model, epoch)

      #specify the logs
      #prefix = ''
      #logs['Loss'] = train_loss
      #logs['Acc'] = train_acc
      #logs
      #prefix = 'val_'
      #logs[prefix + 'Loss'] = val_loss 
      #logs[prefix + 'Acc'] = val_acc

      #living loss
      #liveloss.update(logs)
      #send
      #liveloss.send()

      #exit training early
      if exit_training:
        print('Early Stop: Exit Training')
        break

    #clear workspace when finished with a single model run
    model, x, y_true, y_pred, loss = (None, None, None, None, None)
    dset_train, train_loader, dset_val, val_loader = (None, None, None, None)
    criterion, optimizer, scheduler, early_stopper = (None, None, None, None)
    #reset
    if device == 'cuda':
      torch.cuda.empty_cache()

In [None]:
#parameters in wandb format
sweep_config = {
    #name decided later (sweep name)
    'name': None,
    #sweep method
    'method': 'grid',
    #metric
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize',
    },
    #values which may be altered wandb wants all components
    'parameters': {
        #description
        'description': {'value': 'Model which classifies a Medical Image'}, 
        #project in wandb
        'project':{'value': 'Classification'},
        'model': {'value': None}, #placeholder for actual name
        #documentation
        'data_path': {'value': ''}, #the input csv with filepaths
        #path to save the results of the sweep
        'save_folder': {'value': ''},
        'col_image': {'value': 'Norm-256-256-256'}, #input column of the csv
        'col_label': {'value': 'Annotation_Label'}, #label column of the csv
        'device': {'value': device},
        'val_kfolds': {'value': [4, 5]}, #kfolds for validation
        #model specific params
        'img_size': {'values': [(256, 256, 256)]},
        #training params
        'aug': {'values': [True, False]},
        'batch_size': {'values': [4]},
        'init_lr': {'values': [1e-4, 1e-6]},
        'epochs': {'values': [100]}, #max epochs to train
        'epochs_trained': {'value': 0}, #this is updated in the script!
        'save_after_n_epochs': {'value': 5}, #depends on if you want to save the best model
        'weight_decay': {'values': [1e-2, 1e-4]},
        'optimizer': {'values': ['AdamW']},
        'scheduler': {'values': [
            {'description': 'plateau',
             'mode': 'min',
             'factor': 5e-1,
             'patience': 10,
             'threshold': 1e-3,
             'threshold_mode': 'rel',
             'cooldown': 0,
             'min_lr': 0,
             'eps': 1}]},
        'loss': {'values': ['CE']},
        'weighted': {'values': ['auto']},
        'weights': {'values': [(0.3, 0.5, 0.2)]}, #by manual specification // depends on weighted
        'sampler': {'values':[False]},
        #saving
        'save_weights_only': {'value': True},
        'save_best_model': {'value': True},
        #early stopping
        'early_stopping': {'value':
            {'patience': 10,
            'delta': 1e-4,
            'stopped_early': None, #indicate if stopped early
            'best_model': None, #indicate if best model (if save best model)
            'model_criteria': False}
        },
        #log the model loss and acc
        'log_train_loss': {'value': None},
        'log_train_acc': {'value': None},
        'log_val_loss': {'value': None},
        'log_val_acc': {'value': None},
        #updating performance in WandB
        'train_loss': {'value': None},
        'train_acc': {'value': None},
        'val_loss': {'value': None},
        'val_acc': {'value': None}
    } 
}


In [None]:
#main script

if __name__ == '__main__':

  #read the pickle file
  df = pd.read_pickle(sweep_config['parameters']['data_path']['value'])

  #specify the sweep save location
  sweep_config['name'] = datetime.now().strftime('3D-Model-sweep-%Y-%m-%d-%H-%M-%S')
  #set
  sweep_config['parameters']['save_folder']['value'] = sweep_config['parameters']['save_folder']['value'] + sweep_config['name'] + '/'
  #create the sweep folder
  if os.path.isdir(sweep_config['parameters']['save_folder']['value']) == False:
    os.mkdir(sweep_config['parameters']['save_folder']['value'])
  #save the sweep config in the sweep folder
  save_params(sweep_config, sweep_config['parameters']['save_folder']['value'] + 'sweep_config.json')
  #now run the main script

  #select the project folder
  sweep_id = wandb.sweep(sweep_config, project = sweep_config['parameters']['project']['value'])
  #execute the search
  wandb.agent(sweep_id, main)
  #finish
  wandb.finish()


In [None]:
#delete extra wandb files (can be run after)
#import shutil
#shutil.rmtree('wandb')