# **Setup**

In [None]:
import wandb
wandb.login()

In [None]:
import torch
from torch.utils.data import DataLoader
#from torchsummary import summary
import gc
import h5py
from UNet import UNet, Encoder, ContrastiveEncoder, ResNetBlock, ResidualBlock, double_conv
from LITSDataset import LITSBinaryDataset, LITSContDataset, LITSMultiClassDataset
import LossFunctions
import TrainingEval
from tqdm import tqdm
import os

In [None]:
#Hyperparameters and training modifications
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "TotalSegmentatorTest3"
modelFile = "UsedModels/" + modelName
configFile = "testConfig.txt"
#configFile = "evalConfig.txt"
#configFile = "ContrastiveModels/contFineTuneConfig.txt"

startEpoch = 0
useWandB = 0
batchSize = 0
learnRate = 0
epochs = 0
startDim = 0
epochsToDouble = 0
progressive = 0
epochsToSave = 0
cosineAnnealing = 0
cosineRestartEpochs = 0

varDict = {
    "startEpoch":startEpoch,
    "useWandB":useWandB,
    "batchSize":batchSize,
    "learnRate":learnRate,
    "epochs":epochs,
    "startDim":startDim,
    "epochsToDouble":epochsToDouble,
    "progressive":progressive,
    "epochsToSave":epochsToSave,
    "cosineAnnealing":cosineAnnealing,
    "cosineRestartEpochs":cosineRestartEpochs,
}

TrainingEval.ParseConfig(configFile, varDict)

for key in varDict:
    if varDict[key].is_integer():
        locals()[key] = int(varDict[key])
    else:
        locals()[key] = varDict[key]

