# **Imports/Installs**

In [1]:
!pip install -q condacolab
import condacolab
condacolab.install()
!pip install torch
!pip install torchvision
!pip install d2l==1.0.0b0

⏬ Downloading https://github.com/conda-forge/miniforge/releases/download/23.1.0-1/Mambaforge-23.1.0-1-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:15
🔁 Restarting kernel...
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting networkx
  Downloading networkx-3.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m69.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cusparse-cu11==11.7.4.91
  Downloading nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m173.2/173.2 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?2

In [None]:
!pip install wandb -qU
import wandb
wandb.login()

[0m

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mamoseley018[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [1]:
import os
import torch
import torchvision
from torchvision import transforms
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchsummary import summary
from d2l import torch as d2l
from pathlib import Path
import sys
import nibabel as nib
import numpy as np
import gc
import skimage
import h5py
import math
import PIL
from PIL import Image

# **Loss Functions**

In [2]:
def dice_loss(pred, target, smooth = 1.):
    target = torch.clamp(target, min=0, max=1)

    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred.mul(target)).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    roundedPreds = torch.round(pred)

    intersectionRounded = (roundedPreds.mul(target)).sum(dim=2).sum(dim=2)
    roundedLoss = ((2. * intersectionRounded + smooth) / (roundedPreds.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))

    return loss.mean(), roundedLoss.mean()

In [3]:
class FocalLoss(nn.Module):
    def __init__(self, weight0=1, weight1=1, gamma=0):
        super().__init__()

        self.weight0 = weight0
        self.weight1 = weight1
        self.gamma = gamma

    def forward(self, input, target):
        loss = 0

        predictions = torch.round(input)
        accuratePreds = 0

        #Takes negative average loss over each element of input
        #Loss = ln(prediction) * (absolute loss ^ gamma) * class weight
        #prediction is the predicted likelihood that the correct label is true
        for i, el in enumerate(input):
            #print(f"{predictions[i]} {el}")

            if predictions[i] == target[i]:
                accuratePreds += 1

            if target[i] == 1:
                loss += torch.log(el) * (abs(1 - el) ** self.gamma) * self.weight1
            else:
                loss += torch.log(1 - el) * (abs(0 - el) ** self.gamma) * self.weight0

        return -1 * loss / len(input), accuratePreds / input.size()[0]

In [4]:
class BalancedCELoss(nn.Module):
    def __init__(self, weight0=1, weight1=1):
        super().__init__()

        self.weight0 = weight0
        self.weight1 = weight1

    def forward(self, input, target):
        loss = 0

        predictions = torch.round(input)
        accuratePreds = 0

        #Takes negative average loss over each element of input
        #Loss = ln(prediction) * class weight
        #prediction is the predicted likelihood that the correct label is true
        for i, el in enumerate(input):
            if predictions[i] == target[i]:
                accuratePreds += 1

            if target[i] == 1:
                loss += torch.log(el) * self.weight1
            else:
                loss += torch.log(1 - el) * self.weight0

        return -1 * loss / len(input), accuratePreds / input.size()[0]

# **Data Handling**

In [5]:
class LITSBinaryDataset(Dataset):
    def __init__(self, fileName):
        super().__init__()

        #Keeps a file pointer open throughout use
        self.file = h5py.File(fileName, 'r')

        #Precalculates length to reduce training computations
        self.length = len(list(self.file.keys()))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        data = self.file["Slice" + str(idx)]["Slice"]
        segmentation = self.file["Slice" + str(idx)]["Segmentation"]
        label = self.file["Slice" + str(idx)].attrs.get("ImageLabel")

        result = []

        #Returns list containing slice data and image label
        #Does not currently return segmentation data, will need to implement for decoder
        result.append(torch.Tensor(data[...]).unsqueeze(0))
        result.append(torch.Tensor(segmentation[...]).unsqueeze(0))
        result.append(torch.Tensor(label).squeeze(0))

        return result

    def closeFile(self):
        #Closes file once dataset is no longer being used
        #Do not use class instance after this function is called
        self.file.close()

# **Network**

In [6]:
class convBlock(nn.Module):
    def __init__(self, inChannels, outChannels, strides, dropout) -> None:
        super().__init__()

        batchNorm = True
        layerMean = 1.5
        layerDev = 0.1

        #Uses 2 convolutional layers for each block
        self.conv1 = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)

        #Initializes convolutional layers using hyperparameters for mean and standard deviation
        nn.init.normal_(self.conv1.weight, mean=layerMean, std=layerDev)
        nn.init.normal_(self.conv2.weight, mean=layerMean, std=layerDev)

        self.dropout = nn.Dropout(dropout).to(device)

        if(batchNorm):
            self.bn1 = nn.BatchNorm2d(outChannels)
        else:
            self.bn1 = False

    def forward(self, X):
        Y = self.conv1(X)
        if self.bn1:
            Y = self.bn1(Y)
        Y = F.relu(Y)

        Y = self.conv2(Y)
        if self.bn1:
            Y = self.bn1(Y)
        Y = F.relu(Y)

        Y = self.dropout(Y)

        return Y

