In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import segmentation_models_pytorch as smp

import mlflow
import mlflow.pytorch

import numpy as np
import nibabel as nib
import sklearn.model_selection as model_selection
import pandas as pd



In [2]:
device = 'cuda:0'
BATCH_SIZE = 4
EPOCHS = 10
LEARNING_RATE = 0.0001
N_SLICES = 32
MODEL_FOLDER = 'resnet18'
NUM_CLASSES = 4
NUM_WORKERS = 4
CLASSES = ['Background', 'NonEnhancingTumour', 'Edema', 'EnhancingTumour']

In [3]:

# dataset to read nifti images and return 2d slices
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, num_slices=16):
        self.df = df
        self.num_slices = num_slices

    def __len__(self):
        return len(self.df)

    def get_img_label(self, idx):
        # get images
        img_flair = nib.load(self.df.iloc[idx]['flair']).get_fdata()
        img_t1 = nib.load(self.df.iloc[idx]['t1']).get_fdata()
        img_t1ce = nib.load(self.df.iloc[idx]['t1ce']).get_fdata()
        img_t2 = nib.load(self.df.iloc[idx]['t2']).get_fdata()

        # get label
        label = self.df.iloc[idx]['label']
        label = nib.load(label).get_fdata()
        return img_flair, img_t1, img_t1ce, img_t2, label

    def __getitem__(self, idx):

        img_flair, img_t1, img_t1ce, img_t2, label = self.get_img_label(idx)

        # normalize
        img_flair = (img_flair - img_flair.mean()) / img_flair.std()
        img_t1 = (img_t1 - img_t1.mean()) / img_t1.std()
        img_t1ce = (img_t1ce - img_t1ce.mean()) / img_t1ce.std()
        img_t2 = (img_t2 - img_t2.mean()) / img_t2.std()

        img = np.stack([img_flair, img_t1, img_t1ce, img_t2], axis=0)
        img = img.astype(np.float32)

        # get random slices
        slices = np.random.choice(img.shape[-1], self.num_slices, replace=False)
        img = img[..., slices]
        img = np.transpose(img, (3, 0, 1, 2))
        # pad to 256x256
        img = np.pad(img, ((0, 0), (0, 0), (0, 256 - img.shape[2]), (0, 256 - img.shape[3])), 'constant')

        label = label[None, ...]
        label = label[..., slices]
        label = np.transpose(label, (3, 0, 1, 2))
        label = np.pad(label, ((0, 0), (0, 0), (0, 256 - label.shape[2]), (0, 256 - label.shape[3])), 'constant')

        # change label 4 to 3 (there is no class 3 in the dataset)
        label[label == 4] = 3

        
        return torch.tensor(img), torch.tensor(label)

In [4]:
def collate_fn(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    return (torch.cat(data, 0), torch.cat(target, 0))

In [5]:
# model
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights=None, # no transfer learning
    in_channels=4,
    classes=NUM_CLASSES,
)
model = model.to(device)

In [6]:
def dice_loss(pred, target, smooth=1.0): # per class
    pred_softmax = F.softmax(pred, dim=1)
    target_onehot = F.one_hot(target.long().squeeze(), num_classes=NUM_CLASSES).permute(0, 3, 1, 2).float()
    intersection = (pred_softmax * target_onehot).sum(dim=(2, 3))
    dice = (2. * intersection + smooth) / (pred_softmax.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3)) + smooth)
    return 1 - dice.mean()


In [10]:
def get_tp_fp_fn_tn(pred, target):
    tps = []
    fps = []
    fns = []
    tns = []
    pred_argmax = pred.argmax(dim=1)
    print(pred_argmax.shape, target.shape)
    for i in range(NUM_CLASSES):
        pred_i = pred_argmax == i
        target_i = target == i
        tp = (pred_i & target_i).sum().item()
        fp = (pred_i & ~target_i).sum().item()
        fn = (~pred_i & target_i).sum().item()
        tn = (~pred_i & ~target_i).sum().item()
        tps.append(tp)
        fps.append(fp)
        fns.append(fn)
        tns.append(tn)

    return tps, fps, fns, tns

In [11]:
# loss function, optimizer, scheduler
# use dice loss for multi-class segmentation
# criterion = dice_loss
criterion = nn.CrossEntropyLoss().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

