# **Setup**

In [None]:
import wandb
wandb.login()

In [1]:
import torch
from torch.utils.data import DataLoader
#from torchsummary import summary
import gc
import h5py
from UNet import UNet, Encoder, ContrastiveEncoder, ResNetBlock, ResidualBlock
from LITSDataset import LITSBinaryDataset, LITSContDataset, LITSMultiClassDataset
import LossFunctions
import TrainingEval
from tqdm import tqdm
import os

In [2]:
#Hyperparameters and training modifications
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "ResUNetTest"
modelFile = "UsedModels/" + modelName
configFile = "SPIE_Models/standardConfigRun2.txt"

startEpoch = 0
useWandB = 0
batchSize = 0
learnRate = 0
epochs = 0
startDim = 0
epochsToDouble = 0
progressive = 0
epochsToSave = 0
cosineAnnealing = 0
cosineRestartEpochs = 0

varDict = {
    "startEpoch":startEpoch,
    "useWandB":useWandB,
    "batchSize":batchSize,
    "learnRate":learnRate,
    "epochs":epochs,
    "startDim":startDim,
    "epochsToDouble":epochsToDouble,
    "progressive":progressive,
    "epochsToSave":epochsToSave,
    "cosineAnnealing":cosineAnnealing,
    "cosineRestartEpochs":cosineRestartEpochs,
}

TrainingEval.ParseConfig(configFile, varDict)

for key in varDict:
    if varDict[key].is_integer():
        locals()[key] = int(varDict[key])
    else:
        locals()[key] = varDict[key]

