# **Imports/Installs**

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()
!pip install torch
!pip install torchvision
!pip install d2l==1.0.0b0

⏬ Downloading https://github.com/jaimergp/miniforge/releases/latest/download/Mambaforge-colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:15
🔁 Restarting kernel...
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch
  Downloading torch-2.0.0-cp39-cp39-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.7.99
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m62.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jinja2
  Downloading Jinja2-3.1.2-py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.1/133.1 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollec

In [None]:
!pip install wandb -qU
import wandb
wandb.login()

[0m

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mamoseley018[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
import os
import torch
import torchvision
from torchvision import transforms
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchsummary import summary
from d2l import torch as d2l
from pathlib import Path
import sys
import nibabel as nib
import numpy as np
import gc
import skimage
import h5py
import math
import PIL
from PIL import Image

# **Loss Functions**

In [None]:
class WeightedDiceLoss(nn.Module):
    def __init__(self, weight0=1, weight1=1):
        super().__init__()

        self.weight0 = weight0
        self.weight1 = weight1

    def forward(self, input, target):
        loss0 = 0
        loss1 = 0

        imgSize = input.size()[1]

        correctPixels = 0

        for i, segment in enumerate(input):
            falseNeg0 = 0
            falsePos0 = 0
            truePos0 = 0

            falseNeg1 = 0
            falsePos1 = 0
            truePos1 = 0

            segment = torch.round(segment)

            for j in range(imgSize):
                for k in range(imgSize):
                    if segment[j][k] == 0 and target[i][j][k] == 0:
                        truePos0 += 1
                    elif segment[i][j] == 0 and target[i][j][k] == 1:
                        falsePos0 += 1
                    elif segment[i][j] == 1 and target[i][j][k] == 0:
                        falseNeg0 += 1

                    if segment[j][k] == 1 and target[i][j][k] == 1:
                        truePos1 += 1
                    elif segment[i][j] == 1 and target[i][j][k] == 0:
                        falsePos1 += 1
                    elif segment[i][j] == 0 and target[i][j][k] == 1:
                        falseNeg1 += 1

            if truePos0 > 0:
                loss0 += (2 * truePos0) / ((2 * truePos0) + falsePos0 + falseNeg0)

            if truePos1 > 0:
                loss1 += (2 * truePos1) / ((2 * truePos1) + falsePos1 + falseNeg1)

            correctPixels += truePos0 + truePos1

        loss0 /= input.size()[0]
        loss1 /= input.size()[0]

        return torch.as_tensor(1 - ((self.weight0 * loss0) + (self.weight1 * loss1))), correctPixels / (imgSize * imgSize * input.size()[0])

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, weight0=1, weight1=1, gamma=0):
        super().__init__()

        self.weight0 = weight0
        self.weight1 = weight1
        self.gamma = gamma

    def forward(self, input, target):
        loss = 0

        predictions = torch.round(input)
        accuratePreds = 0

        #Takes negative average loss over each element of input
        #Loss = ln(prediction) * (absolute loss ^ gamma) * class weight
        #prediction is the predicted likelihood that the correct label is true
        for i, el in enumerate(input):
            #print(f"{predictions[i]} {el}")

            if predictions[i] == target[i]:
                accuratePreds += 1

            if target[i] == 1:
                loss += torch.log(el) * (abs(1 - el) ** self.gamma) * self.weight1
            else:
                loss += torch.log(1 - el) * (abs(0 - el) ** self.gamma) * self.weight0

        return -1 * loss / len(input), accuratePreds / input.size()[0]

In [None]:
class BalancedCELoss(nn.Module):
    def __init__(self, weight0=1, weight1=1):
        super().__init__()

        self.weight0 = weight0
        self.weight1 = weight1

    def forward(self, input, target):
        loss = 0

        predictions = torch.round(input)
        accuratePreds = 0

        #Takes negative average loss over each element of input
        #Loss = ln(prediction) * class weight
        #prediction is the predicted likelihood that the correct label is true
        for i, el in enumerate(input):
            if predictions[i] == target[i]:
                accuratePreds += 1

            if target[i] == 1:
                loss += torch.log(el) * self.weight1
            else:
                loss += torch.log(1 - el) * self.weight0

        return -1 * loss / len(input), accuratePreds / input.size()[0]

# **Data Handling**

In [None]:
class LITSBinaryDataset(Dataset):
    def __init__(self, fileName):
        super().__init__()

        #Keeps a file pointer open throughout use
        self.file = h5py.File(fileName, 'r')

        #Precalculates length to reduce training computations
        self.length = len(list(self.file.keys()))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        data = self.file["Slice" + str(idx)]["Slice"]
        segmentation = self.file["Slice" + str(idx)]["Segmentation"]
        label = self.file["Slice" + str(idx)].attrs.get("ImageLabel")

        result = []

        #Returns list containing slice data and image label
        #Does not currently return segmentation data, will need to implement for decoder
        result.append(torch.Tensor(data[...]).unsqueeze(0))
        result.append(torch.Tensor(segmentation[...]))
        result.append(torch.Tensor(label).squeeze(0))

        return result

    def closeFile(self):
        #Closes file once dataset is no longer being used
        #Do not use class instance after this function is called
        self.file.close()

# **Network**

In [None]:
class convBlock(nn.Module):
    def __init__(self, inChannels, outChannels, batchNorm, strides, layerMean, layerDev, dropout) -> None:
        super().__init__()

        #Uses 2 convolutional layers for each block
        self.conv1 = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)

        #Initializes convolutional layers using hyperparameters for mean and standard deviation
        nn.init.normal_(self.conv1.weight, mean=layerMean, std=layerDev)
        nn.init.normal_(self.conv2.weight, mean=layerMean, std=layerDev)

        self.dropout = nn.Dropout(dropout).to(device)

        if(batchNorm):
            self.bn1 = nn.BatchNorm2d(outChannels)
        else:
            self.bn1 = False

    def forward(self, X):
        Y = self.conv1(X)
        if self.bn1:
            Y = self.bn1(Y)
        Y = F.relu(Y)

        Y = self.conv2(Y)
        if self.bn1:
            Y = self.bn1(Y)
        Y = F.relu(Y)

        Y = self.dropout(Y)

        return Y

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, inChannels, outChannels, strides, dropout) -> None:
        super().__init__()

        #Creates a convBlock with batch norm and max pooling layer, directly from UNet paper
        self.conv = convBlock(inChannels, outChannels, True, strides, 0, 0.025, dropout)
        self.pool = nn.MaxPool2d(2, stride=None)

    def forward(self, X):
        Y = self.conv(X)

        #Only for use in testing, does not return skip connection data
        #return self.pool(Y)

        #Returns the average pool of Y for the next encoder block, Y for a skip connections
        return self.pool(Y), Y

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, inChannels, outChannels, strides, dropout) -> None:
        super().__init__()

        self.convTrans = nn.ConvTranspose2d(inChannels, outChannels, 2, stride=2, padding=0)
        self.conv = convBlock(outChannels, outChannels, True, strides, 0, 0.25, dropout)

    def forward(self, X, skipConn):
        Y = self.convTrans(X)
        Y = torch.cat((Y, skipConn), dim=0)

        return self.conv(Y)

