In [1]:
import os

config = {}
config["weights_file"] = os.getcwd() + '/model/weight'
config["patch_size"] = (64, 64, 64)  # switch to None to train on the whole image
config["patch_gap"] = 16
config["batch_size"] = 10
config["kfold"] = 5

config["input_shape"] = (1, None, None, None)
config["depth"] = 4 # depth of layers for V/Unet
config["n_base_filters"] = 32
config["pool_size"] = (2, 2, 2)  # pool size for the max pooling operations
config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 10  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 0.00001
config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
config["n_epochs"] = 1

In [2]:
from model.data import *
from model.model import *

d = Data()
d.load_data(config["patch_size"])

# prepare data for training
train_num, valid_num = d.prekfold(config["patch_size"], config["patch_gap"], config["batch_size"], config["kfold"])

Using TensorFlow backend.


In [51]:
model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"],
                              depth=config["depth"],
                              n_base_filters=config["n_base_filters"])

weight_path = '/model/weight/ref/weights-02-0.02-0429-binary-patch.hdf5'
model.load_weights(os.getcwd() + weight_path) 


In [63]:
import numpy as np
import nibabel as nib
import h5py

class Reconstruct:
    def __init__(self, ind, shape, patch_size, to_weight):
        # find its original image: d.data[str(shape)][ind][0]
        # find its target image: d.data[str(shape)][ind][1]
        self.ind = ind
        self.shape = shape
        self.patch_size = patch_size
        # weight the patch before merging or not
        self.to_weight = to_weight
        
        self.data = np.zeros(shape)
        self.image = np.zeros(shape)
        self.target = np.zeros(shape)
        self.count = np.zeros(shape, dtype=np.float32)
        
#         construct softmax map for distance from the boundary

        if self.to_weight is False:
            self.dist_map = np.ones(patch_size)
        else:
            self.dist_map = np.zeros(patch_size)
            mini = 0
            minj = 0
            mink = 0
            for i in range(patch_size[0]):
                mini = min(i+1, patch_size[0]-i)
                for j in range(patch_size[1]):
                    minj = min(j+1, patch_size[1]-j)
                    for k in range(patch_size[2]):
                        mink = min(k+1, patch_size[2]-k)
    #                     print(i, j, k, mini, minj, mink)
                        self.dist_map[i, j, k] = min(mini, minj, mink)
    #         print(self.dist_map)
            self.dist_map = np.exp(self.dist_map)/np.sum(np.exp(self.dist_map))
    
#             self.dist_map = np.zeros(patch_size)
#             center = (np.array(patch_size)-1) / 2
#             center_dist = np.linalg.norm(center)
#             for i in range(patch_size[0]):
#                 for j in range(patch_size[1]):
#                     for k in range(patch_size[2]):
#     #                     print([i, j, k], np.array([i, j, k]) - center)
#                         self.dist_map[i, j, k] = center_dist - np.linalg.norm(np.array([i, j, k]) - center)
#     #         print(self.dist_map)
#             self.dist_map[self.dist_map < 0] = 0
#             self.dist_map = np.exp(self.dist_map)/np.sum(np.exp(self.dist_map))
#     #         print(self.dist_map)

        
    def add(self, patch, index):
        patch = patch * self.dist_map
        # get patch data
        patch_index = np.zeros(self.shape, dtype=np.bool)
        patch_index[...,
                    index[0]:index[0]+patch.shape[-3],
                    index[1]:index[1]+patch.shape[-2],
                    index[2]:index[2]+patch.shape[-1]] = True
        patch_data = np.zeros(self.shape)
        patch_data[patch_index] = patch.flatten()
        
        # store patch data in self.data
        new_data_index = np.logical_and(patch_index, np.logical_not(self.count > 0))
        self.data[new_data_index] = patch_data[new_data_index]
        
        # average overlapped region
        averaged_data_index = np.logical_and(patch_index, self.count > 0)
        if np.any(averaged_data_index):
            self.data[averaged_data_index] = (self.data[averaged_data_index] * self.count[averaged_data_index] + 
                                              patch_data[averaged_data_index]) / (self.count[averaged_data_index] + 1)
#         self.count[patch_index] += 1
#         print(self.count[patch_index].shape, self.dist_map.shape)
        self.count[ index[0]:index[0]+patch.shape[-3],
                    index[1]:index[1]+patch.shape[-2],
                    index[2]:index[2]+patch.shape[-1]] += 1
        
    def store(self, name):
        with h5py.File("./model/h5df_data/reconstruct_" + name + "_" +  str(self.shape) + ".h5", 'w') as f:
            f.create_dataset("index", data=self.ind)
            f.create_dataset("shape", data=self.shape)
            f.create_dataset("data", data=self.data)
        nib.save(nib.Nifti1Image(self.data, np.eye(4)), "reconstruct_" + name + str(i.shape) + ".nii.gz")

