In [1]:
import argparse
import os
import torch
import sys
import importlib
from tqdm import tqdm
import numpy as np
import time
import nibabel as nib
import SimpleITK as sitk
import skimage

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc

def preprocess_bone(nii_path,tmp_path):
    source = nib.load(nii_path)
    source = source.get_fdata()
    source[source >= 200] = 1
    source[source != 1] = 0
    temp = np.argwhere(source == 1)
    # label =nib.load("./data/RibSeg/nii/RibFrac421-rib-seg.nii.gz")
    # label = label.get_fdata()
    # label_selected_points = []
    # for i in temp:
    #     label_selected_points.append(label[i[0]][i[1]][i[2]])
    # label_selected_points = np.array(label_selected_points)
    # print("label shape:",label_selected_points.shape)
    np.save(tmp_path,temp)

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] ="0"
experiment_dir = './log/part_seg/c2_a'
num_classes = 16
num_part = 50

In [3]:
import models.pointnet2_part_seg_msg as MODEL
classifier = MODEL.get_model(num_part, normal_channel=False).cuda()
checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
classifier.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [3]:
nii_path = "./data/RibSeg/test/RibFrac421-image.nii.gz"
tmp_nppath = "./data/tmp/"
save_path = "./data/PredictDir/RibFrac421-pred.nii.gz"
# preprocess_bone(nii_path,tmp_nppath+"image-data.npy")

In [6]:
with torch.no_grad():
    time_cost = 0
    num=0
    data = np.load(tmp_nppath+"image-data.npy").astype(np.float32)
    points = data[:, 0:3]
    choice = np.random.choice(data.shape[0], 30000, replace=False)
    # resample
    points = points[choice, :]
    np.save(tmp_nppath+"image-point.npy",points.astype('int32'))
    points[:, 0:3] = pc_normalize(points[:, 0:3])

    label = np.array([0])

    points = np.expand_dims(points, 0)
    label = np.expand_dims(label, 0)
    
    points, label = torch.from_numpy(points).float().cuda(), torch.from_numpy(label).long().cuda()
    points = points.transpose(2, 1)
    t1=time.clock()

    classifier = classifier.eval()
    seg_pred, trans_feat = classifier(points, to_categorical(label, num_classes))

    time_cost += time.clock()-t1
    seg_pred = seg_pred.contiguous().view(-1, num_part)
    pred_choice = seg_pred.data.max(1)[1]
    pred_choice=pred_choice.cpu().numpy()
    np.save(tmp_nppath+"image-label.npy",pred_choice.astype('int8'))
    print("time cost :",time_cost)



time cost : 1.8517037999999957




In [4]:
s_i = nib.load(nii_path)
aff= s_i.affine
s_i = s_i.get_fdata()
s_i[s_i != 0] = 1
s_i = s_i.astype('int8')

loc = np.load(tmp_nppath+"image-point.npy")
label = np.load(tmp_nppath+"image-label.npy")

In [5]:
mask_rd = np.zeros(s_i.shape)
mask_res = np.zeros(s_i.shape)
for index in loc:
    x, y, z = index[0], index[1], index[2]
    mask_rd[x][y][z] = 1
for i in range(loc.shape[0]):
    index = loc[i]
    x, y, z = index[0], index[1], index[2]
    mask_res[x][y][z] = label[i]



In [8]:
im = np.multiply(s_i, mask_res)

In [6]:
lmage_array = sitk.GetImageFromArray(mask_res.astype('int8'))
# closed = sitk.BinaryMorphologicalClosing(lmage_array,15,sitk.sitkBall)
dilated = sitk.BinaryDilate(lmage_array, (3,3,3), sitk.sitkBall)
# Eroded = sitk.BinaryErode(dilated,3,sitk.sitkBall)
# holesfilled = sitk.BinaryFillhole(dilated,fullyConnected=True)
# bmopening = sitk.BinaryMorphologicalOpening(lmage_array,3,sitk.sitkBall)
# holesfilled = sitk.GetArrayFromImage(dilated)
im_to = sitk.GetArrayFromImage(dilated)

In [7]:
res = np.multiply(s_i, im_to)
res1 = skimage.measure.label(res, connectivity=1)
rib_p = skimage.measure.regionprops(res1)

rib_p.sort(key=lambda x: x.area, reverse=True)

im = np.in1d(res1, [x.label for x in rib_p[:24]]).reshape(res1.shape)

im = im.astype('int8')

In [8]:
new_image = nib.Nifti1Image(im,aff) 
nib.save(new_image,save_path)
# rib_p.sort(key=lambda x: x.area, reverse=True)

In [11]:
im.shape

(512, 512, 325)