In [7]:
class DecoderBlock(nn.Module):
    def __init__(self, inChannels, outChannels, strides, dropout) -> None:
        super().__init__()

        self.convTrans = nn.ConvTranspose2d(inChannels, outChannels, 2, stride=2, padding=0)
        self.conv = convBlock(inChannels, outChannels, strides, dropout)

    def forward(self, X, skipConn):
        Y = self.convTrans(X)
        Y = torch.cat((Y, skipConn), dim=1)

        return self.conv(Y)

In [8]:
class DecoderNetwork(nn.Module):
        def __init__(self, strides, dropout, device) -> None:
            super().__init__()

            self.device = device
            
            self.block1 = DecoderBlock(256, 128, strides, dropout)
            self.block2 = DecoderBlock(128, 64, strides, dropout)
            self.block3 = DecoderBlock(64, 32, strides, dropout)
            self.block4 = DecoderBlock(32, 16, strides, dropout)

            self.endBlock = nn.Conv2d(16, 1, kernel_size=1, padding=0, stride=1)

            self.sigm = nn.Sigmoid()

        def forward(self, X, skipConn):
            y = X

            y = self.block1(y, skipConn[-1])
            y = self.block2(y, skipConn[-2])
            y = self.block3(y, skipConn[-3])
            y = self.block4(y, skipConn[-4])
            y = self.endBlock(y)

            return self.sigm(y)

In [9]:
class EncoderNetwork(nn.Module):
    def __init__(self, strides, dropout, device) -> None:
        super().__init__()

        self.device = device

        #Creates a list of encoder blocks w/ in and out channels specified by parameter
        self.block1 = convBlock(1, 16, strides, dropout)
        self.block2 = convBlock(16, 32, strides, dropout)
        self.block3 = convBlock(32, 64, strides, dropout)
        self.block4 = convBlock(64, 128, strides, dropout)
        self.block5 = convBlock(128, 256, strides, dropout)

        self.pool = nn.MaxPool2d(2)

        #Creates classification branch as sequential
        #Try without using sequential, use each layer separately
        #Can use without Flatten, average pool does the same thing
        #Follow MultiMix code
        self.classification = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(256, 1), nn.Sigmoid()).to(device)

    def forward(self, X):
        y = X.to(self.device)

        skipConnections = []

        y = self.block1(y)
        skipConnections.append(y)
        y = self.pool(y)

        y = self.block2(y)
        skipConnections.append(y)
        y = self.pool(y)

        y = self.block3(y)
        skipConnections.append(y)
        y = self.pool(y)

        y = self.block4(y)
        skipConnections.append(y)
        y = self.pool(y)

        y = self.block5(y)

        return self.classification(y), skipConnections, y

In [10]:
class SegmentationNetwork(nn.Module):
    def __init__(self, encoder, decoder, dropout, device) -> None:
        super().__init__()

        self.device = device
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, X):
        _, skip, y = self.encoder(X)

        y = self.decoder(y, skip)

        return y

# **Multi Mix Code**

In [11]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.InstanceNorm2d(in_channels),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.25)
    )   

class MultiMix(nn.Module):

    def __init__(self, n_class = 1, encoder=None):
        super().__init__()

        if encoder == None:
            self.encoder = Encoder(1)
        else:
            self.encoder = encoder

        self.decoder = Decoder(1)
        

    def forward(self, x):
        outC, conv5, conv4, conv3, conv2, conv1 = self.encoder(x)
        outSeg = self.decoder(x, conv5, conv4, conv3, conv2, conv1)

        # return outSeg, outC, saliency
        return outSeg, outC

    def freezeEncoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

