In [1]:
import os

import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
from scipy import ndimage
from skimage.measure import label, regionprops
from skimage.morphology import disk, remove_small_objects
from tqdm import tqdm

from dataset.fracnet_dataset import FracNetInferenceDataset
from dataset import transforms as tsfm
from model.unet import UNet



In [2]:
image_path = ".\\data\\RibFrac501-image.nii.gz"

transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 500)
]
batch_size = 16
num_workers = 4

# model = UNet(1, 1, first_out_channels=16)
# model.eval()
# model_weights = torch.load("model_weights.pth")
# model.load_state_dict(model_weights)
# model = nn.DataParallel(model).cuda()

dataset = FracNetInferenceDataset(image_path, transforms=transforms)
dataloader = FracNetInferenceDataset.get_dataloader(dataset,
    batch_size, num_workers)
pred = np.zeros(dataloader.dataset.image.shape)
crop_size = dataloader.dataset.crop_size
with torch.no_grad():
    for _,sample in enumerate(dataloader):
        images, centers = sample
        # images = images.cuda()
        # output = model(images).sigmoid().cpu().numpy().squeeze(axis=1)
        images = images.numpy().squeeze(axis=1)
        # print("images:",images.shape)
        # print("")

        for i in range(len(centers)):
            center_x, center_y, center_z = centers[i]
            cur_pred_patch = pred[
                center_x - crop_size // 2:center_x + crop_size // 2,
                center_y - crop_size // 2:center_y + crop_size // 2,
                center_z - crop_size // 2:center_z + crop_size // 2
            ]
            pred[
                center_x - crop_size // 2:center_x + crop_size // 2,
                center_y - crop_size // 2:center_y + crop_size // 2,
                center_z - crop_size // 2:center_z + crop_size // 2
            ] = np.where(cur_pred_patch > 0, np.mean((images[i],
                cur_pred_patch), axis=0), images[i])


In [4]:
import SimpleITK as sitk
import skimage
lmage_array = sitk.GetImageFromArray(pred.astype('int8'))
# closed = sitk.BinaryMorphologicalClosing(lmage_array,15,sitk.sitkBall)
dilated = sitk.BinaryDilate(lmage_array, (3,1,1), sitk.sitkBall)
# Eroded = sitk.BinaryErode(dilated,3,sitk.sitkBall)
# holesfilled = sitk.BinaryFillhole(Eroded,fullyConnected=True)
# bmopening = sitk.BinaryMorphologicalOpening(lmage_array,3,sitk.sitkBall)
im = sitk.GetArrayFromImage(dilated)
# im = holesfilled
# res = np.multiply(pred, holesfilled)
# res1 = skimage.measure.label(res, connectivity=1)
# rib_p = skimage.measure.regionprops(res1)
# rib_p.sort(key=lambda x: x.area, reverse=True)
# im = np.in1d(res1, [x.label for x in rib_p[:24]]).reshape(res1.shape)
# im = im.astype('int8')

In [5]:
pred_image = nib.Nifti1Image(pred, dataset.image_affine)
# pred_image = nib.Nifti1Image(im, dataset.image_affine)
save_path = ".\\data\\tmp"
nib.save(pred_image,save_path+"\RibFrac501-pred-test2.nii.gz")


In [None]:
transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 1000)
]

dataset = FracNetInferenceDataset(image_path, transforms=transforms)
dataloader = FracNetInferenceDataset.get_dataloader(dataset,
    batch_size, num_workers)

In [None]:
from itertools import product
from dataset import transforms as tsfm
image = nib.load(".\data\RibFrac501-image.nii.gz")
image_affine = image.affine
image = image.get_fdata().astype(np.int16)
crop_size=64

dim_coords = [list(range(0, dim, crop_size // 2))[1:-1]\
            + [dim - crop_size // 2] for dim in image.shape]
centers = list(product(*dim_coords))

transforms = [
    tsfm.Window(-200, 1000),
    tsfm.MinMaxNorm(-200, 1000)
]