In [1]:
import math
import numpy as np
import torch
import pkbar
import sys
from tqdm import tqdm
from torchsummary import summary
from torch.nn.functional import one_hot
from unet3d.config import *
from unet3d.unet3d_vgg16 import UNet3D_VGG16
from utils.Visualization import ImageSliceViewer3D
import nrrd
from patchify import patchify, unpatchify
from unet3d.transforms import val_transform
%load_ext autoreload
%autoreload 2
%matplotlib inline

torch.manual_seed(0)
torch.backends.cudnn.benchmark = True # Speeds up stuff
torch.backends.cudnn.enabled = True
device = torch.device('cuda')


In [8]:
model_path = 'checkpoints/test_epoch149_17sep.pth'
model = UNet3D_VGG16(
    in_channels=IN_CHANNELS , 
    num_classes=NUM_CLASSES,
    use_softmax_end=True
    ).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

UNet3D_VGG16(
  (encoder_block1): Conv3DBlock_2conv(
    (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (relu): ReLU()
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder_block2): Conv3DBlock_2conv(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (relu): ReLU()
    (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder_block3): Conv3DBlock_3conv(
    (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (relu): ReLU()
    (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  

In [9]:
# Test on 1 scan
scan, _ = nrrd.read(DATASET_PATH + '/SAIAD 15/scan.nrrd')
segm, _ = nrrd.read(DATASET_PATH + '/SAIAD 15/segm.nrrd')
scan = np.pad(scan, ((32,32),(32,32),(0,0)), constant_values=0)
segm = np.pad(segm, ((32,32),(32,32),(0,0)), constant_values=0)

scan_patches = patchify(scan, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])
scan_patches = torch.tensor(scan_patches).float()
scan_patches = torch.unsqueeze(scan_patches,1) # add channel dimension
segm_patches = patchify(segm, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2]) 
segm_patches = one_hot(torch.tensor(segm_patches).to(torch.int64), num_classes=NUM_CLASSES).permute(0,4,1,2,3).float()


print([scan_patches.shape, scan_patches.dtype])
print([segm_patches.shape, segm_patches.dtype])


[torch.Size([72, 1, 96, 96, 96]), torch.float32]
[torch.Size([72, 5, 96, 96, 96]), torch.float32]


In [24]:
## Test ##
pred_patches = np.zeros((scan_patches.shape[0], 5, PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2]))

with torch.no_grad():
    for i in tqdm(range(0,scan_patches.shape[0], TEST_BATCH_SIZE)):
        batch = np.zeros((TEST_BATCH_SIZE, 1, PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2]))
        
        for j in range(0,TEST_BATCH_SIZE):
            patch = {'name': 'patches', 'patch_scan': torch.Tensor(scan_patches[i+j])} 
            batch[j] = (val_transform(patch)['patch_scan'].float())
            
        pred = model(torch.Tensor(batch).cuda())
        pred_patches[i:i+TEST_BATCH_SIZE] = pred.cpu().detach().numpy()

100%|█████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:09<00:00,  1.01s/it]


In [25]:
# Unpatchify predictions
pred_patches_reshape = np.array(pred_patches).reshape(scan_patches.shape[0],5,PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])
pred_patches_reshape = np.argmax(pred_patches_reshape, axis=1)
print(pred_patches_reshape.shape)
pred_patches_reshape = pred_patches_reshape.reshape(6,6,2,96,96,96)

pred_unpatchified = unpatchify(pred_patches_reshape, segm.shape)
print(pred_unpatchified.shape)


(72, 96, 96, 96)
(576, 576, 192)


In [29]:
ImageSliceViewer3D(segm,pred_unpatchified)

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x14916b021f70>

In [7]:
from unet3d.dice import *
truth_segm, _ = nrrd.read('../SAIAD-project/Data/Predicted Segms OriginSpacing/SAIAD 15/truth_segm.nrrd')
pred_segm, _ = nrrd.read('../SAIAD-project/Data/Predicted Segms OriginSpacing/SAIAD 15/pred_segm.nrrd')

dice_coef_torch_multiclass(torch.Tensor(truth_segm), torch.Tensor(pred_segm), 5)


array([0.98926663, 0.8680606 , 0.645984  , 0.7853374 , 0.85968685],
      dtype=float32)