In [None]:
#Load Datasets
#trainDataset = LITSBinaryDataset("Datasets/StandardDatasets/FullTrainDataset.hdf5")
trainDataset = LITSBinaryDataset("Datasets/StandardDatasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("Datasets/StandardDatasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("Datasets/StandardDatasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

print("Datasets loaded")

# **Standard Training**

In [None]:
#Load model
initModel = ""

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

segmenter = UNet(n_class=1, block=ResNetBlock)
#print(summary(net, (1, 256, 256)))

#Loads model from file if using a pretrained version
if initModel != "":
    segmenter.load_state_dict(torch.load(initModel))

segmenter = segmenter.to(device)

print("Intialized standard UNet model")

In [None]:
#Train Model

if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName,
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

if useWandB:
    wandb.finish()

In [3]:
trainDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTrainingDataset.hdf5")
validationDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassValidationDataset.hdf5")
testDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTestingDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

#Load model
initModel = ""

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

for modelNum in range(10):
    modelFile = "MultiClassUNet" + str(modelNum)
    segmenter = UNet(device=device, n_class=2)
    #print(summary(net, (1, 256, 256)))

    #Loads model from file if using a pretrained version
    if initModel != "":
        segmenter.load_state_dict(torch.load(initModel))

    segmenter = segmenter.to(device)

    print("Intialized standard UNet model")

    #Train Model

    if useWandB:
        wandb.init(project="LiverSegmentation",
                name=modelName,
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

Intialized standard UNet model
Training on cuda
Epoch 0:
Train Loss: 0.8805961608886719 Validation Loss: 0.9427634477615356 dice_loss: tensor(0.9428, device='cuda:0') dice_score: tensor(0.0652, device='cuda:0') 
Epoch 1:
Train Loss: 0.8169339299201965 Validation Loss: 0.9166449308395386 dice_loss: tensor(0.9166, device='cuda:0') dice_score: tensor(0.0875, device='cuda:0') 
Epoch 2:
Train Loss: 0.7926157116889954 Validation Loss: 0.9088678956031799 dice_loss: tensor(0.9089, device='cuda:0') dice_score: tensor(0.0934, device='cuda:0') 
Epoch 3:
Train Loss: 0.7684995532035828 Validation Loss: 0.8824055790901184 dice_loss: tensor(0.8824, device='cuda:0') dice_score: tensor(0.5234, device='cuda:0') 
Epoch 4:
Train Loss: 0.7456583380699158 Validation Loss: 0.5872103571891785 dice_loss: tensor(0.5872, device='cuda:0') dice_score: tensor(0.5048, device='cuda:0') 
Epoch 5:
Train Loss: 0.5270226001739502 Validation Loss: 0.3593154549598694 dice_loss: tensor(0.3593, device='cuda:0') dice_score: t

# **Pre-Training**

In [None]:
focal = LossFunctions.FocalLoss(weight0=0.1, weight1=0.9, gamma=2)

lossFuncs = [[], [focal, LossFunctions.accuracy, LossFunctions.f1]]
weights = [[], [1, 0, 0]]

initEncoder = "Run 2/Progressive Encoders/ProgEncoder10"
encoderFile = "UsedModels/Encoder1"

encoder = Encoder(1)

if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

encoder = encoder.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="PreTrainedEncoder",
            name="UNetEncoder",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
            })

print(TrainingEval.train(encoder, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, encoderFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, encoder=True))

if useWandB:
    wandb.finish()

In [None]:
segmenter = UNet(encoder=encoder)

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationPreTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

if useWandB:
    wandb.finish()

# **Joint Training**

In [None]:
for i in tqdm(range(10)):
    modelFile = "ContJoint" + str(i)
    initEncoder = "ContrastiveModels/Encoders/ContrastiveEncoder" + str(i)

    encoder = Encoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, encoder=encoder)

    if useWandB:
        wandb.init(project="LiverSegmentationJointTraining",
                name="Weights:",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

    lossFuncs = [[LossFunctions.dice_score, LossFunctions.dice_loss], [LossFunctions.accuracy, classLossFunc]]
    weights = [[0, 0.6], [0, 0.4]]

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Contrastive Pre-Training**

In [None]:
contTrainDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveTrainDataset.hdf5")
contValDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveValidationDataset.hdf5")
contTestDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveTestDataset.hdf5")

contTrainIter = DataLoader(contTrainDataset, batch_size=batchSize, shuffle=True)
contValidationIter = DataLoader(contValDataset, batch_size=batchSize)
contTestIter = DataLoader(contTestDataset, batch_size=batchSize)

In [None]:
for i in tqdm(range(10)):
    modelName = "RandContrastiveEncoder" + str(i)
    modelFile = "UsedModels/" + modelName
    encoder = ContrastiveEncoder()

    if useWandB:
        wandb.init(project="LITSEncoderContrastive",
                name="Weights:",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })
        
    #lossFunc = LossFunctions.ContrastiveLossEuclidean
    lossFunc = LossFunctions.ContrastiveLossCosine(temp=(1 / batchSize))

    TrainingEval.contrastiveTrain(encoder, lossFunc, contTrainIter, contValidationIter, epochs, startEpoch, learnRate, device, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, isDist=False)

    if useWandB:
        wandb.finish()

    print("Model: " + str(i) + " finished")

In [None]:
initEncoder = "ContrastiveTest/ContrastiveTest1Encoder"

encoder = ContrastiveEncoder()
encoder.load_state_dict(torch.load(initEncoder), strict=False)