class Encoder(nn.Module):

    def __init__(self, n_class = 1):
        super().__init__()
                
        self.dconv_down1 = double_conv(1, 16)
        self.dconv_down2 = double_conv(16, 32)
        self.dconv_down3 = double_conv(32, 64)
        self.dconv_down4 = double_conv(64, 128)
        self.dconv_down5 = double_conv(128, 256)      
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))       
        self.fc = nn.Linear(256, 1) 
        self.sigm = nn.Sigmoid()

        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   

        conv4 = self.dconv_down4(x)
        x = self.maxpool(conv4)

        conv5 = self.dconv_down5(x)
        x1 = self.maxpool(conv5)
        
        avgpool = self.avgpool(x1)
        avgpool = avgpool.view(avgpool.size(0), -1)
        outC = self.fc(avgpool)
        
        return self.sigm(outC), conv5, conv4, conv3, conv2, conv1

class Decoder(nn.Module):

    def __init__(self, n_class = 1, nonlocal_mode='concatenation', attention_dsample = (2,2)):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up4 = double_conv(256 + 128, 128)
        self.dconv_up3 = double_conv(128 + 64, 64)
        self.dconv_up2 = double_conv(64 + 32, 32)
        self.dconv_up1 = double_conv(32 + 16, 16)
        self.conv_last = nn.Conv2d(16, n_class, 1)

        self.conv_last_saliency = nn.Conv2d(17, n_class, 1)

        self.sigm = nn.Sigmoid()
        
        
    def forward(self, input, conv5, conv4, conv3, conv2, conv1):
  
        x = self.upsample(conv5)        
        x = torch.cat([x, conv4], dim=1)

        x = self.dconv_up4(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)       

        x = self.dconv_up3(x)
        x = self.upsample(x)        
        # pdb.set_trace()
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1) 

        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        
        return self.sigm(out)

# **Training**

In [12]:
def evaluate_accuracy(net, testIter, lossFunc, classification=True, device=None):
    net.eval()

    #Accuracy, loss
    metric = d2l.Accumulator(2)

    segmentationMask = []

    with torch.no_grad():
        for i, (X, y1, y2) in enumerate(testIter):
            X = X.to(device)
            y1 = y1.to(device)
            y2 = y2.to(device)

            yhat = net(X)

            if isinstance(yhat, tuple):
                yhat = yhat[0]

            if classification:
                loss, accuracy = lossFunc(yhat, y2)
            else:
                loss, accuracy = lossFunc(yhat, y1)

            metric.add(accuracy, loss)

            segmentationMask.append(torch.round(yhat).tolist())

    return metric[0] / len(testIter), metric[1] / len(testIter)

In [13]:
def final_eval_get_mask(net, testIter, lossFunc1, lossFunc2, device=None):
    net.eval()

    #Accuracy, loss
    segmentMetric = d2l.Accumulator(2)
    classificationMetric = d2l.Accumulator(2)

    segmentationMask = []

    with torch.no_grad():
        for i, (X, y1, y2) in enumerate(testIter):
            X = X.to(device)
            y1 = y1.to(device)
            y2 = y2.to(device)

            segment, classPred = net(X)

            loss1, acc1 = lossFunc1(segment, y1)
            loss2, acc2 = lossFunc2(classPred, y2)

            segmentMetric.add(acc1, loss1)
            classificationMetric.add(acc2, loss2)

            segmentationMask.append(torch.round(segment).squeeze(0).squeeze(0).tolist())

    return segmentMetric[0] / len(testIter), segmentMetric[1] / len(testIter), classificationMetric[0] / len(testIter), classificationMetric[1] / len(testIter), segmentationMask

In [14]:
def joint_eval(net, testIter, classLossFunc, segmentLossFunc, device=None):
    net.eval()

    #Accuracy, loss
    metric = d2l.Accumulator(4)

    segmentationMask = []

    with torch.no_grad():
        for i, (X, y1, y2) in enumerate(testIter):
            X = X.to(device)
            y1 = y1.to(device)
            y2 = y2.to(device)

            yhat = net(X)

            classLoss, classAccuracy = classLossFunc(yhat[1], y2)
            segmentLoss, segmentAccuracy = segmentLossFunc(yhat[0], y1)

            metric.add(classAccuracy, classLoss, segmentAccuracy, segmentLoss)

    return metric[0] / len(testIter), metric[1] / len(testIter), metric[2] / len(testIter), metric[3] / len(testIter)

