# **Setup**

In [1]:
#Only neccessary if logging performance data on wandb
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mamoseley018[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import torch
from torch.utils.data import DataLoader
import gc
import h5py
from UNet import UNet, Encoder, ContrastiveEncoder, ResidualBlock, double_conv, ResUNet
from LITSDataset import LITSBinaryDataset, LITSContDatasetPolyCL, LITSContDatasetSimCLR, LITSMultiClassDataset
import LossFunctions
import TrainingEval
from tqdm import tqdm
import os

In [3]:
#Second line only used if experiencing serious problems with using a gpu or if a gpu is unavailable (not recommended)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "ResUNetTest1"
modelFile = "UsedModels/" + modelName

#Specify the file location of the config file, this contains all hyperparameters for training the model
#Everything below loads the data from the config file
configFile = "../testConfig.txt"

startEpoch = 0
useWandB = 0
batchSize = 0
learnRate = 0
epochs = 0
startDim = 0
epochsToDouble = 0
progressive = 0
epochsToSave = 0
cosineAnnealing = 0
cosineRestartEpochs = 0

varDict = {
    "startEpoch":startEpoch,
    "useWandB":useWandB,
    "batchSize":batchSize,
    "learnRate":learnRate,
    "epochs":epochs,
    "startDim":startDim,
    "epochsToDouble":epochsToDouble,
    "progressive":progressive,
    "epochsToSave":epochsToSave,
    "cosineAnnealing":cosineAnnealing,
    "cosineRestartEpochs":cosineRestartEpochs,
}

TrainingEval.ParseConfig(configFile, varDict)

for key in varDict:
    if varDict[key].is_integer():
        locals()[key] = int(varDict[key])
    else:
        locals()[key] = varDict[key]

In [None]:
#Loads binary datasets
trainDataset = LITSBinaryDataset("../Datasets/StandardDatasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("../Datasets/StandardDatasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("../Datasets/StandardDatasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=1)

In [4]:
#Loads binary datasets
trainDataset = LITSBinaryDataset("../Datasets/FullLiTS/FullLiTSTrainingDataset.hdf5")
validationDataset = LITSBinaryDataset("../Datasets/FullLiTS/FullLiTSValidationDataset.hdf5")
testDataset = LITSBinaryDataset("../Datasets/FullLiTS/FullLiTSTestingDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=1)

In [None]:
#Loads multiclass datasets
trainDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTrainingDataset.hdf5")
validationDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassValidationDataset.hdf5")
testDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTestingDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=1)

# **Standard Training**

In [None]:
for i in range(5):
    modelFile = "../UsedModels/ResUNetBaseline" + str(i)
    
    #Specifies the loss functions and weights to use during training process
    #Example: 
    #   lossFuncs = [[segmentationLossFunc1, segmentationLossFunc2], [classificationLossFunc1, classificationLossFunc2]]
    #   weights = [[0.25, 0.5], [0.1, 0.9]]
    #Allows for easily changing the loss function and enables joint training on segmentation and classification
    #Loss functions given weights of 0 are printed every epoch but are not included in the loss calculation
    lossFuncs = [[LossFunctions.dice_loss, LossFunctions.binary_pixel_ce, LossFunctions.dice_score], []]
    weights = [[0.5, 0.5, 0], []]

    segmenter = ResUNet(num_classes=1).to(device)

    #Loads model from file if using a pretrained version
    initModel = ""
    if initModel != "":
        segmenter.load_state_dict(torch.load(initModel))

    #If the config file specifies using WandB, begins the run
    if useWandB:
        wandb.init(project="EMBCBaseline",
                name=modelName,
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Classification Pre-Training**

In [None]:
#Specifies the loss functions and weights to use during training process
#Example: 
#   lossFuncs = [[segmentationLossFunc1, segmentationLossFunc2], [classificationLossFunc1, classificationLossFunc2]]
#   weights = [[0.25, 0.5], [0.1, 0.9]]
#Allows for easily changing the loss function and enables joint training on segmentation and classification
#Loss functions given weights of 0 are printed every epoch but are not included in the loss calculation

focal = LossFunctions.FocalLoss(weight0=0.1, weight1=0.9, gamma=2)
lossFuncs = [[], [focal, LossFunctions.accuracy, LossFunctions.f1]]
weights = [[], [1, 0, 0]]

#Saves encoder model to separate file than the modelFile specified above
encoderFile = "UsedModels/Encoder1"

#Creates UNet encoder with classification branch
encoder = Encoder(1).to(device)

#Loads encoder model if one already exists
initEncoder = ""
if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

gc.collect()

#Starts WandB run if that is being used
if useWandB:
    wandb.init(project="PreTrainedEncoder",
            name="UNetEncoder",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
            })

print(TrainingEval.train(encoder, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, encoderFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, encoder=True))

if useWandB:
    wandb.finish()

In [None]:
#Creates and loads encoder file from initEncoder, uses that when creating the full UNet model
encoder = Encoder(1).to(device)

initEncoder = ""
if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

segmenter = UNet(encoder=encoder)

#After loading encoder, model is trained in the same way as standard models
lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationPreTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

if useWandB:
    wandb.finish()

# **Joint Training**

In [None]:
for i in range(5):
    modelFile = "../UsedModels/JointTrainedResUNet" + str(i)

    #Specifies the loss functions and their weights to use during training process (both segmentation and classification)
    #Example: 
    #   lossFuncs = [[segmentationLossFunc1, segmentationLossFunc2], [classificationLossFunc1, classificationLossFunc2]]
    #   weights = [[0.25, 0.5], [0.1, 0.9]]
    #Allows for easily changing the loss function and enables joint training on segmentation and classification
    #Loss functions given weights of 0 are printed every epoch but are not included in the loss calculation
    classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)
    lossFuncs = [[LossFunctions.dice_score, LossFunctions.dice_loss], [LossFunctions.accuracy, classLossFunc]]
    weights = [[0, 0.6], [0, 0.4]]

    segmenter = ResUNet(num_classes=1).to(device)

    #Loads model from file if using a pretrained version
    initModel = ""
    if initModel != "":
        segmenter.load_state_dict(torch.load(initModel))

    #Uses WandB to log run data if specified by the config file
    if useWandB:
        wandb.init(project="EMBCJointTraining",
                name="JointResUNet" + str(i),
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Contrastive Pre-Training**

In [None]:
#Loads contrastive training dataset
contTrainDataset = LITSContDatasetPolyCL("Datasets/ContrastiveDatasets/SimCLRTrainingDataset.hdf5")
contTrainIter = DataLoader(contTrainDataset, batch_size=batchSize, shuffle=True)

In [None]:
#Creates an encoder with a projection head for contrastive learning
encoder = ContrastiveEncoder()

if useWandB:
    wandb.init(project="LITSEncoderContrastive",
            name="Weights:",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })
    
lossFunc = LossFunctions.ContrastiveLossSimCLR(temp=(1 / batchSize), device=device)

#Uses specific SimCLR training function because of the differences between it and the PolyCL pre-training process
TrainingEval.simCLRTrain(encoder, lossFunc, contTrainIter, epochs, startEpoch, learnRate, device, modelFile, epochsToSave, useWandB=useWandB, cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs)

if useWandB:
    wandb.finish()

In [None]:
#Creates an encoder with a projection head for contrastive learning
encoder = ContrastiveEncoder()

if useWandB:
    wandb.init(project="LITSEncoderContrastive",
            name="Weights:",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })
    
lossFunc = LossFunctions.ContrastiveLossCosine(temp=(1 / batchSize))

#Uses PolyCL contrastive training function
TrainingEval.contrastiveTrain(encoder, lossFunc, contTrainIter, epochs, startEpoch, learnRate, device, modelFile, epochsToSave, useWandB=useWandB, cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs)

if useWandB:
    wandb.finish()

In [None]:
#Creates and loads encoder file from initEncoder, uses that when creating the full UNet model
encoder = ContrastiveEncoder().to(device)

initEncoder = ""
if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

segmenter = UNet(encoder=encoder)

#After loading encoder, model is trained in the same way as standard models
lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationPreTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

if useWandB:
    wandb.finish()

# **Evaluation/Ending**

In [None]:
#Will evaluate all models in each directory listed here
dirs = ["FullLiTSTesting/"]

#Loops through each directory and each model
for dir in dirs:
    for modelName in os.listdir(dir):
        modelFile = dir + modelName

        #Can evaluate on multiple loss functions, listed here
        #Specified in the same way as loss functions for training
        #Segmentation loss functions are listed first, classification loss functions are listed after
        lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], []]

        #If the model is an encoder and we are only evaluating on classification, an encoder is loaded, otherwise the full UNet is loaded
        classification = False
        if classification:
            net = Encoder()
            net.load_state_dict(torch.load(modelFile), strict=False)
        else:
            net = UNet(device, n_class=1, multiTask=False).to(device)
            net.load_state_dict(torch.load(modelFile), strict=False)

        print(f"Model: {modelName}")

        #Evaluates each model on all losses, prints out the function names and the evaluated value
        losses = TrainingEval.evaluate(net, testIter, lossFuncs, device=device, encoder=encoder)
        logStr = ""
        for i, arr in enumerate(losses):
            for j, val in enumerate(arr):
                logStr += (lossFuncs[i][j].__name__ if str(type(lossFuncs[i][j])) == "<class 'function'>" else type(lossFuncs[i][j]).__name__) + ": " + str(val) + " "

        print(logStr)