segmenter = UNet(device, encoder=encoder)
#segmenter.freezeEncoder()

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [5]:
for i in tqdm(range(10)):
    initEncoder = "ContrastiveModels/Supervised/Encoders/ContrastiveEncoder" + str(i)
    modelFile = "UsedModels/ContFineTuneMultiClass" + str(i)

    print(modelFile)

    encoder = ContrastiveEncoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, n_class=2, encoder=encoder)
    #segmenter.freezeEncoder()

    lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
    weights = [[1, 0], []]

    #Train Model
    gc.collect()

    if useWandB:
        wandb.init(project="LiverSegmentationPreTraining",
                name="NoWeights",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

    if useWandB:
        wandb.finish()

  0%|          | 0/10 [00:00<?, ?it/s]

UsedModels/ContFineTuneMultiClass0
Training on cuda
Epoch 0:
Train Loss: 0.8802812099456787 Validation Loss: 0.9569616913795471 dice_loss: tensor(0.9570, device='cuda:0') dice_score: tensor(0.0480, device='cuda:0') 
Epoch 1:
Train Loss: 0.8154916763305664 Validation Loss: 0.9105811715126038 dice_loss: tensor(0.9106, device='cuda:0') dice_score: tensor(0.0926, device='cuda:0') 
Epoch 2:
Train Loss: 0.7911093235015869 Validation Loss: 0.9038286805152893 dice_loss: tensor(0.9038, device='cuda:0') dice_score: tensor(0.0978, device='cuda:0') 
Epoch 3:
Train Loss: 0.779010534286499 Validation Loss: 0.897773265838623 dice_loss: tensor(0.8978, device='cuda:0') dice_score: tensor(0.1029, device='cuda:0') 
Epoch 4:
Train Loss: 0.7678555250167847 Validation Loss: 0.8915776014328003 dice_loss: tensor(0.8916, device='cuda:0') dice_score: tensor(0.1093, device='cuda:0') 
Epoch 5:
Train Loss: 0.7560440897941589 Validation Loss: 0.8674129247665405 dice_loss: tensor(0.8674, device='cuda:0') dice_score:

In [None]:
for i in tqdm(range(10)):
    initEncoder = "ContrastiveModels/Unsupervised/Encoders/RandContrastiveEncoder" + str(i)
    modelFile = "UsedModels/RandContFineTuneMultiClass" + str(i)

    print(modelFile)

    encoder = ContrastiveEncoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, n_class=2, encoder=encoder)
    #segmenter.freezeEncoder()

    lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
    weights = [[1, 0], []]

    #Train Model
    gc.collect()

    if useWandB:
        wandb.init(project="LiverSegmentationPreTraining",
                name="NoWeights",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

    if useWandB:
        wandb.finish()

# **Evaluation/Ending**

In [None]:
for k in range(10):
    print(str(k + 1) + " Scans")

    for q in range(10):
        modelName = "ContReducedFineTune" + str(k + 1) + "Scans" + str(q)
        classification = False
        modelFile = "TempStorage/RandContReducedFineTune" + str(k + 1) + "Scans" + str(q) + "BestLoss"
        classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

        lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], []]
        #lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], [LossFunctions.accuracy, LossFunctions.f1]]

        if classification:
            net = encoder
            net.load_state_dict(torch.load(modelFile), strict=False)
        else:
            net = UNet(device, multiTask=False).to(device)
            net.load_state_dict(torch.load(modelFile), strict=False)

        #Evaluate Model
        print(f"Model: {modelName}")

        losses = TrainingEval.evaluate(net, testIter, lossFuncs, device=device, encoder=classification)
        logStr = ""
        for i, arr in enumerate(losses):
            for j, val in enumerate(arr):
                logStr += (lossFuncs[i][j].__name__ if str(type(lossFuncs[i][j])) == "<class 'function'>" else type(lossFuncs[i][j]).__name__) + ": " + str(val) + " "

        print(logStr)

    print(" ")

In [None]:
modelName = ""
modelFile = "UsedModels/ContrastiveTest1BestLoss"

net = ContrastiveEncoder().to(device)
net.load_state_dict(torch.load(modelFile), strict=False)

lossFunc = LossFunctions.ContrastiveLossCosine

loss = TrainingEval.contrastiveEval(net, contTestIter, lossFunc, device=device, isDist=False)
print(str(lossFunc) + " " + str(loss))

In [None]:
#modelFile = "Run 2/Standard Pre-Training/PretrainedUNet7"
modelFile = "Run 2/Standard Pre-Training/PretrainedUNet6"
dataset = LITSBinaryDataset("Datasets/Scan1Dataset.hdf5")
iter = DataLoader(dataset, batch_size=batchSize)

net = UNet(0)
net.load_state_dict(torch.load(modelFile), strict=False)

net.to(device)

segmentationMask = TrainingEval.getMasks(net, iter, device=device)

masksFile = "PretrainMasksScan1"
wFile = h5py.File(masksFile, "w")

for i, slice in enumerate(segmentationMask):
    wFile.create_dataset("Slice" + str(i), data=slice)

wFile.close()

In [None]:
#Close datasets
trainDataset.closeFile()
validationDataset.closeFile()
testDataset.closeFile()