In [None]:
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
from dotenv import load_dotenv
import os
from pathlib import Path
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [128]:
load_dotenv()
data_dir = Path(os.getenv("DATA_DIR"))

### U-NET Model

- Consider adding batchnorm to doubleConv

In [135]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x), inplace=True)
        x = F.relu(self.conv2(x), inplace=True)
        return x

- consider adding other encoder layer

In [136]:
class UNETModel(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        # Down
        self.down1 = DoubleConv(in_channels,64)
        self.down2 = DoubleConv(64,128)
        self.down3 = DoubleConv(128,256)
        self.bottleneck = DoubleConv(256,512)

        self.pool = nn.MaxPool2d(2)

        # Up
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.right1 = DoubleConv(512, 256) # 256 (up) + 256 (encoder3)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.right2 = DoubleConv(256, 128) # 128 (up) + 128 (encoder2)
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.right3 = DoubleConv(128, 64) # 64 (up) + 64 (encoder1)

        # Out
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Down
        encoder1 = self.down1(x)
        encoder2 = self.down2(self.pool(encoder1))
        encoder3 = self.down3(self.pool(encoder2))

        # Bottleneck
        bneck = self.bottleneck(self.pool(encoder3))
        
        # Up
        decoder1 = self.up1(bneck)
        decoder1 = self.right1(torch.cat([decoder1, encoder3], dim=1))
        decoder2 = self.up2(decoder1)
        decoder2 = self.right2(torch.cat([decoder2, encoder2], dim=1))
        decoder3 = self.up3(decoder2)
        decoder3 = self.right3(torch.cat([decoder3, encoder1], dim=1))

        # Out
        out = self.out(decoder3)
        return torch.sigmoid(out)


### Preparing Data

In [137]:
def load_image_and_mask(dir):
    image_path = next(dir.glob("*image*"))
    image = nib.load(image_path).get_fdata()
    mask_paths = sorted(dir.glob("*mask*"))
    combined_mask = np.zeros_like(image)
    for path in mask_paths:
        mask = nib.load(path).get_fdata()
        combined_mask = np.logical_or(combined_mask, mask)
    return image, combined_mask.astype(np.float32)

image, mask = load_image_and_mask(data_dir/"BCBM-RadioGenomics-5-0")

In [138]:
def largest_slice(mask):
    slice_sums = mask.sum(axis=(0, 1))
    z = int(np.argmax(slice_sums))
    return z

In [139]:
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF

class MRIDataset(Dataset):
    def __init__(self, dirs, transform=None):
        self.dirs = dirs
        self.transform = transform
    
    def __len__(self):
        return len(self.dirs)
    
    def __getitem__(self, idx):
        dir = self.dirs[idx]
        image, mask = load_image_and_mask(dir)
        z = largest_slice(mask)

        target_size = (256, 256)

        image_slice = torch.tensor(image[:,:,z], dtype=torch.float32).unsqueeze(0)
        image_slice = TF.resize(image_slice, target_size)
        mask_slice = torch.tensor(mask[:,:,z], dtype=torch.float32).unsqueeze(0)
        mask_slice = TF.resize(mask_slice, target_size)

        if self.transform:
            image_slice, mask_slice = self.transform(image_slice, mask_slice)
        
        return image_slice, mask_slice

In [140]:
from sklearn.model_selection import train_test_split

# Split data into train and test
dirs = sorted([p for p in data_dir.iterdir() if p.is_dir()])
train_dirs, test_dirs = train_test_split(dirs, test_size=0.2, random_state=12)
train_data = MRIDataset(train_dirs)
test_data = MRIDataset(test_dirs)

In [141]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
test_loader = DataLoader(test_data, batch_size=4, shuffle=False)

- Add dice loss to criterion

In [None]:
model = UNETModel().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 5
train_losses = [0] * epochs
test_losses = [0] * epochs
for i in range(epochs):
    model.train()
    train_loss_total = 0
    for b, (img_train, mask_train) in enumerate(train_loader):
        img_train = img_train.to(device)
        mask_train = mask_train.to(device)

        mask_pred = model(img_train)
        loss = criterion(mask_pred, mask_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss_total += loss.item()
        print(f'Epoch: {i} Batch: {b} Loss: {loss.item()}')
    
    train_losses[i] = train_loss_total / len(train_loader)

    # testing
    model.eval()
    test_loss_total = 0
    with torch.no_grad():
        for b, (img_test, mask_test) in enumerate(test_loader):
            img_test = img_train.to(device)
            mask_test = mask_train.to(device)

            mask_pred = model(img_test)
            loss = criterion(mask_pred, mask_test)
            test_loss_total += loss.item()
            
    test_losses[i] = test_loss_total / len(test_loader)

Epoch: 0 Batch: 0 Loss: 1.037894606590271
Epoch: 0 Batch: 1 Loss: 0.9916205406188965
Epoch: 0 Batch: 2 Loss: 0.9600479006767273
Epoch: 0 Batch: 3 Loss: 0.9467707872390747
Epoch: 0 Batch: 4 Loss: 0.932676374912262
Epoch: 0 Batch: 5 Loss: 0.928701639175415
Epoch: 0 Batch: 6 Loss: 0.9420586228370667
Epoch: 0 Batch: 7 Loss: 0.9501872658729553
Epoch: 0 Batch: 8 Loss: 0.9351626038551331
Epoch: 0 Batch: 9 Loss: 0.9371138215065002
Epoch: 0 Batch: 10 Loss: 0.9313513040542603
Epoch: 0 Batch: 11 Loss: 0.9127479195594788
Epoch: 0 Batch: 12 Loss: 0.9254933595657349
Epoch: 0 Batch: 13 Loss: 0.9179235696792603
Epoch: 0 Batch: 14 Loss: 0.9290358424186707
Epoch: 0 Batch: 15 Loss: 0.8793569803237915
Epoch: 0 Batch: 16 Loss: 0.868272066116333
Epoch: 0 Batch: 17 Loss: 0.8632814288139343
Epoch: 0 Batch: 18 Loss: 0.8379778265953064
Epoch: 0 Batch: 19 Loss: 0.777062177658081
Epoch: 0 Batch: 20 Loss: 0.7357898354530334
Epoch: 0 Batch: 21 Loss: 0.702095091342926
Epoch: 0 Batch: 22 Loss: 0.6950366497039795