In [64]:
fold_index = 0
recon = []
for i in d.valid_index:
    j = d.valid_index[i][fold_index]
    recons = Reconstruct(j, d.data[i][j][0].shape, config["patch_size"], True)
    orig = Reconstruct(j, d.data[i][j][0].shape, config["patch_size"], False)
    for ind in range(d.patch_index[i][j].shape[0]):
        index = d.patch_index[i][j][ind]
        image_i = np.expand_dims(d.data[i][j][0][
                         index[0]:index[0]+d.patch_size[0], 
                         index[1]:index[1]+d.patch_size[1], 
                         index[2]:index[2]+d.patch_size[2]], axis=0)
        result = model.predict([image_i[None, :]])
        recons.add(result, index)
        orig.add(result, index)
    recon.append(recons)
    recon.append(orig)
    break


In [65]:
show_image([recon[0].data, recon[1].data])


interactive(children=(IntSlider(value=64, description='id', max=127), Output()), _dom_classes=('widget-interac…

In [35]:
import matplotlib.pyplot as plt

size = 5
res = Reconstruct(0, (size, size, size), (size, size, size), False)
print(res.dist_map)

def show_image(images):
    # show image with [None, None, : ,: ,:] dimension
    def show_frame(id):
        length = len(images)
        for i in range(length):
            ax = plt.subplot(1, length, i+1)
            plt.imshow(images[i][id, :, :], cmap='gray')
    interact(show_frame, 
             id=widgets.IntSlider(min=0, max=images[0].shape[0]-1, step=1, value=images[0].shape[0]/2))
        
show_image([res.dist_map])
# print(2 * np.ones((3,3,3)) * res.dist_map)

[[[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]]


interactive(children=(IntSlider(value=2, description='id', max=4), Output()), _dom_classes=('widget-interact',…

In [None]:
def test_full_img2(model, device, test_data, save_folder, save_name, box_size=160, save_thresh = 0.75):
    model.eval()
    img_num = 1
    with torch.no_grad():
        
        img, sub_box = test_data

        predict_subbox = []

        test_whole_volumes = {}
        test_whole_labels = {}
        for i in range(1):
            test_whole_volumes[i] = img
            test_whole_labels[i] = img

        test_idx = list(range(1))
        test_sub_box = get_subbox_idx(test_idx,[(img, sub_box)])

        test_Mouse_dataset = Mouse_sub_volumes_test(test_whole_volumes,test_whole_labels,test_sub_box)
        test_dataloader = DataLoader(test_Mouse_dataset, batch_size=8, shuffle=False, num_workers=4)

        for i_batch, sample_batched in enumerate(test_dataloader):
            inputs, labels = sample_batched['image'], sample_batched['label'].numpy()
            outputs = model(inputs.to(device)).cpu().numpy()
            for i_subimg in range(outputs.shape[0]):
                predict_subbox.append(np.squeeze(outputs[i_subimg,...]))
                
        y_predict = np.zeros((np.shape(img)),np.float32)
        overlapping = np.zeros((np.shape(img)),np.float32)
        
        box_one = np.ones((box_size,box_size,box_size),np.float32) * 0.2
        box_one[20:140, 20:140, 20:140] = 1.0

        for i_subbox in range(len(sub_box)):
            x, y, z = sub_box[i_subbox]
            y_predict[x-box_size:x, y-box_size:y, z-box_size:z] += predict_subbox[i_subbox] * box_one
            overlapping[x-box_size:x, y-box_size:y, z-box_size:z] += box_one
            

        y_predict /= overlapping
        
        y_predict[y_predict >  save_thresh] = 1.0
        y_predict[y_predict <= save_thresh] = 0.0

        max_component = 1
        y_predict_component = measure.label(y_predict)
        component_num = np.unique(y_predict_component)
        for current_component in range(1,len(component_num)):
            if np.sum(y_predict_component == current_component) < np.sum(y_predict_component == max_component):
                y_predict[y_predict_component == current_component] = 0
            elif np.sum(y_predict_component == current_component) > np.sum(y_predict_component == max_component):
                y_predict[y_predict_component == max_component] = 0
                max_component = current_component

        y_predict = ndimage.binary_fill_holes(y_predict).astype(float)

        
        save_nii(img, y_predict, save_folder, save_name)

        #print('Img_num: {}, f-score: {}'.format(img_num, score))
        print('img {}, predict body pixel: {}'.format(img_num, np.sum(y_predict)))
        img_num += 1
        del y_predict
        #print('average score of {} images is {}'.format(test_img, score_sum/test_img))

In [1]:
import os
import h5py

def fetch_file():
    path = os.getcwd() + '/model/h5df_data/'
    _, _, files = next(os.walk(path))
    result = []
    for file in files:
        if "reconstruct_dice_softmax_circle_" in file:
            result.append(h5py.File(path+file, 'r'))
    return result
            
files = fetch_file()
print(files)

[<HDF5 file "reconstruct_dice_softmax_circle_(192, 512, 512).h5" (mode r)>, <HDF5 file "reconstruct_dice_softmax_circle_(320, 384, 384).h5" (mode r)>, <HDF5 file "reconstruct_dice_softmax_circle_(128, 256, 256).h5" (mode r)>]


In [2]:
images = []
for i in files:
    images.append(i["data"][:])

In [5]:
from model.data import *

print(np.array(images).shape)
d = Data()
d.show_image([images[0][None, None, :]])

(3,)


interactive(children=(IntSlider(value=96, description='id', max=191), Output()), _dom_classes=('widget-interac…

In [53]:
# store all files to nii
def store_nii(arr):
    for i in arr:
        nib.save(nib.Nifti1Image(i, np.eye(4)), "reconstruct" + str(i.shape) + ".nii.gz")
    
store_nii(images)

In [29]:
# import numpy as np
# import h5py

# class Reconstruct:
#     def __init__(self, ind, shape, patch_size):
#         # find its original image: d.data[str(shape)][ind][0]
#         # find its target image: d.data[str(shape)][ind][1]
#         self.ind = ind
#         self.shape = shape
#         self.patch_size = patch_size
#         self.data = np.zeros(shape)
#         self.count = np.zeros(shape, dtype=np.int)
        
# #         construct softmax map for distance from the boundary
#         self.dist_map = np.zeros(patch_size)
#         mini = 0
#         minj = 0
#         mink = 0
#         for i in range(patch_size[0]):
#             mini = min(i+1, patch_size[0]-i)
#             for j in range(patch_size[1]):
#                 minj = min(j+1, patch_size[1]-j)
#                 for k in range(patch_size[2]):
#                     mink = min(k+1, patch_size[2]-k)
# #                     print(i, j, k, mini, minj, mink)
#                     self.dist_map[i, j, k] = min(mini, minj, mink)
# #         print(self.dist_map)
#         # add a base weight to have a bit more weight on the margins
# #         self.dist_map += 32
#         self.dist_map = np.exp(self.dist_map)/np.sum(np.exp(self.dist_map))

# #         self.dist_map = np.zeros(patch_size)
# #         center = (np.array(patch_size)-1) / 2
# #         center_dist = np.linalg.norm(center)
# #         for i in range(patch_size[0]):
# #             for j in range(patch_size[1]):
# #                 for k in range(patch_size[2]):
# # #                     print([i, j, k], np.array([i, j, k]) - center)
# #                     self.dist_map[i, j, k] = center_dist - np.linalg.norm(np.array([i, j, k]) - center)
# # #         print(self.dist_map)
# #         self.dist_map[self.dist_map < 0] = 0
# #         self.dist_map = np.exp(self.dist_map)/np.sum(np.exp(self.dist_map))
# # #         print(self.dist_map)

        
#     def add(self, patch, index):
#         patch = patch * self.dist_map
#         # get patch data
#         patch_index = np.zeros(self.shape, dtype=np.bool)
#         patch_index[...,
#                     index[0]:index[0]+patch.shape[-3],
#                     index[1]:index[1]+patch.shape[-2],
#                     index[2]:index[2]+patch.shape[-1]] = True
#         patch_data = np.zeros(self.shape)
#         patch_data[patch_index] = patch.flatten()
        
#         # store patch data in self.data
#         new_data_index = np.logical_and(patch_index, np.logical_not(self.count > 0))
#         self.data[new_data_index] = patch_data[new_data_index]
        
#         # average overlapped region
#         averaged_data_index = np.logical_and(patch_index, self.count > 0)
#         if np.any(averaged_data_index):
#             self.data[averaged_data_index] = (self.data[averaged_data_index] * self.count[averaged_data_index] + 
#                                               patch_data[averaged_data_index]) / (self.count[averaged_data_index] + 1)
#         self.count[patch_index] += 1
        
#     def store(self, name):
#         with h5py.File("./model/h5df_data/reconstruct_" + name + "_" +  str(self.shape) + ".h5", 'w') as f:
#             f.create_dataset("index", data=self.ind)
#             f.create_dataset("shape", data=self.shape)
#             f.create_dataset("data", data=self.data)