# **Setup**

In [None]:
import wandb
wandb.login()

In [5]:
import torch
from torch.utils.data import DataLoader
#from torchsummary import summary
import gc
import h5py
from UNet import UNet, Encoder, ContrastiveEncoder, ResNetBlock, ResidualBlock
from LITSDataset import LITSBinaryDataset, LITSContDataset, LITSMultiClassDataset
import LossFunctions
import TrainingEval
from tqdm import tqdm
import os

In [6]:
#Hyperparameters and training modifications
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "StandardUNet10"
modelFile = "UsedModels/" + modelName
#configFile = "evalConfig.txt"
configFile = "ContrastiveModels/contFineTuneConfig.txt"

startEpoch = 0
useWandB = 0
batchSize = 0
learnRate = 0
epochs = 0
startDim = 0
epochsToDouble = 0
progressive = 0
epochsToSave = 0
cosineAnnealing = 0
cosineRestartEpochs = 0

varDict = {
    "startEpoch":startEpoch,
    "useWandB":useWandB,
    "batchSize":batchSize,
    "learnRate":learnRate,
    "epochs":epochs,
    "startDim":startDim,
    "epochsToDouble":epochsToDouble,
    "progressive":progressive,
    "epochsToSave":epochsToSave,
    "cosineAnnealing":cosineAnnealing,
    "cosineRestartEpochs":cosineRestartEpochs,
}

TrainingEval.ParseConfig(configFile, varDict)

for key in varDict:
    if varDict[key].is_integer():
        locals()[key] = int(varDict[key])
    else:
        locals()[key] = varDict[key]

In [7]:
#Load Datasets
#trainDataset = LITSBinaryDataset("Datasets/StandardDatasets/FullTrainDataset.hdf5")
trainDataset = LITSBinaryDataset("Datasets/StandardDatasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("Datasets/StandardDatasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("Datasets/StandardDatasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

print("Datasets loaded")

Datasets loaded


In [None]:
#trainDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTrainingDataset.hdf5")
#validationDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassValidationDataset.hdf5")
testDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTestingDataset.hdf5")

#trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
#validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

In [None]:
testDataset = LITSBinaryDataset("Datasets/TotalSegmentator/TotalSegmentatorTesting.hdf5")
testIter = DataLoader(testDataset, batch_size=1)

# **Standard Training**

In [None]:
#Load model
initModel = ""

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

segmenter = UNet(device=device, n_class=1)
#print(summary(net, (1, 256, 256)))

#Loads model from file if using a pretrained version
if initModel != "":
    segmenter.load_state_dict(torch.load(initModel))

segmenter = segmenter.to(device)

print("Intialized standard UNet model")

In [None]:
#Train Model

if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName,
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

if useWandB:
    wandb.finish()

In [None]:
trainDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTrainingDataset.hdf5")
validationDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassValidationDataset.hdf5")
testDataset = LITSMultiClassDataset("Datasets/MultiClass/MultiClassTestingDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

#Load model
initModel = ""

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

for modelNum in range(10):
    modelFile = "MultiClassUNet" + str(modelNum)
    segmenter = UNet(device=device, n_class=2)
    #print(summary(net, (1, 256, 256)))

    #Loads model from file if using a pretrained version
    if initModel != "":
        segmenter.load_state_dict(torch.load(initModel))

    segmenter = segmenter.to(device)

    print("Intialized standard UNet model")

    #Train Model

    if useWandB:
        wandb.init(project="LiverSegmentation",
                name=modelName,
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Pre-Training**

In [None]:
focal = LossFunctions.FocalLoss(weight0=0.1, weight1=0.9, gamma=2)

lossFuncs = [[], [focal, LossFunctions.accuracy, LossFunctions.f1]]
weights = [[], [1, 0, 0]]

initEncoder = "Run 2/Progressive Encoders/ProgEncoder10"
encoderFile = "UsedModels/Encoder1"

encoder = Encoder(1)

if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

encoder = encoder.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="PreTrainedEncoder",
            name="UNetEncoder",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
            })

print(TrainingEval.train(encoder, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, encoderFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, encoder=True))

if useWandB:
    wandb.finish()

In [None]:
segmenter = UNet(encoder=encoder)

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationPreTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

if useWandB:
    wandb.finish()

# **Joint Training**

In [None]:
for i in tqdm(range(10)):
    modelFile = "ContJoint" + str(i)
    initEncoder = "ContrastiveModels/Encoders/ContrastiveEncoder" + str(i)

    encoder = Encoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, encoder=encoder)

    if useWandB:
        wandb.init(project="LiverSegmentationJointTraining",
                name="Weights:",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

    lossFuncs = [[LossFunctions.dice_score, LossFunctions.dice_loss], [LossFunctions.accuracy, classLossFunc]]
    weights = [[0, 0.6], [0, 0.4]]

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive)

    if useWandB:
        wandb.finish()

# **Contrastive Pre-Training**

In [None]:
contTrainDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveTrainDataset.hdf5")
contValDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveValidationDataset.hdf5")
contTestDataset = LITSContDataset("Datasets/ContrastiveDatasets/ScanBased/RandContrastiveTestDataset.hdf5")

contTrainIter = DataLoader(contTrainDataset, batch_size=batchSize, shuffle=True)
contValidationIter = DataLoader(contValDataset, batch_size=batchSize)
contTestIter = DataLoader(contTestDataset, batch_size=batchSize)

In [None]:
for i in tqdm(range(10)):
    modelName = "RandContrastiveEncoder" + str(i)
    modelFile = "UsedModels/" + modelName
    encoder = ContrastiveEncoder()

    if useWandB:
        wandb.init(project="LITSEncoderContrastive",
                name="Weights:",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })
        
    #lossFunc = LossFunctions.ContrastiveLossEuclidean
    lossFunc = LossFunctions.ContrastiveLossCosine(temp=(1 / batchSize))

    TrainingEval.contrastiveTrain(encoder, lossFunc, contTrainIter, contValidationIter, epochs, startEpoch, learnRate, device, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, isDist=False)

    if useWandB:
        wandb.finish()

    print("Model: " + str(i) + " finished")

