In [7]:
import os
import torch
from pylab import *
import numpy as np
import torch.nn as nn
import SimpleITK as sitk
from numpy import ndarray
from scipy import ndimage
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms

from model import VNet
from data import SegThorDataset

In [8]:
def tensor_to_numpy(tensor):
    t_numpy = tensor.cpu().numpy()
    t_numpy = np.transpose(t_numpy, [0, 2, 3, 4, 1])
    t_numpy = np.squeeze(t_numpy)

    return t_numpy

In [9]:
def test():        
    test_set = SegThorDataset("data", phase = 'test', vol_size = [128, 128, 128])
    test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1, shuffle = False)    

    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = torch.load("models/model.pt")
    model.eval()
    model.to(device)
    active_vol =  np.zeros(5)
    
    with torch.no_grad():
     #   (images, labels, size) = next(iter(test_loader))
        for batch_idx, (images, labels, size) in enumerate(test_loader):

            images, labels = images.to(device, dtype=torch.float), labels.to(device, dtype=torch.uint8)
            outputs = model(images)

            images = tensor_to_numpy(images)
            outputs = tensor_to_numpy(outputs)
            labels = tensor_to_numpy(labels)
            size = size.numpy()
            size = np.reshape(size, (1,np.product(size.shape)))[0]
            print("=========================================================================")
            print("size of input volume: ", size)
            
            predicted_volume_numpy = np.zeros(shape=(128, 128, 128))

            for j in range(outputs.shape[3]):
                for x in range(outputs.shape[0]):
                    for y in range(outputs.shape[1]):
                        for z in range(outputs.shape[2]):
                            if outputs[x, y, z, j] > 0.5:
                                predicted_volume_numpy[x, y, z] = j
                                active_vol[j] = active_vol[j] + 1

            print("Active volume =", active_vol)
            factor = np.divide(size, predicted_volume_numpy.shape)  
            print("Resize factor = ", factor)

            predicted_volume_numpy = predicted_volume_numpy.astype(float32)   
            resize_vol = ndimage.interpolation.zoom(predicted_volume_numpy, factor, order=0, mode='nearest', prefilter=True)

            filename = "file.nii.gz"
            predicted_volume = sitk.GetImageFromArray(resize_vol, isVector=False)
            print("Size of segmented volume: ", predicted_volume.GetSize())
            sitk.WriteImage(sitk.Cast( predicted_volume, sitk.sitkUInt8 ), filename, True)


In [10]:
if __name__ == "__main__":
    test()

100%|██████████| 7/7 [00:05<00:00,  1.20it/s]


size of input volume:  [206 512 512]
Active volume = [0. 0. 0. 0. 0.]
Resize factor =  [1.609375 4.       4.      ]
Size of segmented volume:  (512, 512, 206)
size of input volume:  [213 512 512]
Active volume = [0. 0. 0. 0. 0.]
Resize factor =  [1.6640625 4.        4.       ]
Size of segmented volume:  (512, 512, 213)
size of input volume:  [150 512 512]
Active volume = [0. 0. 0. 0. 0.]
Resize factor =  [1.171875 4.       4.      ]
Size of segmented volume:  (512, 512, 150)
size of input volume:  [176 512 512]
Active volume = [0. 0. 0. 0. 0.]
Resize factor =  [1.375 4.    4.   ]
Size of segmented volume:  (512, 512, 176)
size of input volume:  [166 512 512]


KeyboardInterrupt: 