In [3]:
import os
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from d2l import torch as d2l
from pathlib import Path
from torchmetrics.classification import BinaryF1Score
import sys
import nibabel as nib
import numpy as np
import gc

In [None]:
class convBlock(nn.Module):
    def __init__(self, inChannels, outChannels, batchNorm, strides) -> None:
        super().__init__()

        self.conv1 = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)

        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.xavier_uniform_(self.conv2.weight)

        if(batchNorm):
            self.bn1 = nn.BatchNorm2d(outChannels)
        else:
            self.bn1 = False

    def forward(self, X):
        Y = self.conv1(X)

        if(self.bn1):
            Y = self.bn1(Y)

        Y = F.relu(Y)

        return torch.Tensor(F.relu(self.conv2(Y)))

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, inChannels, outChannels, strides) -> None:
        super().__init__()

        self.conv = convBlock(inChannels, outChannels, True, strides)
        self.pool = nn.MaxPool2d(2, stride=strides)

    def forward(self, X):

        Y = self.conv.forward(X)

        #return torch.Tensor(self.pool(Y)), Y
        return self.pool(Y)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, inChannels, outChannels, strides) -> None:
        super().__init__()

        self.convTrans = nn.ConvTranspose2d(inChannels, outChannels, 2, stride=strides, padding=1)
        self.conv = convBlock(outChannels, outChannels, True, strides)

    def forward(self, X, skipFeatures):
        Y = self.convTrans(X)
        Y = torch.cat(X, skipFeatures)
        return self.conv(Y)

In [None]:
def train(net: nn.Sequential, trainIter, numEpochs, learnRate, device: torch.device, lossFunc = nn.BCEWithLogitsLoss()):
    print(f"Training on {device}")
    
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=learnRate)
    numBatches = len(trainIter)

    net.train()

    for epoch in range(numEpochs):
        sumLoss = 0

        for i, (X, y) in enumerate(trainIter):
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)

            yhat = net(X)

            l = lossFunc(yhat.squeeze(0), y)
            l.backward()
            optimizer.step()

            sumLoss += l

            #print(f"{l} {yhat.item()} {y.item()}")

        print(f"Loss: {sumLoss / numBatches}")

In [None]:
gc.collect()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")

batchSize = 1
learnRate = 0.01
epochs = 3

block1 = EncoderBlock(1, 64, 1)
block2 = EncoderBlock(64, 128, 1)
block3 = EncoderBlock(128, 256, 1)

net = nn.Sequential(block1, block2, block3, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(256, 1))

print("Intialized model")

xDir = "Dataset/Volumes/"
yDir = "Dataset/Segmentations/"
numFiles = len([name for name in os.listdir("Dataset/Volumes") if os.path.isfile(os.path.join("Dataset/Volumes", name))])

xData = []
yData = []

for name in os.listdir(xDir):
    ctScan = nib.load(xDir + name)
    data = ctScan.get_fdata()

    for plane in range(data.shape[2]):
        p = data[:,:,plane].astype(np.int16)
        xData.append(p)

for name in os.listdir(yDir):
    segmentation = nib.load(yDir + name)
    data = segmentation.get_fdata()

    for plane in range(data.shape[2]):
        p = data[:,:,plane].astype(np.int16)
        yData.append(min(np.amax(p), 1))

tensorX = torch.Tensor(np.array(xData)).unsqueeze(1)
tensorY = torch.Tensor(np.array(yData))

print(tensorX.shape)
print(tensorY.shape)

print("Finished loading data")

dataset = TensorDataset(tensorX, tensorY)
trainIter = DataLoader(dataset, batch_size=batchSize, shuffle=True)

train(net, trainIter, epochs, learnRate, device)