In [None]:
initEncoder = "ContrastiveTest/ContrastiveTest1Encoder"

encoder = ContrastiveEncoder()
encoder.load_state_dict(torch.load(initEncoder), strict=False)

segmenter = UNet(device, encoder=encoder)
#segmenter.freezeEncoder()

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [9]:
for i in tqdm([5, 8]):
    initEncoder = "ContrastiveModels/Unsupervised/Encoders/RandContrastiveEncoder" + str(i)
    modelFile = "UsedModels/RandContFineTune" + str(i)

    print(modelFile)

    encoder = ContrastiveEncoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, n_class=1 , encoder=encoder)
    #segmenter.freezeEncoder()

    lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
    weights = [[1, 0], []]

    #Train Model
    gc.collect()

    if useWandB:
        wandb.init(project="LiverSegmentationPreTraining",
                name="NoWeights",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

    if useWandB:
        wandb.finish()

  0%|          | 0/2 [00:00<?, ?it/s]

UsedModels/RandContFineTune5
Training on cuda
Epoch 0:
Train Loss: 0.8849958181381226 Validation Loss: 0.9758964776992798 dice_loss: tensor(0.9759, device='cuda:0') dice_score: tensor(0.0711, device='cuda:0') 
Epoch 1:
Train Loss: 0.826005220413208 Validation Loss: 0.96075040102005 dice_loss: tensor(0.9608, device='cuda:0') dice_score: tensor(0.0777, device='cuda:0') 
Epoch 2:
Train Loss: 0.7761727571487427 Validation Loss: 0.9325522780418396 dice_loss: tensor(0.9326, device='cuda:0') dice_score: tensor(0.1660, device='cuda:0') 
Epoch 3:
Train Loss: 0.7367934584617615 Validation Loss: 0.9049058556556702 dice_loss: tensor(0.9049, device='cuda:0') dice_score: tensor(0.1931, device='cuda:0') 
Epoch 4:
Train Loss: 0.7085496783256531 Validation Loss: 0.8832738995552063 dice_loss: tensor(0.8833, device='cuda:0') dice_score: tensor(0.2010, device='cuda:0') 
Epoch 5:
Train Loss: 0.6881203651428223 Validation Loss: 0.8756838440895081 dice_loss: tensor(0.8757, device='cuda:0') dice_score: tensor

 50%|█████     | 1/2 [32:07<32:07, 1927.76s/it]

Epoch 99:
Train Loss: 0.500690758228302 Validation Loss: 0.24902032315731049 dice_loss: tensor(0.2490, device='cuda:0') dice_score: tensor(0.7510, device='cuda:0') 
UsedModels/RandContFineTune8
Training on cuda
Epoch 0:
Train Loss: 0.8926879167556763 Validation Loss: 0.9816311001777649 dice_loss: tensor(0.9816, device='cuda:0') dice_score: tensor(0.0242, device='cuda:0') 
Epoch 1:
Train Loss: 0.8585518598556519 Validation Loss: 0.9746121168136597 dice_loss: tensor(0.9746, device='cuda:0') dice_score: tensor(0.0594, device='cuda:0') 
Epoch 2:
Train Loss: 0.8273155689239502 Validation Loss: 0.9670119881629944 dice_loss: tensor(0.9670, device='cuda:0') dice_score: tensor(0.0621, device='cuda:0') 
Epoch 3:
Train Loss: 0.7933850884437561 Validation Loss: 0.9526105523109436 dice_loss: tensor(0.9526, device='cuda:0') dice_score: tensor(0.1110, device='cuda:0') 
Epoch 4:
Train Loss: 0.7647485136985779 Validation Loss: 0.9409874677658081 dice_loss: tensor(0.9410, device='cuda:0') dice_score: te

100%|██████████| 2/2 [1:03:52<00:00, 1916.20s/it]

Epoch 99:
Train Loss: 0.39651235938072205 Validation Loss: 0.09555500000715256 dice_loss: tensor(0.0956, device='cuda:0') dice_score: tensor(0.9568, device='cuda:0') 





In [None]:
for i in tqdm(range(10)):
    initEncoder = "ContrastiveModels/Unsupervised/Encoders/RandContrastiveEncoder" + str(i)
    modelFile = "UsedModels/RandContFineTuneMultiClass" + str(i)

    print(modelFile)

    encoder = ContrastiveEncoder()
    encoder.load_state_dict(torch.load(initEncoder), strict=False)

    segmenter = UNet(device, n_class=2, encoder=encoder)
    #segmenter.freezeEncoder()

    lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
    weights = [[1, 0], []]

    #Train Model
    gc.collect()

    if useWandB:
        wandb.init(project="LiverSegmentationPreTraining",
                name="NoWeights",
                config={
                    "BatchSize":batchSize,
                    "LearnRate":learnRate,
                    "Epochs":epochs,
                    "StartDimension":startDim,
                    "EpochsToDouble":epochsToDouble
                })

    TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
        cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

    if useWandB:
        wandb.finish()

# **Evaluation/Ending**

In [10]:
for q in [4, 5, 8, 9]:
    modelName = "BaselineTotalSeg" + str(q)
    classification = False
    modelFile = "UsedModels/RandContFineTune" + str(q) + "BestLoss"
    classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

    lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], []]
    #lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], [LossFunctions.accuracy, LossFunctions.f1]]

    if classification:
        net = encoder
        net.load_state_dict(torch.load(modelFile), strict=False)
    else:
        net = UNet(device, n_class=1, multiTask=False).to(device)
        net.load_state_dict(torch.load(modelFile), strict=False)

    #Evaluate Model
    print(f"Model: {modelName}")

    losses = TrainingEval.evaluate(net, testIter, lossFuncs, device=device, encoder=classification)
    logStr = ""
    for i, arr in enumerate(losses):
        for j, val in enumerate(arr):
            logStr += (lossFuncs[i][j].__name__ if str(type(lossFuncs[i][j])) == "<class 'function'>" else type(lossFuncs[i][j]).__name__) + ": " + str(val) + " "

    print(logStr)

