In [None]:
%load_ext autoreload
%autoreload 2

import os.path

import gc
gc.collect()

import torch

%load_ext tensorboard

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.multiprocessing.set_start_method('spawn')
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import numpy as np
import nibabel as nib
from network.net import UNet3D
from scipy.ndimage import distance_transform_edt

import sys
assert sys.version_info.major == 3, 'Not running on Python 3'

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [None]:
img_path = "/home/imag2/IMAG2_DL/KDCompression/Dataset/KING.nii.gz"
model_path = "trained_models/Vessels/KD.pt"


patch_size = np.array([96, 48, 96])
half_size = patch_size // 2
out_path = "test.nii.gz"
threshold = 0.2

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

In [None]:
model = UNet3D()
model.to(device, non_blocking=True)
model.load_state_dict(torch.load(model_path))
model.eval()

In [None]:
def preprocess_img(img, patch_size):
    mean = np.mean(img)
    std = np.std(img)
    if std > 0:
        img = (img - mean) / std
    else:
        img *= 0.

    if np.any(img.shape < patch_size):
        to_pad = np.maximum(patch_size - img.shape, np.zeros(3)).astype(np.int16)
        img = np.pad(img,
                     ((to_pad[0] // 2, to_pad[0] - to_pad[0] // 2),
                      (to_pad[1] // 2, to_pad[1] - to_pad[1] // 2),
                      (to_pad[2] // 2, to_pad[2] - to_pad[2] // 2)),
                     mode='constant', constant_values=0)

    return img

def load_img(f_name, patch_size, is_mask=False):
    img = nib.load(f_name)
    img = nib.as_closest_canonical(img)
    affine = img.affine
    img = img.get_fdata().astype(np.float32)

    return preprocess_img(img, patch_size), affine


def save_seg(f_name, segmentation, affine, dtype=np.uint8):
    segmentation = nib.Nifti1Image(segmentation.astype(dtype), affine)
    nib.save(segmentation, f_name)


def get_grid(img, patch_size):
    x = np.linspace(half_size[0], img.shape[0] - half_size[0], 2 * img.shape[0] // patch_size[0] + 1)
    y = np.linspace(half_size[1], img.shape[1] - half_size[1], 2 * img.shape[1] // patch_size[1] + 1)
    z = np.linspace(half_size[2], img.shape[2] - half_size[2], 2 * img.shape[2] // patch_size[2] + 1)
    return x, y, z

def _get_merge_fn(size, fn='flat', t=0.2):
    if fn == 'flat':
        return np.ones(size)
    elif fn == 'dt':
        _merge_fn = np.zeros(size)
        _merge_fn = np.sqrt(((np.argwhere(_merge_fn == 0) - size // 2) ** 2).sum(axis=1)).reshape(size)
        return np.maximum(1 - _merge_fn / np.amax(_merge_fn), t)
    elif fn == 'borders':
        _merge_fn = np.ones(size)
        _merge_fn = np.pad(_merge_fn, ((1, 1), (1, 1), (1, 1)))
        _merge_fn = distance_transform_edt(_merge_fn)
        _merge_fn = _merge_fn[1:-1, 1:-1, 1:-1]
        return np.maximum(_merge_fn / np.amax(_merge_fn), t)

In [None]:
img, img_affine = load_img(img_path, patch_size)

grid = get_grid(img, patch_size)
merge_fn = _get_merge_fn(patch_size, fn='borders', t=0.)

x, y, z = grid
seg = np.zeros([img.shape[0], img.shape[1], img.shape[2]])
weights = np.zeros([img.shape[0], img.shape[1], img.shape[2]])
for ind in np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3).astype(np.int16):
    x_min, y_min, z_min = ind - half_size
    x_max, y_max, z_max = ind + half_size

    current_patch = img[x_min:x_max, y_min:y_max, z_min:z_max]
    current_seg = seg[x_min:x_max, y_min:y_max, z_min:z_max]
    current_weights = weights[x_min:x_max, y_min:y_max, z_min:z_max]
    
    current_patch = np.expand_dims(current_patch.transpose((2, 0, 1)), axis=0)
    current_patch = torch.from_numpy(np.expand_dims(current_patch, axis=0))

    prediction = np.squeeze(torch.sigmoid(model(current_patch.to(device, non_blocking=True))).cpu().detach().numpy()).transpose((1,2,0))
    prediction = (current_seg * current_weights + prediction * merge_fn) / (current_weights + merge_fn)

    seg[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
    weights[x_min:x_max, y_min:y_max, z_min:z_max] = current_weights + merge_fn
    
save_seg(out_path, seg > threshold, img_affine) 