In [12]:
# training function
def train(model, device, train_loader, optimizer, epoch, log_interval=10):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target.squeeze().long())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx + 1, len(train_loader),
                100. * batch_idx / len(train_loader), loss.item()))


In [13]:
# testing function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    tps = [0] * NUM_CLASSES
    fps = [0] * NUM_CLASSES
    fns = [0] * NUM_CLASSES
    tns = [0] * NUM_CLASSES
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target.squeeze().long()).item() # sum up batch loss
            tps_, fps_, fns_, tns_ = get_tp_fp_fn_tn(output, target.squeeze())
            tps = [x + y for x, y in zip(tps, tps_)]
            fps = [x + y for x, y in zip(fps, fps_)]
            fns = [x + y for x, y in zip(fns, fns_)]
            tns = [x + y for x, y in zip(tns, tns_)]

    test_loss /= len(test_loader.dataset)

    acc = [ (tp + tn) / (tp + tn + fp + fn + 1e-6) for tp, tn, fp, fn in zip(tps, tns, fps, fns)]
    precision = [ tp / (tp + fp + 1e-6) for tp, fp in zip(tps, fps)]
    recall = [ tp / (tp + fn + 1e-6) for tp, fn in zip(tps, fns)]
    f1 = [ 2 * (p * r) / (p + r + 1e-6) for p, r in zip(precision, recall)]

    print('Accuracy: ', {CLASSES[i]: acc[i] for i in range(NUM_CLASSES)})
    print('Precision: ', {CLASSES[i]: precision[i] for i in range(NUM_CLASSES)})
    print('Recall: ', {CLASSES[i]: recall[i] for i in range(NUM_CLASSES)})
    print('F1: ', {CLASSES[i]: f1[i] for i in range(NUM_CLASSES)})

In [14]:
folders = os.listdir('RSNA_ASNR_MICCAI_BraTS2021_TrainingData')
folders = [os.path.join('RSNA_ASNR_MICCAI_BraTS2021_TrainingData', f) for f in folders]
df = []
for folder in folders:
    files = os.listdir(folder)
    files = [os.path.join(folder, f) for f in files]
    df.append({
        'flair': [ f for f in files if 'flair' in f][0],
        't1': [ f for f in files if 't1' in f and 't1ce' not in f][0],
        't1ce': [ f for f in files if 't1ce' in f][0],
        't2': [ f for f in files if 't2' in f][0],
        'label': [ f for f in files if 'seg' in f][0]
    })
df = pd.DataFrame(df, columns=['flair', 't1', 't1ce', 't2', 'label'])

In [15]:
 # split data into train and test
train_df, test_df = model_selection.train_test_split(df, test_size=0.2, random_state=42)

# create dataset
train_dataset = Dataset(train_df, num_slices=N_SLICES)
test_dataset = Dataset(test_df, num_slices=N_SLICES)

# create data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=collate_fn, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, collate_fn=collate_fn, num_workers=NUM_WORKERS)

  

In [16]:
x, y = train_dataset[0]

In [17]:
x.shape

torch.Size([32, 4, 256, 256])

In [18]:
os.makedirs(MODEL_FOLDER, exist_ok=True)

In [19]:
test(model, device, test_loader)

torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256

In [20]:
# train model
for epoch in range(1, EPOCHS + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval=10)
    test(model, device, test_loader)
    # scheduler.step()
    # save model
    torch.save(model.state_dict(), f"{MODEL_FOLDER}/model_{epoch}.pth")

torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256, 256])
torch.Size([128, 256, 256]) torch.Size([128, 256

In [24]:
# convert model to onnx
dummy_input = torch.randn(1, 4, 256, 256, device='cpu')
model_cpu = model.cpu()
torch.onnx.export(model_cpu, dummy_input, f"{MODEL_FOLDER}/model.onnx", verbose=True, input_names=['input'], output_names=['output'])

graph(%input : Float(1, 4, 256, 256, strides=[262144, 65536, 256, 1], requires_grad=0, device=cpu),
      %segmentation_head.0.weight : Float(4, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %segmentation_head.0.bias : Float(4, strides=[1], requires_grad=1, device=cpu),
      %305 : Float(64, 4, 7, 7, strides=[196, 49, 7, 1], requires_grad=0, device=cpu),
      %306 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %308 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %309 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %311 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %312 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %314 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %315 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %317 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      