In [15]:
def train(net: nn.Module, trainIter, testIter, numEpochs, startEpoch, learnRate, batchSize, device: torch.device, startDim, epochsToDouble, modelFileName, epochsToSave, 
          useWandB=False, cosineAnnealing=True, restartEpochs=-1, progressive=False, lossFunc = nn.BCEWithLogitsLoss(), classification=True):
    print(f"Training on {device}")
    
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learnRate)

    #Setting restartEpochs to a negative will use no warm restarts, otherwise will use warm restarts 
    if cosineAnnealing:
        if restartEpochs < 0:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, restartEpochs, T_mult=1)

    numBatches = len(trainIter)
    bestValLoss = float('inf')

    currDim = startDim
    for epoch in range(startEpoch, numEpochs):
        net.train()
        
        #Loss, accuracy
        metric = d2l.Accumulator(2)

        for i, (X, y1, y2) in enumerate(trainIter):
            optimizer.zero_grad()
            y1 = y1.to(device)
            y2 = y2.to(device)

            if progressive > 0:
                #If using progressive learning, downsamples image to the current dimension
                X = F.interpolate(X, size=int(currDim))

            X = X.to(device)
            
            yhat = net(X)

            if isinstance(yhat, tuple):
                yhat = yhat[0]

            if classification:
                l, accuracy = lossFunc(yhat, y2)
            else:
                l, accuracy = lossFunc(yhat, y1)

            #print(f"Loss: {l.item()} Predictions: {yhat.tolist()} Labels: {y.tolist()}")

            l.backward()
            optimizer.step()

            if cosineAnnealing:
                scheduler.step(epoch + i / numBatches)

            metric.add(l, accuracy)

        #Progressive learning
        if (epoch + 1) % epochsToDouble == 0 and progressive == 1:
            currDim *= 2
        #Reverse progressive learning
        elif (epoch + 1) % epochsToDouble == 0 and progressive == 2:
            currDim /= 2

        #Checkpoints model
        if (epoch + 1) % epochsToSave == 0:
            torch.save(net.state_dict(), modelFileName + "Epoch" + str(epoch))

        validationAcc, validationLoss = evaluate_accuracy(net, testIter, lossFunc, classification=classification, device=device)

        #Overwrites previous best model based on validation accuracy
        if validationLoss < bestValLoss:
            bestValLoss = validationLoss
            torch.save(net.state_dict(), modelFileName + "BestLoss")

        print(f"Epoch {epoch}:\nTrain Acc: {metric[1] / numBatches} Validation Acc: {validationAcc} Train Loss: {metric[0] / numBatches} Validation Loss: {validationLoss}")

        #Externally logs epoch info to WandB
        if useWandB:
            wandb.log({"Train Acc": metric[1] / numBatches,
                    "Validation Acc": validationAcc,
                    "Train Loss": metric[0] / numBatches,
                    "Validation Loss": validationLoss
                    })