In [None]:
class DecoderNetwork(nn.Module):
        def __init__(self, channels, strides, dropout, device) -> None:
            super().__init__()

            self.device = device
            self.blocks = []

            for i in range(len(channels) - 1):
                self.blocks.append(DecoderBlock(channels[i], channels[i + 1], 1, dropout).to(device))

        def forward(self, X, skipConnections):
            if len(skipConnections) != len(self.blocks):
                return None

            y = X
            for i in range(len(self.blocks)):
                y = self.blocks[i](y, skipConnections[-(i + 1)])

            return y

In [None]:
class EncoderNetwork(nn.Module):
    def __init__(self, channels, strides, dropout, device) -> None:
        super().__init__()

        self.device = device

        #Creates a list of encoder blocks w/ in and out channels specified by parameter
        self.blocks = []
        for i in range(len(channels)):
            if i == 0:
                self.blocks.append(EncoderBlock(1, channels[i], 1, dropout).to(device))
            else:
                self.blocks.append(EncoderBlock(channels[i - 1], channels[i], 1, dropout).to(device))

        #Creates classification branch as sequential
        #Try without using sequential, use each layer separately
        #Can use without Flatten, average pool does the same thing
        #Follow MultiMix code
        self.classification = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(channels[-1], 1), nn.Sigmoid()).to(device)

        print(self.blocks)

    def forward(self, X):
        y = X.to(self.device)

        skipConnections = []

        for block in self.blocks:
            y, skip = block(y)
            skipConnections.append(skip)

        return self.classification(y), skipConnections, y

