In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import os
import nrrd
import torch
import torch.nn as nn

from PIL import Image

from torch.utils.data import Dataset, DataLoader

In [None]:
torch.cuda.is_available()

In [None]:
class RasterDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir

        raw_labels = [(*(os.path.splitext(file)[0].split("_")[1:-1]), file) for file in os.listdir(data_dir)]

        self.labels = pd.DataFrame(data=raw_labels, columns=["pid", "age", "tbv", "filename"])
        self.labels[["age", "tbv"]] = self.labels[["age", "tbv"]].astype(float)


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        entry = self.labels.iloc[idx]
        raster = nrrd.read(os.path.join(self.data_dir, entry["filename"]))[0]
        raster = (raster - np.mean(raster)) / np.std(raster) # Standardize the data

        return {"pid": entry["pid"], "age": entry["age"], "tbv": entry["tbv"], "raster": raster}

class RasterNet(nn.Module):
    def __init__(self):
        super(RasterNet, self).__init__()

        # First we use a big kernel with stride 2 because small details are uninteresting.
        # 128x128x128x1 -> 32x32x32x32
        self.layer1 = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        # 32x32x32x32 -> 8x8x8x64
        self.layer2 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        # Then we user a smaller kernel to capture the more general data which is more interesting.
        # 8x8x8x64 -> 4x4x4x128
        self.layer3 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        # 4x4x4x128 -> 2x2x2x256
        self.layer4 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        self.linear = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.flatten(start_dim=1)
        x = self.linear(x)
        return x

class EarlyStopper:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.counter = 0

        return self.early_stop

In [None]:
dataset = RasterDataset(data_dir="../aug_dataset")

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

batch = next(iter(train_dataloader))

model = RasterNet()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10000
train_loss_list = []
val_loss_list = []

early_stopper = EarlyStopper(patience=15)

for epoch in range(num_epochs):
    training_loss = 0.
    model.train()

    for i, data in enumerate(train_dataloader):
        rasters = data["raster"].float().unsqueeze(1)
        tbvs = data["tbv"].float()
        
        optimizer.zero_grad()

        predictions = model(rasters).squeeze()
        loss = criterion(predictions, tbvs)

        loss.backward()
        optimizer.step()

        training_loss += loss.item()

        if (i+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")

    validation_loss = 0.
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            rasters = data["raster"].float().unsqueeze(1)
            tbvs = data["tbv"].float()

            predictions = model(rasters).squeeze()
            loss = criterion(predictions, tbvs)

            validation_loss += loss.item()

    train_loss = training_loss/len(train_dataloader)
    val_loss = validation_loss/len(test_dataloader)
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    if early_stopper(validation_loss):
        break

plt.plot(train_loss_list, label="Training Loss", linewidth=3)
plt.plot(val_loss_list, label="Validation Loss", linewidth=3)
plt.legend("Training Loss", "Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

torch.save(model.state_dict(), "simple_regression.pt")