In [16]:
def joint_train(net: nn.Module, trainIter, testIter, numEpochs, startEpoch, learnRate, batchSize, device: torch.device, modelFileName, epochsToSave, 
          useWandB=False, cosineAnnealing=True, restartEpochs=-1, classLossFunc = None, segmentLossFunc = None, weights = [0.5, 0.5]):
    print(f"Training on {device}")
    
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learnRate)

    #Setting restartEpochs to a negative will use no warm restarts, otherwise will use warm restarts 
    if cosineAnnealing:
        if restartEpochs < 0:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, restartEpochs, T_mult=1)

    numBatches = len(trainIter)
    bestValLoss = float('inf')

    currDim = startDim
    for epoch in range(startEpoch, numEpochs):
        net.train()
        
        #Loss, accuracy
        metric = d2l.Accumulator(2)

        for i, (X, y1, y2) in enumerate(trainIter):
            optimizer.zero_grad()
            y1 = y1.to(device)
            y2 = y2.to(device)
            X = X.to(device)
            
            yhat = net(X)

            segmentLoss, segmentAcc = segmentLossFunc(yhat[0], y1)
            classLoss, classAcc = classLossFunc(yhat[1], y2)

            #print(f"Loss: {l.item()} Predictions: {yhat.tolist()} Labels: {y.tolist()}")

            l = (weights[0] * segmentLoss) + (weights[1] * classLoss)

            l.backward()
            optimizer.step()

            if cosineAnnealing:
                scheduler.step(epoch + i / numBatches)

            metric.add(l, segmentAcc)

        #Progressive learning
        if (epoch + 1) % epochsToDouble == 0 and progressive == 1:
            currDim *= 2
        #Reverse progressive learning
        elif (epoch + 1) % epochsToDouble == 0 and progressive == 2:
            currDim /= 2

        #Checkpoints model
        if (epoch + 1) % epochsToSave == 0:
            torch.save(net.state_dict(), modelFileName + "Epoch" + str(epoch))

        validationAcc, validationLoss = evaluate_accuracy(net, testIter, segmentLossFunc, classification=False, device=device)

        #Overwrites previous best model based on validation accuracy
        if validationLoss < bestValLoss:
            bestValLoss = validationLoss
            torch.save(net.state_dict(), modelFileName + "BestLoss")

        print(f"Epoch {epoch}:\nTrain Acc: {metric[1] / numBatches} Validation Acc: {validationAcc} Train Loss: {metric[0] / numBatches} Validation Loss: {validationLoss}")

        #Externally logs epoch info to WandB
        if useWandB:
            wandb.log({"Train Acc": metric[1] / numBatches,
                    "Validation Acc": validationAcc,
                    "Train Loss": metric[0] / numBatches,
                    "Validation Loss": validationLoss
                    })

# **Setup**

In [17]:
#Hyperparameters and training modifications
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "StandardUNet"
fileSaveName = "/content/drive/MyDrive/" + modelName

#Use if starting from a checkpoint
startEpoch = 0

useWandB = False

batchSize = 6
learnRate = 0.008
epochs = 100
dropout = 0.3

#Progressive training parameters
startDim = 32
epochsToDouble = 25

#0: not progressive, 1: progressive, 2: reverse progressive
progressive = 0

#Checkpointing
epochsToSave = 10

#Learn rate scheduling parameters
cosineAnnealing = True
cosineRestartEpochs = 10

#lossFunc = BalancedCELoss(weight0=1, weight1=1.5)
#lossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)
#lossFunc = nn.BCEWithLogitsLoss()
#lossFunc = WeightedDiceLoss(weight0=0.2, weight1=0.8)

In [18]:
#Load Datasets
trainDataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

print("Dataset loaded")

Dataset loaded


# **Standard Training**

In [None]:
#Load model
jointTrainingFileName = ""

#lossFunc = WeightedDiceLoss(weight0=0.2, weight1=0.8)
lossFunc = dice_loss

#encoder = EncoderNetwork(1, dropout, device).to(device)
#decoder = DecoderNetwork(1, dropout, device).to(device)
#segmenter = SegmentationNetwork(encoder, decoder, dropout, device)
segmenter = MultiMix()
#print(summary(net, (1, 256, 256)))

#Loads model from file if using a pretrained version
if jointTrainingFileName != "":
    segmenter.load_state_dict(torch.load(jointTrainingFileName))

segmenter = segmenter.to(device)

print("Intialized joint training model")

Intialized joint training model


In [None]:
#Train Model

"""
if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName,
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })
"""

train(segmenter, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, startDim, epochsToDouble, fileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0, lossFunc=lossFunc, classification=False)

if useWandB:
    wandb.finish()

# **Pre-Training**

In [None]:
lossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

encoderFileName = ""
encoderFileSaveName = "/content/drive/MyDrive/MachineLearningResearch/ProgressiveEncoders/ProgressiveEncoder4"

#encoder = EncoderNetwork(1, dropout, device).to(device)
encoder = Encoder(1)

if encoderFileName != "":
    encoder.load_state_dict(torch.load(encoderFileName))

encoder = encoder.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="Pre-TrainedEncoder",
            name="ReverseProgressiveUNetEncoder",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
                "Dropout":dropout,
            })

train(encoder, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, startDim, epochsToDouble, encoderFileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, lossFunc=lossFunc, classification=True)

if useWandB:
    wandb.finish()

In [None]:
lossFunc = dice_loss

encoder.load_state_dict(torch.load(encoderFileSaveName))
#decoder = DecoderNetwork(1, dropout, device).to(device)
#segmenter = SegmentationNetwork(encoder, decoder, dropout, device)

segmenter = MultiMix(encoder=encoder)