In [None]:
#Load Datasets
#trainDataset = LITSBinaryDataset("Datasets/StandardDatasets/FullTrainDataset.hdf5")
trainDataset = LITSBinaryDataset("Datasets/StandardDatasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("Datasets/StandardDatasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("Datasets/StandardDatasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

print("Datasets loaded")

In [None]:
trainDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTrainingDataset.hdf5")
validationDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassValidationDataset.hdf5")
testDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTestingDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

In [None]:
trainDataset = LITSBinaryDataset("Datasets/TotalSegmentator/TotalSegmentatorTrainingDataset.hdf5")
validationDataset = LITSBinaryDataset("Datasets/TotalSegmentator/TotalSegmentatorValidationDataset.hdf5")
testDataset = LITSBinaryDataset("Datasets/TotalSegmentator/TotalSegmentatorTrainingDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=1)

# **Standard Training**

In [None]:
#Load model
initModel = ""

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

segmenter = UNet(device=device, n_class=1)
#print(summary(net, (1, 256, 256)))

#Loads model from file if using a pretrained version
if initModel != "":
    segmenter.load_state_dict(torch.load(initModel))

segmenter = segmenter.to(device)

print("Intialized standard UNet model")

In [None]:
#Train Model

learnRate = 0.5

if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName,
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

if useWandB:
    wandb.finish()

In [None]:
for modelNum in range(10):
    modelName = "MultiClassBasline" + str(modelNum)
    modelFile = "UsedModels/MultiClassBasline" + str(modelNum)

    #Load model
    initModel = ""

    lossFuncs = [[LossFunctions.weighted_dice_loss, LossFunctions.dice_score], []]
    weights = [[1, 0], []]

    segmenter = UNet(device=device, n_class=2)
    #print(summary(net, (1, 256, 256)))

    #Loads model from file if using a pretrained version
    if initModel != "":
        segmenter.load_state_dict(torch.load(initModel))

    segmenter = segmenter.to(device)

    print("Intialized standard UNet model")

    #Train Model

    if useWandB:
        wandb.init(project="MultiClassLiTS",
                name=modelName,
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

In [None]:
lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

initModel = ""

for modelNum in range(10):
    modelFile = "UsedModels/BaselineTotalSeg" + str(modelNum)
    segmenter = UNet(device=device, n_class=1)
    #print(summary(net, (1, 256, 256)))

    #Loads model from file if using a pretrained version
    if initModel != "":
        segmenter.load_state_dict(torch.load(initModel))

    segmenter = segmenter.to(device)

    print("Intialized standard UNet model")

    #Train Model

    if useWandB:
        wandb.init(project="LiverSegmentation",
                name=modelName,
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Pre-Training**

In [None]:
focal = LossFunctions.FocalLoss(weight0=0.1, weight1=0.9, gamma=2)

lossFuncs = [[], [focal, LossFunctions.accuracy, LossFunctions.f1]]
weights = [[], [1, 0, 0]]

initEncoder = "Run 2/Progressive Encoders/ProgEncoder10"
encoderFile = "UsedModels/Encoder1"

encoder = Encoder(1)

if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

encoder = encoder.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="PreTrainedEncoder",
            name="UNetEncoder",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
            })

print(TrainingEval.train(encoder, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, encoderFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, encoder=True))

if useWandB:
    wandb.finish()

In [None]:
segmenter = UNet(encoder=encoder)

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationPreTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

if useWandB:
    wandb.finish()

# **Joint Training**

In [None]:
for i in tqdm(range(10)):
    modelFile = "ContJoint" + str(i)
    initEncoder = "ContrastiveModels/Encoders/ContrastiveEncoder" + str(i)

    encoder = Encoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, encoder=encoder)

    if useWandB:
        wandb.init(project="LiverSegmentationJointTraining",
                name="Weights:",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

    lossFuncs = [[LossFunctions.dice_score, LossFunctions.dice_loss], [LossFunctions.accuracy, classLossFunc]]
    weights = [[0, 0.6], [0, 0.4]]

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Contrastive Pre-Training**

In [None]:
#contTrainDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveTrainDataset.hdf5")
contTrainDataset = LITSContDataset("Datasets/ContrastiveDatasets/SimCLRTrainingDataset.hdf5")
#contValDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveValidationDataset.hdf5")
#contTestDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveTestDataset.hdf5")

contTrainIter = DataLoader(contTrainDataset, batch_size=batchSize, shuffle=True)
#contValidationIter = DataLoader(contValDataset, batch_size=batchSize)
#contTestIter = DataLoader(contTestDataset, batch_size=batchSize)

In [None]:
for i in tqdm(range(10)):
    modelName = "SimCLREncoder" + str(i)
    modelFile = "UsedModels/" + modelName
    encoder = ContrastiveEncoder(block=ResidualBlock)

    if useWandB:
        wandb.init(project="LITSEncoderContrastive",
                name="Weights:",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })
        
    #lossFunc = LossFunctions.ContrastiveLossEuclidean
    #lossFunc = LossFunctions.ContrastiveLossCosine(temp=(1 / batchSize))
    lossFunc = LossFunctions.ContrastiveLossSimCLR(temp=(1 / batchSize), device=device)

    #TrainingEval.contrastiveTrain(encoder, lossFunc, contTrainIter, contValidationIter, epochs, startEpoch, learnRate, device, modelFile, epochsToSave, useWandB=useWandB, 
    #    cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, isDist=False)

    TrainingEval.simCLRTrain(encoder, lossFunc, contTrainIter, epochs, startEpoch, learnRate, device, modelFile, epochsToSave, useWandB=useWandB, cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs)

    if useWandB:
        wandb.finish()

    print("Model: " + str(i) + " finished")

In [None]:
initEncoder = "ContrastiveTest/ContrastiveTest1Encoder"

encoder = ContrastiveEncoder()
encoder.load_state_dict(torch.load(initEncoder), strict=False)

segmenter = UNet(device, encoder=encoder)
#segmenter.freezeEncoder()

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [None]:
for i in tqdm(range(10)):
    initEncoder = "SimCLR/Encoders/SimCLREncoder" + str(i)
    modelFile = "UsedModels/SimCLRFineTune" + str(i)

    print(modelFile)

    encoder = ContrastiveEncoder(block=ResidualBlock)
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, n_class=1 , encoder=encoder, block=ResidualBlock)
    segmenter.freezeEncoder()

    lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
    weights = [[1, 0], []]

    #Train Model
    gc.collect()

    if useWandB:
        wandb.init(project="LiverSegmentationPreTraining",
                name="NoWeights",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

    if useWandB:
        wandb.finish()

In [None]:
modelNames = ["SimCLRMulticlass"]
initLocs = ["SimCLR/Encoders/SimCLREncoder"]

for j, modelName in enumerate(modelNames):
    for i in tqdm(range(10)):
        initEncoder = initLocs[j] + str(i)
        modelFile = "UsedModels/" + modelName + str(i)

        print(modelFile)

        block = ResidualBlock

        encoder = ContrastiveEncoder(block=block)

        encoder.load_state_dict(torch.load(initEncoder), strict=False)

        segmenter = UNet(device, n_class=2, encoder=encoder, block=block)
        segmenter.freezeEncoder()

        lossFuncs = [[LossFunctions.weighted_dice_loss, LossFunctions.dice_score], []]
        weights = [[1, 0], []]

        #Train Model
        gc.collect()

        if useWandB:
            wandb.init(project=modelName,
                    name=modelName + str(i),
                    config={
                        "BatchSize":batchSize,
                        "LearnRate":learnRate,
                        "Epochs":epochs,
                        "StartDimension":startDim,
                        "EpochsToDouble":epochsToDouble
                    })

        TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
            cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

        if useWandB:
            wandb.finish()

In [None]:
for numScans in range(10):
    for i in tqdm(range(10)):
            trainDataset = LITSBinaryDataset("Datasets/ReducedData/ReducedDataFineTuneDataset" + str(numScans + 1) + "Scans.hdf5")
            trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
            
            modelName = "SimCLRReducedFineTune" + str(numScans + 1) + "Scans" + str(i)
            initEncoder = "SimCLR/Encoders/SimCLREncoder" + str(i)
            modelFile = "UsedModels/" + modelName

            print(modelFile)

            block = ResidualBlock
            encoder = ContrastiveEncoder(block=block)

            encoder.load_state_dict(torch.load(initEncoder), strict=False)

            segmenter = UNet(device, n_class=1, encoder=encoder, block=block)

            segmenter.freezeEncoder()

            lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
            weights = [[1, 0], []]

            #Train Model
            gc.collect()

            if useWandB:
                wandb.init(project="LiverSegmentationPreTraining",
                        name="NoWeights",
                        config={
                            "BatchSize":batchSize,
                            "LearnRate":learnRate,
                            "Epochs":epochs,
                            "StartDimension":startDim,
                            "EpochsToDouble":epochsToDouble
                        })

            TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
                cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

            if useWandB:
                wandb.finish()

# **Evaluation/Ending**

In [None]:
dirs = ["MultiClassBaselines/"]

for dir in dirs:
    for modelName in os.listdir(dir):
        #modelName = "BaselineTotalSeg" + str(q)
        classification = False
        modelFile = dir + modelName
        classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

        lossFuncs = [[LossFunctions.hausdorff], []]
        #lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], [LossFunctions.accuracy, LossFunctions.f1]]

        if classification:
            net = encoder
            net.load_state_dict(torch.load(modelFile), strict=False)
        else:
            net = UNet(device, n_class=2, multiTask=False).to(device)
            net.load_state_dict(torch.load(modelFile), strict=False)

        #Evaluate Model
        print(f"Model: {modelName}")

        losses = TrainingEval.evaluate(net, testIter, lossFuncs, device=device, encoder=classification)
        logStr = ""
        for i, arr in enumerate(losses):
            for j, val in enumerate(arr):
                logStr += (lossFuncs[i][j].__name__ if str(type(lossFuncs[i][j])) == "<class 'function'>" else type(lossFuncs[i][j]).__name__) + ": " + str(val) + " "

        print(logStr)

In [None]:
#modelName = "BaselineTotalSeg" + str(q)
classification = False
modelFile = "UsedModels/TotalSegmentatorTest3BestLoss"
classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], []]
#lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], [LossFunctions.accuracy, LossFunctions.f1]]