In [None]:
class SegmentationNetwork(nn.Module):
    def __init__(self, encoder, decoder, middleBlockInDim, middleBlockOutDim, endDim, dropout, device) -> None:
        super().__init__()

        self.device = device
        self.encoder = encoder
        self.decoder = decoder

        self.middleBlock = convBlock(middleBlockInDim, middleBlockOutDim, True, 1, 0, 0.25, dropout)

        self.endBlock = nn.Conv2d(endDim, 1, kernel_size=1, padding=0, stride=1)

    def forward(self, X):
        _, skip, y = self.encoder(X)

        y = self.middleBlock(y)
        y = self.decoder(y, skip)

        sigmoid = nn.Sigmoid()

        return sigmoid(self.endBlock(y))

# **Training**

In [None]:
def evaluate_accuracy(net, testIter, lossFunc, classification=True, device=None):
    net.eval()
    if not device:
        device = next(iter(net.parameters())).device

    #Accuracy, number of samples, loss
    metric = d2l.Accumulator(2)

    with torch.no_grad():
        for _, (X, y1, y2) in enumerate(testIter):
            X = X.to(device)
            y1 = y1.to(device)
            y2 = y2.to(device)

            yhat = net(X)[0]

            if classification:
                loss, accuracy = lossFunc(yhat, y2)
            else:
                loss, accuracy = lossFunc(yhat, y1)

            metric.add(accuracy, loss)

    return metric[0] / len(testIter), metric[1] / len(testIter)

In [None]:
def train(net: nn.Module, trainIter, testIter, numEpochs, startEpoch, learnRate, batchSize, device: torch.device, startDim, epochsToDouble, modelFileName, epochsToSave, 
          useWandB=False, cosineAnnealing=True, restartEpochs=-1, progressive=False, lossFunc = nn.BCEWithLogitsLoss(), classification=True):
    print(f"Training on {device}")
    
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learnRate)

    #Setting restartEpochs to a negative will use no warm restarts, otherwise will use warm restarts 
    if cosineAnnealing:
        if restartEpochs < 0:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        else:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, restartEpochs, T_mult=1)

    numBatches = len(trainIter)
    bestValLoss = float('inf')

    currDim = startDim
    for epoch in range(startEpoch, numEpochs):
        net.train()
        
        #Loss, accuracy
        metric = d2l.Accumulator(2)

        for i, (X, y1, y2) in enumerate(trainIter):
            optimizer.zero_grad()
            y1 = y1.to(device)
            y2 = y2.to(device)

            if progressive > 0:
                #If using progressive learning, downsamples image to the current dimension
                X = F.interpolate(X, size=int(currDim))

            X = X.to(device)
            
            yhat = net(X)[0].to(device)

            if classification:
                l, accuracy = lossFunc(yhat, y2)
            else:
                l, accuracy = lossFunc(yhat, y1)

            l.requires_grad_()

            #print(f"Loss: {l.item()} Predictions: {yhat.tolist()} Labels: {y.tolist()}")

            l.backward()
            optimizer.step()

            if cosineAnnealing:
                scheduler.step(epoch + i / numBatches)

            metric.add(l, accuracy)

        #Progressive learning
        if (epoch + 1) % epochsToDouble == 0 and progressive == 1:
            currDim *= 2
        #Reverse progressive learning
        elif (epoch + 1) % epochsToDouble == 0 and progressive == 2:
            currDim /= 2

        #Checkpoints model
        if (epoch + 1) % epochsToSave == 0:
            torch.save(net.state_dict(), modelFileName + "Epoch" + str(epoch))

        validationAcc, validationLoss = evaluate_accuracy(net, testIter, lossFunc, classification=classification, device=device)

        #Overwrites previous best model based on validation accuracy
        if validationLoss < bestValLoss:
            bestValLoss = validationLoss
            torch.save(net.state_dict(), modelFileName + "BestLoss")

        print(f"Train Acc: {metric[1] / numBatches} Validation Acc: {validationAcc} Train Loss: {metric[0] / numBatches} Validation Loss: {validationLoss}")

        #Externally logs epoch info to WandB
        if useWandB:
            wandb.log({"Train Acc": metric[1] / numBatches,
                    "Validation Acc": validationAcc,
                    "Train Loss": metric[0] / numBatches,
                    "Validation Loss": validationLoss
                    })