segmenter = segmenter.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentationJointTraining",
            name="NoWeights",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

train(segmenter, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, startDim, epochsToDouble, fileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=0, lossFunc=lossFunc, classification=False)

if useWandB:
    wandb.finish()

# **Joint Training**

In [30]:
segmenter = MultiMix()

"""
if useWandB:
    wandb.init(project="LiverSegmentationJointTraining",
            name="Weights:0.5,0.5#2",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })
"""

classLossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

joint_train(segmenter, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, fileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, segmentLossFunc=dice_loss, classLossFunc=classLossFunc, weights=[0.5, 0.5])

if useWandB:
    wandb.finish()

Training on cuda
Epoch 0:
Train Acc: 0.2664094201874649 Validation Acc: 0.14966454745066585 Train Loss: 0.41529309519117685 Validation Loss: 0.8589106948673725
Epoch 1:
Train Acc: 0.3604078584535109 Validation Acc: 0.1618151319056051 Train Loss: 0.3464082794920575 Validation Loss: 0.8398429675400257
Epoch 2:
Train Acc: 0.39480577577944015 Validation Acc: 0.2058338455390185 Train Loss: 0.3172740226062082 Validation Loss: 0.7970126536488533
Epoch 3:
Train Acc: 0.4264376039031337 Validation Acc: 0.510195766503457 Train Loss: 0.30079778807787155 Validation Loss: 0.7401450663805008
Epoch 4:
Train Acc: 0.691575181342179 Validation Acc: 0.933587880237028 Train Loss: 0.2857940602216789 Validation Loss: 0.4253993382304907
Epoch 5:
Train Acc: 0.8092614019678223 Validation Acc: 0.9005903452634811 Train Loss: 0.2730942838724783 Validation Loss: 0.16638283707201482
Epoch 6:
Train Acc: 0.8959559991205339 Validation Acc: 0.9446448827814311 Train Loss: 0.1664362837888783 Validation Loss: 0.06643860673

KeyboardInterrupt: ignored

# **Evaluation/Ending**

In [33]:
modelName = "Pre-Trained then Joint Trained UNet"
classification = False
fileSaveName = "/content/drive/MyDrive/StandardUNetBestLossBestLossBestLoss"
segmentLossFunc = dice_loss
classossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

if classification:
    net = encoder
    net.load_state_dict(torch.load(encoderFileSaveName + "BestLoss"))
else:
    net = segmenter
    net.load_state_dict(torch.load(fileSaveName))

#Evaluate Model
print(f"Model: {modelName}")

classAcc, classLoss, segmentAcc, segmentLoss = joint_eval(net, testIter, classLossFunc, segmentLossFunc, device=device)
print(f"Classification Accuracy: {classAcc} Classification Loss: {classLoss}")
print(f"Segmentation Dice: {segmentAcc} Segmentation Loss: {segmentLoss}")

Model: Pre-Trained then Joint Trained UNet
Classification Accuracy: 0.861111111111111 Classification Loss: 0.05684027758913792
Segmentation Dice: 0.8773960446825103 Segmentation Loss: 0.12693572284964225


In [None]:
fileSaveName = "drive/MyDrive/MachineLearningResearch/Standard UNets/StandardUNet3"
dataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/Scan1Dataset.hdf5")
iter = DataLoader(dataset, batch_size=batchSize)
lossFunc1 = dice_loss
lossFunc2 = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

net = MultiMix()
net.load_state_dict(torch.load(fileSaveName))

net.to(device)

#acc, loss, _ = evaluate_accuracy(net, iter, lossFunc2, classification=True, device=device)

#print(f"Acc {acc} Loss {loss}")

segmentAcc, segmentLoss, classAcc, classLoss, segmentationMask = final_eval(net, iter, lossFunc1, lossFunc2, device=device)
print(f"Test Accuracy: {segmentAcc} Test Loss: {segmentLoss}")
print(f"Test Accuracy: {classAcc} Test Loss: {classLoss}")
print(len(segmentationMask[0][0]))


fileName = "drive/MyDrive/Scan1SegmentationMap1.hdf5"
wFile = h5py.File(fileName, "w")

for i, slice in enumerate(segmentationMask):
    wFile.create_dataset("Slice" + str(i), data=slice)

wFile.close()

In [None]:
#Close datasets
trainDataset.closeFile()
validationDataset.closeFile()
testDataset.closeFile()