In [3]:
import os

config = {}
config["weights_file"] = os.getcwd() + '/model/weight'
config["patch_size"] = (64, 64, 64)  # switch to None to train on the whole image
config["patch_gap"] = 16
config["batch_size"] = 2
config["kfold"] = 5

config["input_shape"] = (1, None, None, None)
config["depth"] = 4 # depth of layers for V/Unet
config["n_base_filters"] = 32
config["pool_size"] = (2, 2, 2)  # pool size for the max pooling operations
config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 10  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 0.00001
config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
config["n_epochs"] = 10

In [10]:
from model.recon import *
from model.model import *
from model.data import *

# weight_path = ['/model/weight/spring2019/fold0_all_patch_weights-05-0.40.hdf5',
#                '/model/weight/spring2019/fold0_weights-03-0.39.hdf5',
#                '/model/weight/spring2019/fold_binary0weights-05-0.06.hdf5',
#               ]
# weight_name = ['all',
#                'normal',
#                'binary',
#               ]

weight_path = ['/model/weight/fold0_weights-03-0.38.hdf5',
              ]
weight_name = ['dice',
              ]

d = Data()
d.load_data(config["patch_size"])
# set up valid index
train_num, valid_num = d.prekfold(config["patch_size"], config["patch_gap"], config["batch_size"], config["kfold"])

