BraTS Challenge Demo-Notebook

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from dataset_utils import slice_cube, split_cube
from visualization_utils import plot_confusion_matrix, plot_loss, animate_cube, get_positives_negatives_from_cm
from data_loading import get_train_test_iters
torch.manual_seed(42)
from Architectures.unet_3d import UNet3D
from Architectures.unet_3d_context import UNet3D_Mini
from Architectures.unet_2d import UNet2D
from train import train_model
from custom_losses import DiceLoss, FocalTverskyLoss

Data Loading

Create the model

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

cuda
NVIDIA GeForce RTX 3060 Laptop GPU


Training Loop

In [3]:
# hyperparameters
train_3d = True
add_context = False
compute_test_loss = True
epochs = 60
learning_rate = 0.1

# init model
if train_3d:
    if add_context:
        model = UNet3D_Mini(num_modalities=4, num_classes=4).to(device)
    else:
        model = UNet3D(num_modalities=4, num_classes=4).to(device)
else:
    model = UNet2D().to(device)
weights_path = os.path.join('..','Weights')
weights_filename = 'UNet3D_epoch14_loss0.073.h5'
#model.load_state_dict(torch.load(os.path.join(weights_path, weights_filename)))
# init optimizer and loss_fn
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [4]:
class SmallSegNet(nn.Module):
    
    def __init__(self, in_channels=4, num_classes=4, img_height=5, img_width=5):
        super(SmallSegNet, self).__init__()
        self.sm = nn.Softmax(dim=1)
        self.layers = nn.Sequential(
        
        # decrease x,y and increase channels
            # Conv
            nn.Conv3d(in_channels,8,2,stride=1,padding=0),
            # nn.Batchnorm
            nn.ReLU(),
        
        # change number of channels
            # 1x1 Conv
            nn.Conv3d(8,8,1,stride=1,padding=0),
            # nn.Batchnorm
            nn.ReLU(),
        
        # increase x,y and decrease channels
            # Transpose-Conv
            nn.ConvTranspose3d(8, num_classes, 2, stride=1, padding=0)
        )
    
    def forward(self, x):
        x = self.layers(x)
        # x = self.sm(x)
        # Skip Softmax because Torch CrossEntropyLoss
        # takes logits and already applies Softmax.
        return x

model = SmallSegNet(in_channels=4, num_classes=4, img_height=96, img_width=96).to(device)

In [5]:
batch_size = 1
dataset_path = os.path.join('..', 'Task01_BrainTumour', 'cropped_small')

In [6]:
# Training
print(f'training {model.__class__.__name__}:')
train_losses, test_losses = train_model(model, optim, loss_fn, epochs, device, dataset_path, batch_size, train_3d, add_context, compute_test_loss)

training SmallSegNet:
epoch 0: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 1: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 2: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 3: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 4: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 5: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 6: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 7: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 8: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 9: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 10: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 11: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 12: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 13: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 14: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 15: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 16: epoch_train_loss=3.419, epoch_test_loss=2.858
epoch 17: epoch_train_loss=3.419, ep

KeyboardInterrupt: 

Plot losses

In [None]:
print(f'Loss curve using {loss_fn}:')
plot_loss(train_losses, test_losses)

Load inference model

In [None]:
weights_path = os.path.join('..','Weights')
weights_filename = 'SmallSegNet_epoch26_loss149.591.h5'
if train_3d:
    if add_context:
        inference_model = UNet3D_Mini(num_modalities=4, num_classes=4).to(device)
    else:
        inference_model = UNet3D(num_modalities=4, num_classes=4).to(device)
else:
    inference_model = UNet2D().to(device)
inference_model = SmallSegNet(in_channels=4, num_classes=4, img_height=96, img_width=96).to(device)
inference_model.load_state_dict(torch.load(os.path.join(weights_path, weights_filename)))
inference_model.eval();

Predict

In [None]:
train_iter, test_iter = get_train_test_iters(dataset_path, batch_size=batch_size, shuffle=True, num_workers=0)
batch = train_iter.next()

In [None]:
%matplotlib inline
minicube_batch = split_cube(batch, add_context) # split cubes into minicubes
sm = nn.Softmax(dim=1)

for minicube_idx in range(1):
    x = minicube_batch['image'][None,minicube_idx,:,:,:,:].to(device)
    voxel_logits = model.forward(x).cpu()
    voxel_probs = sm(voxel_logits)
    voxel_probs[0, 0] -= 0.9
    
    probs, out = torch.max(voxel_probs, dim=1)
plt.imshow(out[0, 79])

In [None]:
%matplotlib notebook
ani = animate_cube(inference_model, batch, add_context, device, train_3d)
plt.show()
#del batch

In [None]:
sm = nn.Softmax(dim=1)
temp = split_cube(batch, False)['image'][None,0,:,:,:,:].to(device)
print(temp.shape)
voxel_logits = model.forward(temp).cpu()
voxel_probs = sm(voxel_logits)
probs, out = torch.max(voxel_probs, dim=1)

print(probs[0, 79, 70:75, 70:75].detach().numpy())

In [None]:
my_input = torch.tensor([[[100, -100, -100, -100],[-100, 100, -100, -100], [-100, -100, 100, -100],[-100, -100, -100,100]]]).float()
print(my_input.shape)
my_label = torch.tensor([[[0, 1, 2, 3]]])
print(my_label.shape)
print(loss_fn(my_input, my_label))

In [None]:
current_class = my_label.clone()
current_class[current_class != 3] = 10
current_class[current_class == 3] = 11
current_class = current_class-10
print(current_class)
print(dice_loss_one_image(torch.tensor([0, 0, 0, 100]), torch.tensor([[[0, 0, 0, 1]]])))

In [None]:
print(torch.sigmoid(torch.tensor([0, 0, 0, 100])))

# Confusion Matrix

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated())

In [None]:
#train_iter, test_iter = get_train_test_iters(os.path.join('..', 'Task01_BrainTumour', 'cropped'), batch_size=batch_size, shuffle=False, num_workers=0)

#cf_matrix_visu, cf_matrix = plot_confusion_matrix(test_iter, model, train_3d, add_context, device=device)
#cf_matrix_visu.show()

In [None]:
#get_positives_negatives_from_cm(cf_matrix.to_numpy())