# **Setup**

In [None]:
import wandb
wandb.login()

In [1]:
import torch
from torch.utils.data import DataLoader
#from torchsummary import summary
import gc
import h5py
from UNet import UNet, Encoder
from LITSDataset import LITSBinaryDataset
import LossFunctions
import TrainingEval

  warn(f"Failed to load image Python extension: {e}")


In [2]:
#Hyperparameters and training modifications
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "MultiTaskW=0.9-0.1,T=0"
modelFile = "UsedModels/" + modelName
configFile = "Run 2\Joint Training Models\jointConfigRun2.txt"

startEpoch = 0
useWandB = 0
batchSize = 0
learnRate = 0
epochs = 0
startDim = 0
epochsToDouble = 0
progressive = 0
epochsToSave = 0
cosineAnnealing = 0
cosineRestartEpochs = 0

varDict = {
    "startEpoch":startEpoch,
    "useWandB":useWandB,
    "batchSize":batchSize,
    "learnRate":learnRate,
    "epochs":epochs,
    "startDim":startDim,
    "epochsToDouble":epochsToDouble,
    "progressive":progressive,
    "epochsToSave":epochsToSave,
    "cosineAnnealing":cosineAnnealing,
    "cosineRestartEpochs":cosineRestartEpochs,
}

TrainingEval.ParseConfig(configFile, varDict)

for key in varDict:
    if varDict[key].is_integer():
        locals()[key] = int(varDict[key])
    else:
        locals()[key] = varDict[key]

In [3]:
#Load Datasets
trainDataset = LITSBinaryDataset("Datasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("Datasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("Datasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

print("Datasets loaded")

Datasets loaded


# **Standard Training**

In [None]:
#Load model
initModel = ""

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

segmenter = UNet(1)
#print(summary(net, (1, 256, 256)))

#Loads model from file if using a pretrained version
if initModel != "":
    segmenter.load_state_dict(torch.load(initModel))

segmenter = segmenter.to(device)

print("Intialized standard UNet model")

In [None]:
#Train Model

if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName,
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

if useWandB:
    wandb.finish()

# **Pre-Training**

In [None]:
focal = LossFunctions.FocalLoss(weight0=0.1, weight1=0.9, gamma=2)

lossFuncs = [[], [focal, LossFunctions.accuracy, LossFunctions.f1]]
weights = [[], [1, 0, 0]]

initEncoder = "Run 2/Progressive Encoders/ProgEncoder10"
encoderFile = "UsedModels/Encoder1"

encoder = Encoder(1)

if initEncoder != "":
    encoder.load_state_dict(torch.load(initEncoder))

encoder = encoder.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="PreTrainedEncoder",
            name="UNetEncoder",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
            })

print(TrainingEval.train(encoder, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, encoderFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, encoder=True))

if useWandB:
    wandb.finish()

In [None]:
segmenter = UNet(encoder=encoder)

lossFuncs = [[LossFunctions.dice_loss, LossFunctions.dice_score], []]
weights = [[1, 0], []]

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationPreTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

if useWandB:
    wandb.finish()

# **Joint Training**

In [4]:
segmenter = UNet(device, multiTask=True, classThreshold=0, segmentThreshold=0)

if useWandB:
    wandb.init(project="LiverSegmentationJointTraining",
            name="Weights:",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

lossFuncs = [[LossFunctions.dice_score, LossFunctions.dice_loss], [LossFunctions.accuracy, classLossFunc]]
weights = [[0, 0.75], [0, 0.25]]

TrainingEval.train(segmenter, lossFuncs, weights, trainIter, validationIter, epochs, startEpoch, learnRate, device, startDim, epochsToDouble, modelFile, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0)

if useWandB:
    wandb.finish()

Training on cuda
Epoch 0:
Train Loss: 0.5339672565460205 Validation Loss: 0.19579926133155823 dice_score: tensor(0.7880, device='cuda:0') dice_loss: tensor(0.2510, device='cuda:0') accuracy: 0.9366666666666668 FocalLoss: tensor(0.0303, device='cuda:0') 
Epoch 1:
Train Loss: 0.2571074068546295 Validation Loss: 0.17080852389335632 dice_score: tensor(0.8468, device='cuda:0') dice_loss: tensor(0.2193, device='cuda:0') accuracy: 0.95 FocalLoss: tensor(0.0252, device='cuda:0') 
Epoch 2:
Train Loss: 0.20737643539905548 Validation Loss: 0.6549866199493408 dice_score: tensor(0.2321, device='cuda:0') dice_loss: tensor(0.8509, device='cuda:0') accuracy: 0.3233333333333333 FocalLoss: tensor(0.0671, device='cuda:0') 
Epoch 3:
Train Loss: 0.1805156022310257 Validation Loss: 0.14314082264900208 dice_score: tensor(0.8839, device='cuda:0') dice_loss: tensor(0.1830, device='cuda:0') accuracy: 0.93 FocalLoss: tensor(0.0237, device='cuda:0') 
Epoch 4:
Train Loss: 0.15399819612503052 Validation Loss: 0.106

# **Evaluation/Ending**

In [5]:
modelName = ""
classification = False
modelFile = "UsedModels/MultiTaskW=0.9-0.1,T=0BestLoss"
classLossFunc = LossFunctions.FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], [LossFunctions.accuracy, LossFunctions.f1]]
#lossFuncs = [[LossFunctions.dice_score, LossFunctions.hausdorff], [LossFunctions.accuracy, LossFunctions.f1]]

if classification:
    net = encoder
    net.load_state_dict(torch.load(modelFile), strict=False)
else:
    net = UNet(device, multiTask=True).to(device)
    net.load_state_dict(torch.load(modelFile), strict=False)

#Evaluate Model
print(f"Model: {modelName}")

losses = TrainingEval.evaluate(net, testIter, lossFuncs, device=device, encoder=classification)
logStr = ""
for i, arr in enumerate(losses):
    for j, val in enumerate(arr):
        logStr += (lossFuncs[i][j].__name__ if str(type(lossFuncs[i][j])) == "<class 'function'>" else type(lossFuncs[i][j]).__name__) + ": " + str(val) + " "

print(logStr)

Model: 
dice_score: tensor(0.8886, device='cuda:0') hausdorff: 13.689741821647782 accuracy: 0.9280303030303029 f1: 0.0 


In [None]:
#modelFile = "Run 2/Standard Pre-Training/PretrainedUNet7"
modelFile = "Run 2/Standard Pre-Training/PretrainedUNet6"
dataset = LITSBinaryDataset("Datasets/Scan1Dataset.hdf5")
iter = DataLoader(dataset, batch_size=batchSize)

net = UNet(0)
net.load_state_dict(torch.load(modelFile), strict=False)

net.to(device)

segmentationMask = TrainingEval.getMasks(net, iter, device=device)

masksFile = "PretrainMasksScan1"
wFile = h5py.File(masksFile, "w")

for i, slice in enumerate(segmentationMask):
    wFile.create_dataset("Slice" + str(i), data=slice)

wFile.close()

In [None]:
#Close datasets
trainDataset.closeFile()
validationDataset.closeFile()
testDataset.closeFile()