In [9]:
import math
import numpy as np
import torch
import pkbar
import sys
from tqdm import tqdm
from torchsummary import summary
from torch.nn.functional import one_hot
from unet3d.config import *
from unet3d.tester import Tester
from unet3d.dice import dice_coef_torch_multiclass
from unet3d.transforms import *
from unet3d.unet3d_vgg16 import UNet3D_VGG16
from unet3d.transforms import val_transform
from utils.Visualization import ImageSliceViewer3D
import nrrd
from patchify import patchify, unpatchify
import os 

%load_ext autoreload
%autoreload 2
%matplotlib inline

torch.manual_seed(0)
torch.backends.cudnn.benchmark = True # Speeds up stuff
torch.backends.cudnn.enabled = True
device = torch.device('cuda')

model_path = 'checkpoints/saiad1and13_epoch149_23sep.pth'
model_path = os.path.realpath(model_path)
print(model_path)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/Work/Users/acharneca/Vessel-Segmentation-pytorch/checkpoints/saiad1and13_epoch149_23sep.pth


In [10]:
tester = Tester(
    model_path, 
    test_patient = 'SAIAD 1'
    )

tester.read_test_patient_data_and_pad()
tester.scan.shape

(528, 528, 144)

In [11]:
scan_patchified = tester.patchify_scan()
scan_patchified.shape

(10, 10, 2, 96, 96, 96)

In [12]:
patch = np.array([scan_patchified[5,5,1]])
patch_transform = {'patch_scan': patch, 
                   'patch_scan_flipped': patch,
                   'patch_scan_noise': patch, 
                   'patch_scan_contrast': patch
                  }
patch_transform = test_transform(patch_transform)

In [13]:
ImageSliceViewer3D(patch_transform['patch_scan'][0],patch_transform['patch_scan_contrast'][0]) 

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x1468d78143a0>

In [14]:
temp = tester.predict(with_transforms=True, verbose=1)

Patchifying scan...
(10, 10, 2, 5, 96, 96, 96)
Predicting on patches...
(10, 10, 2, 96, 96, 96, 5)
Unpatchifying...
(528, 528, 144, 5)
(528, 528, 144)
Done.


In [18]:
vol = temp['single_patch_predictions'].cpu().detach().numpy()[1,3]#[5,5,1,:,:,:,0]
print(vol.shape)
ImageSliceViewer3D(patch_transform['patch_scan'][0], vol)

(96, 96, 96)


interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x1468d705fac0>

In [19]:
tester.show_truth_vs_pred()

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

In [3]:
model_path = 'checkpoints/test_epoch149_19sep.pth'
model = UNet3D_VGG16(
    in_channels=IN_CHANNELS , 
    num_classes=NUM_CLASSES,
    use_softmax_end=True
    ).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

UNet3D_VGG16(
  (encoder_block1): Conv3DBlock_2conv(
    (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (relu): ReLU()
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder_block2): Conv3DBlock_2conv(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (relu): ReLU()
    (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder_block3): Conv3DBlock_3conv(
    (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (relu): ReLU()
    (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
    (pooling): MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  

In [4]:
# Test on 1 scan
scan, _ = nrrd.read(DATASET_PATH + '/SAIAD 15/scan.nrrd')
segm, _ = nrrd.read(DATASET_PATH + '/SAIAD 15/segm.nrrd')
scan = np.pad(scan, ((32,32),(32,32),(0,0)), constant_values=0)
segm = np.pad(segm, ((32,32),(32,32),(0,0)), constant_values=0)

scan_patches = patchify(scan, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])
scan_patches = torch.tensor(scan_patches).float()
scan_patches = torch.unsqueeze(scan_patches,1) # add channel dimension
segm_patches = patchify(segm, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2]) 
segm_patches = one_hot(torch.tensor(segm_patches).to(torch.int64), num_classes=NUM_CLASSES).permute(0,4,1,2,3).float()


print([scan_patches.shape, scan_patches.dtype])
print([segm_patches.shape, segm_patches.dtype])


[torch.Size([72, 1, 96, 96, 96]), torch.float32]
[torch.Size([72, 5, 96, 96, 96]), torch.float32]


In [5]:
## Test ##
pred_patches = np.zeros((scan_patches.shape[0], 5, PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2]))

with torch.no_grad():
    for i in tqdm(range(0,scan_patches.shape[0], TEST_BATCH_SIZE)):
        batch = np.zeros((TEST_BATCH_SIZE, 1, PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2]))
        
        for j in range(0,TEST_BATCH_SIZE):
            patch = {'name': 'patches', 'patch_scan': torch.Tensor(scan_patches[i+j])} 
            batch[j] = (val_transform(patch)['patch_scan'].float())
            
        pred = model(torch.Tensor(batch).cuda())
        pred_patches[i:i+TEST_BATCH_SIZE] = pred.cpu().detach().numpy()

100%|█████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:14<00:00,  1.58s/it]


In [6]:
# Unpatchify predictions
pred_patches_reshape = np.array(pred_patches).reshape(scan_patches.shape[0],5,PATCH_SIZE[0],PATCH_SIZE[1],PATCH_SIZE[2])
pred_patches_reshape = np.argmax(pred_patches_reshape, axis=1)
print(pred_patches_reshape.shape)
pred_patches_reshape = pred_patches_reshape.reshape(6,6,2,96,96,96)

pred_unpatchified = unpatchify(pred_patches_reshape, segm.shape)
print(pred_unpatchified.shape)


(72, 96, 96, 96)
(576, 576, 192)


In [7]:
ImageSliceViewer3D(segm,pred_unpatchified)

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x1465277e9670>

In [9]:
from unet3d.dice import *
dice_coef_torch_multiclass(torch.Tensor(segm), torch.Tensor(pred_unpatchified), 5, one_hot_encoded=False)


tensor([0.9708, 0.0099, 0.0897, 0.4562, 0.1189])