# **Setup**

In [None]:
#Hyperparameters and training modifications
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

modelName = "temp"
fileSaveName = "/content/drive/MyDrive/" + modelName

#Use if starting from a checkpoint
startEpoch = 0

useWandB = False

batchSize = 6
learnRate = 0.001
epochs = 100
dropout = 0

#Progressive training parameters
startDim = 32
epochsToDouble = 25

#0: not progressive, 1: progressive, 2: reverse progressive
progressive = 1

#Checkpointing
epochsToSave = 10

#Learn rate scheduling parameters
cosineAnnealing = True
cosineRestartEpochs = 10

#lossFunc = BalancedCELoss(weight0=1, weight1=1.5)
#lossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)
#lossFunc = nn.BCEWithLogitsLoss()
#lossFunc = WeightedDiceLoss(weight0=0.2, weight1=0.8)

In [None]:
#Load Datasets
trainDataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/FullTrainDataset.hdf5")
validationDataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/ValidationDataset.hdf5")
testDataset = LITSBinaryDataset("drive/MyDrive/MachineLearningResearch/Datasets/TestDataset.hdf5")

trainIter = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validationIter = DataLoader(validationDataset, batch_size=batchSize)
testIter = DataLoader(testDataset, batch_size=batchSize)

print("Dataset loaded")

Dataset loaded


# **Joint Training**

In [None]:
#Load model
jointTrainingFileName = ""

lossFunc = WeightedDiceLoss(weight0=0.2, weight1=0.8)

encoder = EncoderNetwork([16, 32, 64, 128], 1, dropout, device).to(device)
decoder = DecoderNetwork([256, 128, 64, 32, 16], 1, dropout, device).to(device)
segmenter = SegmentationNetwork(encoder, decoder, 128, 256, 16, dropout, device)
#print(summary(net, (1, 256, 256)))

#Loads model from file if using a pretrained version
if jointTrainingFileName != "":
    segmenter.load_state_dict(torch.load(jointTrainingFileName))

segmenter = segmenter.to(device)

print("Intialized joint training model")

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName,
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

train(segmenter, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, startDim, epochsToDouble, fileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, lossFunc=lossFunc, classification=False)

if useWandB:
    wandb.finish()

# **Pre-Training**

In [None]:
lossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

block1 = EncoderBlock(1, 16, 1, 0.2)
block2 = EncoderBlock(16, 32, 1, 0.2)
block3 = EncoderBlock(32, 64, 1, 0.2)
block4 = EncoderBlock(64, 128, 1, 0.2)
block5 = EncoderBlock(128, 256, 1, 0.2)

encoder = nn.Sequential(block1, block2, block3, block4, block5, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(256, 1), nn.Sigmoid())

print(encoder)