for i_weight in range(len(weight_path)):
    print("loading weight: ", weight_name[i_weight])
    model = unet_model_3d(input_shape=config["input_shape"],
                                  pool_size=config["pool_size"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  deconvolution=config["deconvolution"],
                                  depth=config["depth"],
                                  n_base_filters=config["n_base_filters"])
    model.load_weights(os.getcwd() + weight_path[i_weight]) 
    
    print(d.valid_index)
    fold_index = 0
    for i in d.valid_index:
        j = d.valid_index[i][fold_index]
        normal = Reconstruct(j, d.data[i][j][0].shape, config["patch_size"], False)
        image = Reconstruct(j, d.data[i][j][0].shape, config["patch_size"], False)
        target = Reconstruct(j, d.data[i][j][0].shape, config["patch_size"], False)
        for ind in range(d.patch_index[i][j].shape[0]):
            index = d.patch_index[i][j][ind]
            image_i = np.expand_dims(d.data[i][j][0][
                             index[0]:index[0]+d.patch_size[0], 
                             index[1]:index[1]+d.patch_size[1], 
                             index[2]:index[2]+d.patch_size[2]], axis=0)
            target_i = np.expand_dims(d.data[i][j][1][
                             index[0]:index[0]+d.patch_size[0], 
                             index[1]:index[1]+d.patch_size[1], 
                             index[2]:index[2]+d.patch_size[2]], axis=0)
            result = model.predict([image_i[None, :]])
            normal.add(result, index)
            image.add(image_i, index)
            target.add(target_i, index)
        dir_name = './model/h5df_data/recon/' + weight_name[i_weight] + '/'
        os.makedirs(os.path.dirname(dir_name), exist_ok=True)
        file_name = '/recon/' + weight_name[i_weight] + '/'+ str(d.data[i][j][0].shape)
        normal.store(file_name + "_uniform_output")
        image.store(file_name + "_input")
        target.store(file_name + "_target")

print("finish reconstructing image")


loading weight:  dice
{}
finish reconstructing image


In [27]:
table = {0: "target", 1: "uniform", 2: "weight"}

for i in total:
    print(i)
    for j in range(len(total[i])):
        print(table[j])
        for k in total[i][j]:
            print(total[i][j][k])


all
target
<HDF5 file "(192, 512, 512)_target.h5" (mode r)>
<HDF5 file "(320, 384, 384)_target.h5" (mode r)>
<HDF5 file "(128, 256, 256)_target.h5" (mode r)>
uniform
<HDF5 file "(192, 512, 512)_uniform_output.h5" (mode r)>
<HDF5 file "(128, 256, 256)_uniform_output.h5" (mode r)>
<HDF5 file "(320, 384, 384)_uniform_output.h5" (mode r)>
weight
<HDF5 file "(128, 256, 256)_weighted_output.h5" (mode r)>
<HDF5 file "(192, 512, 512)_weighted_output.h5" (mode r)>
<HDF5 file "(320, 384, 384)_weighted_output.h5" (mode r)>
normal
target
<HDF5 file "(192, 512, 512)_target.h5" (mode r)>
<HDF5 file "(320, 384, 384)_target.h5" (mode r)>
<HDF5 file "(128, 256, 256)_target.h5" (mode r)>
uniform
<HDF5 file "(192, 512, 512)_uniform_output.h5" (mode r)>
<HDF5 file "(128, 256, 256)_uniform_output.h5" (mode r)>
<HDF5 file "(320, 384, 384)_uniform_output.h5" (mode r)>
weight
<HDF5 file "(128, 256, 256)_weighted_output.h5" (mode r)>
<HDF5 file "(192, 512, 512)_weighted_output.h5" (mode r)>
<HDF5 file "(320, 3

In [27]:
for i in target:
    for j in range(len(target[i])):
        print(target[i][j])
        print(os.path.basename(target[i][j].filename))
    

<HDF5 file "(128, 256, 256)_target.h5" (mode r)>
(128, 256, 256)_target.h5


In [16]:
np.arange(0, 1.01, 0.01).shape
np.zeros(101).shape

(101,)

In [None]:
plt.plot(it, dice_thre)
plt.savefig()
print(it[np.argmax(dice_thre)], dice_thre[np.argmax(dice_thre)])

In [3]:
import os
import h5py
import numpy as np
from collections import defaultdict

def fetch_file():
    path = os.getcwd() + '/model/h5df_data/recon/'
    root, sub_dir, _ = next(os.walk(path))
    total = {}
#     uniform = defaultdict(list)
#     weight = defaultdict(list)
#     target = defaultdict(list)
    for sub in sub_dir:
        if "ipynb_checkpoints" in sub:
            continue
        _, _, sub_files = next(os.walk(root + sub))
        ar = []
        for file in sub_files:
#             print(file)
            if "nii.gz" in file and "threshold" in file:
                ar.append(path+sub+'/'+file)
        total[sub] = ar
    return total["multi"]

def dice(y_true, y_pred, smooth=1.):
    y_true_f = np.array(y_true).flatten()
    y_pred_f = np.array(y_pred).flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    # tensorflow computation graph: will not configure print as one of the graph, unless using tf.Print()
    return (2.*intersection+smooth) / (np.sum(y_true_f)+np.sum(y_pred_f)+smooth)
# print(dice(merge_target, merge_result>0.02))
            
total = fetch_file()
print(total)




['/scratch/yl4217/MS-Lesion-Segmentation/model/h5df_data/recon/multi/(128, 256, 256)_uniform_output.h5_uniform_0.54_threshold.nii.gz', '/scratch/yl4217/MS-Lesion-Segmentation/model/h5df_data/recon/multi/(192, 512, 512)_uniform_output.h5_uniform_0.54_threshold.nii.gz', '/scratch/yl4217/MS-Lesion-Segmentation/model/h5df_data/recon/multi/(320, 384, 384)_uniform_output.h5_uniform_0.54_threshold.nii.gz']


In [4]:
import nibabel as nib

raw_data = defaultdict(list)
# raw_data = []
# raw_data[i][0]: image, raw_data[i][1]: target
for i in total:
    image = nib.load(i)
    nib.save(nib.Nifti1Image(image.get_fdata(), np.eye(4)), i)
#     for j in range(len(total[i])):
#         image = nib.load(total[i][j])
#         nib.save(nib.Nifti1Image(image.get_fdata(), np.eye(4)), total[i][j])
# #         raw_data[str(image.shape)].append(image.get_fdata())

In [36]:
print(raw_data)

defaultdict(<class 'list'>, {'(192, 512, 512)': [(192, 512, 512), (192, 512, 512)], '(320, 384, 384)': [(320, 384, 384), (320, 384, 384)], '(128, 256, 256)': [(128, 256, 256), (128, 256, 256)]})


In [24]:
def show_image(images):
    # show image with [None, None, : ,: ,:] dimension
    def show_frame(id):
        length = len(images)
        for i in range(length):
            ax = plt.subplot(1, length, i+1)
            if (i == 0):
                ax.set_title("Input")
            if (i == 1):
                ax.set_title("Target")
            if (i == 2):
                ax.set_title("Output")
            plt.imshow(images[i][id, :, :], cmap='gray')
    interact(show_frame, 
             id=widgets.IntSlider(min=0, max=images[0].shape[0]-1, step=1, value=images[0].shape[0]/2))
show_image(raw_data['(192, 512, 512)'])

interactive(children=(IntSlider(value=96, description='id', max=191), Output()), _dom_classes=('widget-interac…