This notebook is meant to run on WSL2, hence the directml usage.

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import os
import nrrd
import torch
import torch.nn as nn

from PIL import Image
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from torch.utils.data import Dataset, DataLoader
import torch_directml as tdml

print(f'Available devices: {tdml.device_count()}')
print(f'Current device: {tdml.device()}')
dml = tdml.device()

  from .autonotebook import tqdm as notebook_tqdm


Available devices: 1
Current device: privateuseone:0


In [2]:
class RasterDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir

        raw_labels = [(*(os.path.splitext(file)[0].split("_")[1:-1]), file) for file in os.listdir(data_dir)]

        self.labels = pd.DataFrame(data=raw_labels, columns=["pid", "age", "tbv", "filename"])
        self.labels[["age", "tbv"]] = self.labels[["age", "tbv"]].astype(float)

        self.age_minmax = MinMaxScaler()
        self.tbv_minmax = MinMaxScaler()
        self.age_std = StandardScaler()
        self.tbv_std = StandardScaler()

        self.labels[["age"]] = self.age_minmax.fit_transform(self.labels[["age"]])
        self.labels[["tbv"]] = self.tbv_minmax.fit_transform(self.labels[["tbv"]])
        self.labels[["age"]] = self.age_std.fit_transform(self.labels[["age"]])
        self.labels[["tbv"]] = self.tbv_std.fit_transform(self.labels[["tbv"]])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        entry = self.labels.iloc[idx]
        raster = nrrd.read(os.path.join(self.data_dir, entry["filename"]))[0]
        raster = (raster - np.mean(raster)) / np.std(raster) # Standardize the data

        return {"pid": entry["pid"], "age": entry["age"], "tbv": entry["tbv"], "raster": raster}

class RasterNet(nn.Module):
    def __init__(self):
        super(RasterNet, self).__init__()

        # First layer is triplanar as used in S3PNet:
        # https://www.sciencedirect.com/science/article/pii/S1077314219301791

        self.trpl = nn.Sequential( # 128x128x128 -> 32x40x40
            nn.Conv2d(128, 32, kernel_size=9, stride=3, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.conv1 = nn.Sequential( # 96x40x40 -> 192x10x10
            nn.Conv2d(96, 192, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential( # 192x10x10 -> 384x5x5
            nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv3 = nn.Sequential( # 384x5x5 -> 768x3x3
            nn.Conv2d(384, 768, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(768),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv4 = nn.Sequential( # 768x3x3 -> 1536x1x1
            nn.Conv2d(768, 1536, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(1536),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.linear = nn.Sequential(
            nn.Linear(1536, 768),
            nn.ReLU(),
            nn.Linear(768, 1)
        )

    def forward(self, base):
        base = base.squeeze().float().to(dml)

        xy = self.trpl(base)
        yz = self.trpl(torch.transpose(base, 2, 1))
        xz = self.trpl(torch.transpose(base, 1, 3))

        x = self.conv1(torch.cat((xy, yz, xz), dim=1))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return self.linear(x.squeeze())

class EarlyStopper:
    def __init__(self, patience=10):
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.counter = 0

        return self.early_stop

In [3]:
dataset = RasterDataset(data_dir="../aug_dataset", pin_memory=True, non_blocking=True)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

batch = next(iter(train_dataloader))

model = RasterNet().to(dml)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10000
train_loss_list = []
val_loss_list = []

early_stopper = EarlyStopper(patience=20)

for epoch in range(num_epochs):
    training_loss = 0.
    model.train()

    for i, data in enumerate(train_dataloader):
        rasters = data["raster"].float().unsqueeze(1).to(dml)
        tbvs = data["tbv"].float().to(dml)

        optimizer.zero_grad()

        rasters = rasters

        predictions = model(rasters).squeeze()
        loss = criterion(predictions, tbvs)

        loss.backward()
        optimizer.step()

        training_loss += loss.item()

        if (i+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}              ", end="\r")

    validation_loss = 0.
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            rasters = data["raster"].float().unsqueeze(1).to(dml)
            tbvs = data["tbv"].float().to(dml)

            predictions = model(rasters).squeeze()
            loss = criterion(predictions, tbvs)

            validation_loss += loss.item()

    train_loss = training_loss/len(train_dataloader)
    val_loss = validation_loss/len(test_dataloader)
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}               ")
    if early_stopper(validation_loss):
        break

plt.plot(train_loss_list, label="Training Loss", linewidth=3)
plt.plot(val_loss_list, label="Validation Loss", linewidth=3)
plt.legend("Training Loss", "Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

torch.save(model.state_dict(), "last_run_weights.pt")

Epoch [1/10000], Step [10/141], Loss: 1.6317
Epoch [1/10000], Step [20/141], Loss: 1.4944
Epoch [1/10000], Step [30/141], Loss: 0.8318
Epoch [1/10000], Step [40/141], Loss: 1.2108
Epoch [1/10000], Step [50/141], Loss: 0.9296
Epoch [1/10000], Step [60/141], Loss: 1.7470
Epoch [1/10000], Step [70/141], Loss: 0.8275
Epoch [1/10000], Step [80/141], Loss: 1.2905
Epoch [1/10000], Step [90/141], Loss: 1.2129
Epoch [1/10000], Step [100/141], Loss: 1.4862
Epoch [1/10000], Step [110/141], Loss: 1.8568
Epoch [1/10000], Step [120/141], Loss: 1.3904
Epoch [1/10000], Step [130/141], Loss: 1.1440
Epoch [1/10000], Step [140/141], Loss: 0.5251
Epoch [1/10000], Training Loss: 1.7479, Validation Loss: 1.3192
Epoch [2/10000], Step [10/141], Loss: 1.3566


KeyboardInterrupt: 