Simple Segmentation Net

In [1]:
import torch
import os
import torch.nn as nn
import matplotlib.pyplot as plt 
%matplotlib inline
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from data_loading import BraTS_Dataset
from dataset_utils import plot_batch, crop_batch, decrop_batch, split_cube, slice_cube
from data_loading import get_train_test_iters
torch.manual_seed(42)
from Architectures.unet_3d import UNet3D
from Architectures.unet_2d import UNet2D
from train import train_model
from custom_losses import get_loss, DiceLoss

Data Loading

In [2]:
train_iter, test_iter = get_train_test_iters('../Task01_BrainTumour/cropped', batch_size=1, shuffle=True, num_workers=0)

../Task01_BrainTumour/cropped\imagesTr
../Task01_BrainTumour/cropped\labelsTr
../Task01_BrainTumour/cropped\imagesTs
../Task01_BrainTumour/cropped\labelsTs


Create the model

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

cuda
NVIDIA GeForce RTX 3060 Laptop GPU


In [6]:
model = UNet3D(num_modalities=4, num_classes=4, img_height=96, img_width=96).to(device)
#model = UNet2D().to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss = DiceLoss()

In [7]:
losses = train_model(model, optim, loss, 2, device, True, train_iter, steps_per_epoch = 400)

AttributeError: 'Tensor' object has no attribute 'detatch'

Test & debug model

In [None]:
# Sample a minicube batch
minicube_batch = split_cube(train_iter.next()) # 8 minicubes per image

In [None]:
voxel_logits_batch = model.forward(minicube_batch['image'][:1,:,:,:,:])

Training loop utils

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
def get_minicube_batch_loss(minicube_batch, step, device):
    if step%4 == 0:
        voxel_logits_batch = model.forward(minicube_batch['image'][:2,:,:,:,:])
        loss = criterion(voxel_logits_batch, minicube_batch['label'][:2,:,:,:].long().to(device))
        return loss
    
    elif step%4 == 1:
        voxel_logits_batch = model.forward(minicube_batch['image'][2:4,:,:,:,:])
        loss = criterion(voxel_logits_batch, minicube_batch['label'][2:4,:,:,:].long().to(device))
        return loss
    
    elif step%4 == 2:
        voxel_logits_batch = model.forward(minicube_batch['image'][4:6,:,:,:,:])
        loss = criterion(voxel_logits_batch, minicube_batch['label'][4:6,:,:,:].long().to(device))
        return loss
    
    else:
        voxel_logits_batch = model.forward(minicube_batch['image'][6:,:,:,:,:])
        loss = criterion(voxel_logits_batch, minicube_batch['label'][6:,:,:,:].long().to(device))
        return loss

Training Loop

In [None]:
# define optimizer


# training settings
epochs = 4
steps_per_epoch = 400
losses = []

# training loop
for epoch in range(epochs):
    for step in range(steps_per_epoch):
        
        if step%4 == 0:
            # Get a new minicube batch
            minicube_batch = split_cube(train_iter.next())
        
        loss = get_minicube_batch_loss(minicube_batch, step, device=device)
        losses.append(loss)
        
        if step%20 == 0:
            print(f'epoch {epoch}: step {step:3d}: loss={loss:3.3f}')
            path = f'../Weights/weights_epoch{epoch}_step{step}_loss{loss:3.3f}.h5'
            torch.save(model.state_dict(), path)
        
        # backprop loss
        optim.zero_grad()
        loss.backward()
        optim.step()

In [None]:
losses_array = [l.detach().cpu().numpy() for l in losses]
plt.plot(losses_array)

In [None]:
# load weights
weights_filename = 'weights_epoch3_step300_loss0.026.h5'
inference_model = SmallBTSegNet(num_modalities=4, num_classes=4, img_height=96, img_width=96).to(device)
inference_model.load_state_dict(torch.load('../Weights/' + weights_filename))
_ = inference_model.eval()

Predict

In [None]:
# predict batch
voxel_logits_batch = inference_model.forward(minicube_batch['image'][None,4,:,:,:,:])

# get loss
loss = criterion(voxel_logits_batch, minicube_batch['label'][None,4,:,:,:].long().to(device))
print(f'loss: {loss.item():3.3f}')

sm = nn.Softmax(dim=1)
voxel_probs_batch = sm(voxel_logits_batch)
print(voxel_probs_batch.shape)

probs, out = torch.max(voxel_probs_batch, dim=1)

In [None]:
print('prediction:')
plt.imshow(out[0, 5, :,:].cpu())

In [None]:
print('label:')
plt.imshow(minicube_batch['label'][4, 5, :, :].cpu())