if classification:
    net = encoder
    net.load_state_dict(torch.load(modelFile), strict=False)
else:
    net = UNet(device, n_class=2, multiTask=False).to(device)
    net.load_state_dict(torch.load(modelFile), strict=False)

#Evaluate Model
print(f"Model: {modelName}")

losses = TrainingEval.evaluate(net, testIter, lossFuncs, device=device, encoder=classification)
logStr = ""
for i, arr in enumerate(losses):
    for j, val in enumerate(arr):
        logStr += (lossFuncs[i][j].__name__ if str(type(lossFuncs[i][j])) == "<class 'function'>" else type(lossFuncs[i][j]).__name__) + ": " + str(val) + " "

print(logStr)

In [None]:
modelName = ""
modelFile = "UsedModels/ContrastiveTest1BestLoss"

net = ContrastiveEncoder().to(device)
net.load_state_dict(torch.load(modelFile), strict=False)

lossFunc = LossFunctions.ContrastiveLossCosine

loss = TrainingEval.contrastiveEval(net, contTestIter, lossFunc, device=device, isDist=False)
print(str(lossFunc) + " " + str(loss))

In [None]:
#modelFile = "Run 2/Standard Pre-Training/PretrainedUNet7"
modelFile = "Run 2/Standard Pre-Training/PretrainedUNet6"
dataset = LITSBinaryDataset("Datasets/Scan1Dataset.hdf5")
iter = DataLoader(dataset, batch_size=batchSize)

net = UNet(0)
net.load_state_dict(torch.load(modelFile), strict=False)

net.to(device)

segmentationMask = TrainingEval.getMasks(net, iter, device=device)

masksFile = "PretrainMasksScan1"
wFile = h5py.File(masksFile, "w")

for i, slice in enumerate(segmentationMask):
    wFile.create_dataset("Slice" + str(i), data=slice)

wFile.close()

In [None]:
#Close datasets
trainDataset.closeFile()
validationDataset.closeFile()
testDataset.closeFile()