TypeError: 'list' object cannot be interpreted as an integer

In [None]:
modelName = ""
modelFile = "UsedModels/ContrastiveTest1BestLoss"

net = ContrastiveEncoder().to(device)
net.load_state_dict(torch.load(modelFile), strict=False)

lossFunc = LossFunctions.ContrastiveLossCosine

loss = TrainingEval.contrastiveEval(net, contTestIter, lossFunc, device=device, isDist=False)
print(str(lossFunc) + " " + str(loss))

In [None]:
#modelFile = "Run 2/Standard Pre-Training/PretrainedUNet7"
modelFile = "Run 2/Standard Pre-Training/PretrainedUNet6"
dataset = LITSBinaryDataset("Datasets/Scan1Dataset.hdf5")
iter = DataLoader(dataset, batch_size=batchSize)

net = UNet(0)
net.load_state_dict(torch.load(modelFile), strict=False)

net.to(device)

segmentationMask = TrainingEval.getMasks(net, iter, device=device)

masksFile = "PretrainMasksScan1"
wFile = h5py.File(masksFile, "w")

for i, slice in enumerate(segmentationMask):
    wFile.create_dataset("Slice" + str(i), data=slice)

wFile.close()

In [None]:
#Close datasets
trainDataset.closeFile()
validationDataset.closeFile()
testDataset.closeFile()