Sequential(
  (0): EncoderBlock(
    (conv): convBlock(
      (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Dropout(p=0.2, inplace=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): EncoderBlock(
    (conv): convBlock(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Dropout(p=0.2, inplace=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): EncoderBlock(
    (conv): convBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1

In [None]:
lossFunc = FocalLoss(weight0=0.2, weight1=0.8, gamma=2)

encoderFileName = ""

encoder = EncoderNetwork([16, 32, 64, 128, 256], 1, dropout, device).to(device)

if encoderFileName != "":
    encoder.load_state_dict(torch.load(encoderFileName))

encoder = encoder.to(device)

print(encoder)

[EncoderBlock(
  (conv): convBlock(
    (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout(p=0, inplace=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
), EncoderBlock(
  (conv): convBlock(
    (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dropout): Dropout(p=0, inplace=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
), EncoderBlock(
  (conv): convBlock(
    (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), 

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverClassifier2",
            name="Progressive" + str(dropout),
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble,
                "Dropout":dropout,
            })

train(encoder, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, startDim, epochsToDouble, fileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, lossFunc=lossFunc, classification=True)

if useWandB:
    wandb.finish()

Training on cuda
Train Acc: 0.5743405275779381 Validation Acc: 0.25 Train Loss: 0.05551540767170971 Validation Loss: 0.06152976602315903
Train Acc: 0.7134292565947244 Validation Acc: 0.25 Train Loss: 0.03979905240787126 Validation Loss: 0.05894099064171314
Train Acc: 0.7517985611510797 Validation Acc: 0.25 Train Loss: 0.03745648845789029 Validation Loss: 0.06275923043489456
Train Acc: 0.7601918465227823 Validation Acc: 0.25 Train Loss: 0.034867453258216594 Validation Loss: 0.06636136949062348
Train Acc: 0.7871702637889691 Validation Acc: 0.25 Train Loss: 0.03307217086678733 Validation Loss: 0.05905896313488483
Train Acc: 0.8093525179856119 Validation Acc: 0.6683333333333333 Train Loss: 0.030552252525758508 Validation Loss: 0.059511965662240984
Train Acc: 0.8045563549160678 Validation Acc: 0.31 Train Loss: 0.029396616269184425 Validation Loss: 0.058966809734702114
Train Acc: 0.8123501199040768 Validation Acc: 0.25 Train Loss: 0.028873087458895427 Validation Loss: 0.058768841847777364
Tr

KeyboardInterrupt: ignored

In [None]:
lossFunc = WeightedDiceLoss(weight0=0.2, weight1=0.8)

encoder.load_state_dict(torch.load(fileSaveName + "BestLoss"))
decoder = DecoderNetwork([256, 128, 64, 32, 16], 1, dropout, device).to(device)
segmenter = SegmentationNetwork(encoder, decoder, 128, 256, 16, dropout, device)

segmenter = segmenter.to(device)

In [None]:
#Train Model
gc.collect()

if useWandB:
    wandb.init(project="LiverSegmentation",
            name=modelName + "Segmenter",
            config={
                "BatchSize":batchSize,
                "LearnRate":learnRate,
                "Epochs":epochs,
                "StartDimension":startDim,
                "EpochsToDouble":epochsToDouble
            })

train(segmenter, trainIter, validationIter, epochs, startEpoch, learnRate, batchSize, device, startDim, epochsToDouble, fileSaveName, epochsToSave, useWandB=useWandB, 
      cosineAnnealing=cosineAnnealing, restartEpochs=cosineRestartEpochs, progressive=progressive, lossFunc=lossFunc, classification=False)

if useWandB:
    wandb.finish()

# **Evaluation/Ending**

In [None]:
classification = True

if classification:
    net = encoder
else:
    net = segmenter

#Evaluate Model
print(f"Model: {modelName}")

trainAcc, trainLoss = evaluate_accuracy(net, trainIter, lossFunc, classification=classification, device=device)
print(f"Train Accuracy: {trainAcc} Train Loss: {trainLoss}")

validationAcc, validationLoss = evaluate_accuracy(net, validationIter, lossFunc, classification=classification, device=device)
print(f"Validation Accuracy: {validationAcc} Validation Loss: {validationLoss}")

testAcc, testLoss = evaluate_accuracy(net, testIter, lossFunc, classification=classification, device=device)
print(f"Test Accuracy: {testAcc} Test Loss: {testLoss}")

In [None]:
#Close datasets
trainDataset.closeFile()
validationDataset.closeFile()
testDataset.closeFile()