In [1]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting

from wnet import generate_wnet_model

Using TensorFlow backend.


In [10]:
wparams = {}
wparams['input_channels'] = 2
wparams['output_channels'] = 1
wparams['latent_channels'] = 16
wparams['scale'] = [0.5, 0.5, 0.5]
wparams['use_combined_loss'] = True
wparams['patch_shape'] = (32, 32, 32)
# wparams['loss_weights'] = [1, 1.0/np.sqrt(32**4), 1]
wparams['loss_weights'] = [1, 1, 1.1]

segparams = {}
segparams['seg_model_filename'] = ''
segparams['seg_model_params_filename'] = ''
segparams['segmentation_classes'] = 4

model = generate_wnet_model(wparams, segparams)

In [11]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           (None, 1, 32, 32, 32 0                                            
__________________________________________________________________________________________________
input_9 (InputLayer)            (None, 1, 32, 32, 32 0                                            
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 1, 32, 32, 32 0           input_10[0][0]                   
__________________________________________________________________________________________________
multiply_3 (Multiply)           (None, 1, 32, 32, 32 0           input_9[0][0]                    
                                                                 lambda_2[0][0]                   
__________

In [12]:
import os
import nibabel as nib

from medical_data import diff_info

ref_pattern = 'OAS2_{0:04}_MR{1}_halfwayto_OAS2_{0:04}_MR{2}'
file_general_pattern = 'OAS2_{0:04}_MR{1}_to_OAS2_{0:04}_MR{2}'
dataset_location = '/mnt/harddisk/datasets/OASIS/REG/{}/{}.nii.gz'

step = (16, 16, 16)
threshold = np.int32(0.30 * np.prod(wparams['patch_shape'][:]))
seg_train = np.empty((0, 1, ) + wparams['patch_shape'])
ref_train = np.empty((0, 1, ) + wparams['patch_shape'])
out_train = np.empty((0, 1, ) + wparams['patch_shape'])
for i in range(100, 190) :
    print '{} :'.format(i),
    
    for j in range(1, 5) :
        diff = 0
        for k in range(j+1, 5) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, j, k))
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, j, k) + '_brain_seg')
            mov_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, k, j))
            mov_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, k, j) + '_seg')
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
            
            diff = diff + diff_info['OAS2_0{:03}_MR{}'.format(i, k)]
            
            if diff < 0.02:
                continue

            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()
            
            mask_patches = extract_patches(volume_init != 0, wparams['patch_shape'], step)

            useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold
            
            del mask_patches
            
            ref_patches = extract_patches(volume_init, wparams['patch_shape'], step)
            ref_patches = ref_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            ######################################################################################
            ref_train = np.vstack((ref_patches, ref_train)).astype('float32')
            del ref_patches

            ######################################################################################
            volume_init = nib.load(mov_filename).get_data()
            
            mov_patches = extract_patches(volume_init, wparams['patch_shape'], step)
            mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            ######################################################################################
            volume_init = nib.load(mov_prob_filename).get_data() == 1
            
            mov_prob_patches = extract_patches(volume_init, wparams['patch_shape'], step)
            mov_prob_patches = mov_prob_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            ######################################################################################
            seg_train = np.vstack((mov_prob_patches, seg_train)).astype('float32') ##
            out_train = np.vstack((mov_patches, out_train)).astype('float32')
            del mov_patches, mov_prob_patches
            ######################################################################################
            print '{}->{} ({:.2f})'.format(j, k, diff),
    print

100 : 1->2 (0.02) 1->3 (0.02)
101 : 1->2 (0.02) 1->3 (0.03)
102 :
103 : 1->2 (0.03) 1->3 (0.04)
104 :
105 :
106 : 1->2 (0.04)
107 :
108 :
109 :
110 :
111 : 1->2 (0.03)
112 :
113 : 1->2 (0.02)
114 : 1->2 (0.02)
115 :
116 :
117 : 1->3 (0.02) 1->4 (0.03) 2->3 (0.02) 2->4 (0.03)
118 :
119 :
120 :
121 : 1->2 (0.02)
122 :
123 :
124 :
125 :
126 : 1->3 (0.02)
127 : 1->4 (0.03) 2->4 (0.02) 3->4 (0.02)
128 : 1->2 (0.02)
129 : 1->3 (0.02) 2->3 (0.02)
130 :
131 :
132 :
133 : 1->3 (0.04)
134 :
135 : 1->2 (0.02)
136 :
137 :
138 : 1->2 (0.02)
139 : 1->2 (0.03)
140 : 1->2 (0.02) 1->3 (0.02)
141 : 1->2 (0.02)
142 :
143 : 1->3 (0.02) 2->3 (0.02)
144 :
145 : 1->2 (0.03)
146 :
147 : 1->4 (0.03) 2->4 (0.03) 3->4 (0.02)
148 :
149 :
150 :
151 :
152 :
153 :
154 : 1->2 (0.04)
155 :
156 :
157 : 1->2 (0.07)
158 : 1->2 (0.02)
159 :
160 :
161 :
162 :
163 :
164 :
165 :
166 :
167 :
168 :
169 :
170 :
171 : 1->3 (0.02) 2->3 (0.02)
172 :
173 :
174 :
175 :
176 : 1->3 (0.03) 2->3 (0.02)
177 :
178 : 1->3 (0.02)
179 : 1->2

In [13]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 10

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_mseloss_o2o1.h5', save_best_only=True, save_weights_only=True)

N = len(ref_train)
model.fit(
    [ref_train, seg_train],
    [np.multiply(out_train, 1-seg_train), np.multiply(out_train, seg_train), out_train],
    validation_split=0.3, epochs=100, batch_size=32,
    callbacks=[checkpointer, stopper])

Train on 15439 samples, validate on 6618 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100


Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100


<keras.callbacks.History at 0x7ff322ced190>

In [14]:
model.load_weights('models/ag_mseloss_o2o1.h5')

In [15]:
from keras.models import Model
import os
import nibabel as nib
import SimpleITK as sitk

ref_pattern = 'OAS2_{0:04}_MR{1}_halfwayto_OAS2_{0:04}_MR{2}'
file_general_pattern = 'OAS2_{0:04}_MR{1}_to_OAS2_{0:04}_MR{2}'
dataset_location = '/mnt/harddisk/datasets/OASIS/REG/{}/{}.nii.gz'

mad_results = {}
ssim_results = {}
curr_patch_shape = (32, 32, 32)
step = (32, 32, 32)
for i in range(1, 100) :
    print '{} :'.format(i),
    for j in range(1, 5) :
        for k in range(j+1, 5) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, j, k))
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, j, k) + '_brain_seg')
            mov_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, k, j))
            mov_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, k, j) + '_seg')
            
            print ref_prob_filename
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
                
            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()

            ref_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_patches = ref_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(mov_prob_filename).get_data() == 1
            
            mov_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            mov_prob_patches = mov_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            print '{}->{}'.format(j, k),

            pred = model.predict(
                [ref_patches, mov_prob_patches], verbose=1, batch_size=64)[2]
            pred = pred.reshape((-1, ) + curr_patch_shape)

            volume = perform_voting(pred, curr_patch_shape, volume_init.shape, step)

            volume_data = nib.load(mov_filename)

            volume = np.multiply(volume_data.get_data() != 0, volume)
            
            save_filename = 'mse_results/{}_{}_to_{}.nii.gz'.format(i, j, k)
            save_filename_cor = 'mse_results/{}_{}_to_{}_cor.nii.gz'.format(i, j, k)

            nib.save(nib.Nifti1Image(volume, volume_data.affine), save_filename)

            res = sitk.ReadImage(save_filename)

            caster = sitk.CastImageFilter()
            caster.SetOutputPixelType(res.GetPixelID())

            orig = caster.Execute(sitk.ReadImage(mov_filename))
            seg = caster.Execute(sitk.ReadImage(mov_prob_filename))

            thresholder = sitk.BinaryThresholdImageFilter()
            enhanced_vol = caster.Execute(
                thresholder.Execute(seg, -1, -1, 1, 0))

            pairs = [(1, 1), (2, 3)]
            for (a, b) in pairs :                
                thresholder = sitk.BinaryThresholdImageFilter()
                mask = thresholder.Execute(seg, a, b, 1, 0)

                masker = sitk.MaskImageFilter()
                ref_masked = masker.Execute(orig, mask)
                moving_masked = masker.Execute(res, mask)

                matcher = sitk.HistogramMatchingImageFilter()
                matcher.SetNumberOfHistogramLevels(2048)
                matcher.SetNumberOfMatchPoints(15)
                matcher.SetThresholdAtMeanIntensity(True)
                partial_result = matcher.Execute(moving_masked, ref_masked)

                adder = sitk.AddImageFilter()
                enhanced_vol = adder.Execute(enhanced_vol, partial_result)

            masker = sitk.MaskImageFilter()
            mask = thresholder.Execute(seg, 0, 0, 0, 1)
            enhanced_vol = masker.Execute(enhanced_vol, mask)
            sitk.WriteImage(enhanced_vol, save_filename_cor)

            volume = nib.load(save_filename_cor).get_data()
            act_vol = np.float64(nib.load(mov_filename).get_data())
            mad_movgen = mad(volume, act_vol)
            ssim_movgen = ssim(volume, act_vol)

            mad_results['{} {}->{}'.format(i, j, k)] = mad_movgen
            ssim_results['{} {}->{}'.format(i, j, k)] = ssim_movgen
            print '{}-{}: {} - {}'.format(i, k, mad_movgen, ssim_movgen)

1 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0001_MR1_to_OAS2_0001_MR2/OAS2_0001_MR1_halfwayto_OAS2_0001_MR2_brain_seg.nii.gz
 1-2: 0.0206712573647 - 0.968303241213
/mnt/harddisk/datasets/OASIS/REG/OAS2_0001_MR1_to_OAS2_0001_MR3/OAS2_0001_MR1_halfwayto_OAS2_0001_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0001_MR1_to_OAS2_0001_MR4/OAS2_0001_MR1_halfwayto_OAS2_0001_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0001_MR2_to_OAS2_0001_MR3/OAS2_0001_MR2_halfwayto_OAS2_0001_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0001_MR2_to_OAS2_0001_MR4/OAS2_0001_MR2_halfwayto_OAS2_0001_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0001_MR3_to_OAS2_0001_MR4/OAS2_0001_MR3_halfwayto_OAS2_0001_MR4_brain_seg.nii.gz
2 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0002_MR1_to_OAS2_0002_MR2/OAS2_0002_MR1_halfwayto_OAS2_0002_MR2_brain_seg.nii.gz
 2-2: 0.0339983009138 - 0.930725916554
/mnt/harddisk/datasets/OASIS/REG/OAS2_0002_MR1_to_OAS2_0002_MR3/OAS2_0002_MR1_hal

 12-2: 0.0242604736912 - 0.968466510186
/mnt/harddisk/datasets/OASIS/REG/OAS2_0012_MR1_to_OAS2_0012_MR3/OAS2_0012_MR1_halfwayto_OAS2_0012_MR3_brain_seg.nii.gz
 12-3: 0.0258634676039 - 0.962588985635
/mnt/harddisk/datasets/OASIS/REG/OAS2_0012_MR1_to_OAS2_0012_MR4/OAS2_0012_MR1_halfwayto_OAS2_0012_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0012_MR2_to_OAS2_0012_MR3/OAS2_0012_MR2_halfwayto_OAS2_0012_MR3_brain_seg.nii.gz
 12-3: 0.0246864283748 - 0.967981876724
/mnt/harddisk/datasets/OASIS/REG/OAS2_0012_MR2_to_OAS2_0012_MR4/OAS2_0012_MR2_halfwayto_OAS2_0012_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0012_MR3_to_OAS2_0012_MR4/OAS2_0012_MR3_halfwayto_OAS2_0012_MR4_brain_seg.nii.gz
13 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0013_MR1_to_OAS2_0013_MR2/OAS2_0013_MR1_halfwayto_OAS2_0013_MR2_brain_seg.nii.gz
 13-2: 0.0206110368418 - 0.951692033462
/mnt/harddisk/datasets/OASIS/REG/OAS2_0013_MR1_to_OAS2_0013_MR3/OAS2_0013_MR1_halfwayto_OAS2_0013_MR3_brain_seg.nii.gz
 1

 22-2: 0.0205649303726 - 0.966784593595
/mnt/harddisk/datasets/OASIS/REG/OAS2_0022_MR1_to_OAS2_0022_MR3/OAS2_0022_MR1_halfwayto_OAS2_0022_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0022_MR1_to_OAS2_0022_MR4/OAS2_0022_MR1_halfwayto_OAS2_0022_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0022_MR2_to_OAS2_0022_MR3/OAS2_0022_MR2_halfwayto_OAS2_0022_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0022_MR2_to_OAS2_0022_MR4/OAS2_0022_MR2_halfwayto_OAS2_0022_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0022_MR3_to_OAS2_0022_MR4/OAS2_0022_MR3_halfwayto_OAS2_0022_MR4_brain_seg.nii.gz
23 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0023_MR1_to_OAS2_0023_MR2/OAS2_0023_MR1_halfwayto_OAS2_0023_MR2_brain_seg.nii.gz
 23-2: 0.0208091503712 - 0.954250944077
/mnt/harddisk/datasets/OASIS/REG/OAS2_0023_MR1_to_OAS2_0023_MR3/OAS2_0023_MR1_halfwayto_OAS2_0023_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0023_MR1_to_OAS2_0023_MR4/OAS2_0023_MR1_half

 31-3: 0.0227756114109 - 0.967693304813
/mnt/harddisk/datasets/OASIS/REG/OAS2_0031_MR2_to_OAS2_0031_MR4/OAS2_0031_MR2_halfwayto_OAS2_0031_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0031_MR3_to_OAS2_0031_MR4/OAS2_0031_MR3_halfwayto_OAS2_0031_MR4_brain_seg.nii.gz
32 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0032_MR1_to_OAS2_0032_MR2/OAS2_0032_MR1_halfwayto_OAS2_0032_MR2_brain_seg.nii.gz
 32-2: 0.018575268598 - 0.8369806013
/mnt/harddisk/datasets/OASIS/REG/OAS2_0032_MR1_to_OAS2_0032_MR3/OAS2_0032_MR1_halfwayto_OAS2_0032_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0032_MR1_to_OAS2_0032_MR4/OAS2_0032_MR1_halfwayto_OAS2_0032_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0032_MR2_to_OAS2_0032_MR3/OAS2_0032_MR2_halfwayto_OAS2_0032_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0032_MR2_to_OAS2_0032_MR4/OAS2_0032_MR2_halfwayto_OAS2_0032_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0032_MR3_to_OAS2_0032_MR4/OAS2_0032_MR3_halfway

 40-3: 0.0264869749149 - 0.940914469181
/mnt/harddisk/datasets/OASIS/REG/OAS2_0040_MR2_to_OAS2_0040_MR4/OAS2_0040_MR2_halfwayto_OAS2_0040_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0040_MR3_to_OAS2_0040_MR4/OAS2_0040_MR3_halfwayto_OAS2_0040_MR4_brain_seg.nii.gz
41 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0041_MR1_to_OAS2_0041_MR2/OAS2_0041_MR1_halfwayto_OAS2_0041_MR2_brain_seg.nii.gz
 41-2: 0.0214510615641 - 0.972303219192
/mnt/harddisk/datasets/OASIS/REG/OAS2_0041_MR1_to_OAS2_0041_MR3/OAS2_0041_MR1_halfwayto_OAS2_0041_MR3_brain_seg.nii.gz
 41-3: 0.0279985658775 - 0.964250224367
/mnt/harddisk/datasets/OASIS/REG/OAS2_0041_MR1_to_OAS2_0041_MR4/OAS2_0041_MR1_halfwayto_OAS2_0041_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0041_MR2_to_OAS2_0041_MR3/OAS2_0041_MR2_halfwayto_OAS2_0041_MR3_brain_seg.nii.gz
 41-3: 0.0240517339893 - 0.971224602666
/mnt/harddisk/datasets/OASIS/REG/OAS2_0041_MR2_to_OAS2_0041_MR4/OAS2_0041_MR2_halfwayto_OAS2_0041_MR4_brain_seg.nii.gz
/m

 49-3: 0.0160117981326 - 0.967829833941
/mnt/harddisk/datasets/OASIS/REG/OAS2_0049_MR2_to_OAS2_0049_MR4/OAS2_0049_MR2_halfwayto_OAS2_0049_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0049_MR3_to_OAS2_0049_MR4/OAS2_0049_MR3_halfwayto_OAS2_0049_MR4_brain_seg.nii.gz
50 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0050_MR1_to_OAS2_0050_MR2/OAS2_0050_MR1_halfwayto_OAS2_0050_MR2_brain_seg.nii.gz
 50-2: 0.0194410189179 - 0.961407780071
/mnt/harddisk/datasets/OASIS/REG/OAS2_0050_MR1_to_OAS2_0050_MR3/OAS2_0050_MR1_halfwayto_OAS2_0050_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0050_MR1_to_OAS2_0050_MR4/OAS2_0050_MR1_halfwayto_OAS2_0050_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0050_MR2_to_OAS2_0050_MR3/OAS2_0050_MR2_halfwayto_OAS2_0050_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0050_MR2_to_OAS2_0050_MR4/OAS2_0050_MR2_halfwayto_OAS2_0050_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0050_MR3_to_OAS2_0050_MR4/OAS2_0050_MR3_half

 60-2: 0.0219650659104 - 0.961474989561
/mnt/harddisk/datasets/OASIS/REG/OAS2_0060_MR1_to_OAS2_0060_MR3/OAS2_0060_MR1_halfwayto_OAS2_0060_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0060_MR1_to_OAS2_0060_MR4/OAS2_0060_MR1_halfwayto_OAS2_0060_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0060_MR2_to_OAS2_0060_MR3/OAS2_0060_MR2_halfwayto_OAS2_0060_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0060_MR2_to_OAS2_0060_MR4/OAS2_0060_MR2_halfwayto_OAS2_0060_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0060_MR3_to_OAS2_0060_MR4/OAS2_0060_MR3_halfwayto_OAS2_0060_MR4_brain_seg.nii.gz
61 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0061_MR1_to_OAS2_0061_MR2/OAS2_0061_MR1_halfwayto_OAS2_0061_MR2_brain_seg.nii.gz
 61-2: 0.022914164325 - 0.96337566313
/mnt/harddisk/datasets/OASIS/REG/OAS2_0061_MR1_to_OAS2_0061_MR3/OAS2_0061_MR1_halfwayto_OAS2_0061_MR3_brain_seg.nii.gz
 61-3: 0.0257211736175 - 0.958967734593
/mnt/harddisk/datasets/OASIS/REG/OAS2_0061_M

 69-2: 0.0275334387742 - 0.970282623213
/mnt/harddisk/datasets/OASIS/REG/OAS2_0069_MR1_to_OAS2_0069_MR3/OAS2_0069_MR1_halfwayto_OAS2_0069_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0069_MR1_to_OAS2_0069_MR4/OAS2_0069_MR1_halfwayto_OAS2_0069_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0069_MR2_to_OAS2_0069_MR3/OAS2_0069_MR2_halfwayto_OAS2_0069_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0069_MR2_to_OAS2_0069_MR4/OAS2_0069_MR2_halfwayto_OAS2_0069_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0069_MR3_to_OAS2_0069_MR4/OAS2_0069_MR3_halfwayto_OAS2_0069_MR4_brain_seg.nii.gz
70 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0070_MR1_to_OAS2_0070_MR2/OAS2_0070_MR1_halfwayto_OAS2_0070_MR2_brain_seg.nii.gz
 70-2: 0.0251011729209 - 0.943989250606
/mnt/harddisk/datasets/OASIS/REG/OAS2_0070_MR1_to_OAS2_0070_MR3/OAS2_0070_MR1_halfwayto_OAS2_0070_MR3_brain_seg.nii.gz
 70-3: 0.02321099296 - 0.94953382523
/mnt/harddisk/datasets/OASIS/REG/OAS2_0070_MR

 78-2: 0.0199444311063 - 0.968452711248
/mnt/harddisk/datasets/OASIS/REG/OAS2_0078_MR1_to_OAS2_0078_MR3/OAS2_0078_MR1_halfwayto_OAS2_0078_MR3_brain_seg.nii.gz
 78-3: 0.0250487241079 - 0.951912101169
/mnt/harddisk/datasets/OASIS/REG/OAS2_0078_MR1_to_OAS2_0078_MR4/OAS2_0078_MR1_halfwayto_OAS2_0078_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0078_MR2_to_OAS2_0078_MR3/OAS2_0078_MR2_halfwayto_OAS2_0078_MR3_brain_seg.nii.gz
 78-3: 0.0241543414082 - 0.950356923813
/mnt/harddisk/datasets/OASIS/REG/OAS2_0078_MR2_to_OAS2_0078_MR4/OAS2_0078_MR2_halfwayto_OAS2_0078_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0078_MR3_to_OAS2_0078_MR4/OAS2_0078_MR3_halfwayto_OAS2_0078_MR4_brain_seg.nii.gz
79 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0079_MR1_to_OAS2_0079_MR2/OAS2_0079_MR1_halfwayto_OAS2_0079_MR2_brain_seg.nii.gz
 79-2: 0.025965551351 - 0.932708633349
/mnt/harddisk/datasets/OASIS/REG/OAS2_0079_MR1_to_OAS2_0079_MR3/OAS2_0079_MR1_halfwayto_OAS2_0079_MR3_brain_seg.nii.gz
 79

 88-2: 0.0199790727101 - 0.960303148932
/mnt/harddisk/datasets/OASIS/REG/OAS2_0088_MR1_to_OAS2_0088_MR3/OAS2_0088_MR1_halfwayto_OAS2_0088_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0088_MR1_to_OAS2_0088_MR4/OAS2_0088_MR1_halfwayto_OAS2_0088_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0088_MR2_to_OAS2_0088_MR3/OAS2_0088_MR2_halfwayto_OAS2_0088_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0088_MR2_to_OAS2_0088_MR4/OAS2_0088_MR2_halfwayto_OAS2_0088_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0088_MR3_to_OAS2_0088_MR4/OAS2_0088_MR3_halfwayto_OAS2_0088_MR4_brain_seg.nii.gz
89 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0089_MR1_to_OAS2_0089_MR2/OAS2_0089_MR1_halfwayto_OAS2_0089_MR2_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0089_MR1_to_OAS2_0089_MR3/OAS2_0089_MR1_halfwayto_OAS2_0089_MR3_brain_seg.nii.gz
 89-3: 0.0244910571947 - 0.961720712265
/mnt/harddisk/datasets/OASIS/REG/OAS2_0089_MR1_to_OAS2_0089_MR4/OAS2_0089_MR1_half

 98-2: 0.0180928180519 - 0.963703952921
/mnt/harddisk/datasets/OASIS/REG/OAS2_0098_MR1_to_OAS2_0098_MR3/OAS2_0098_MR1_halfwayto_OAS2_0098_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0098_MR1_to_OAS2_0098_MR4/OAS2_0098_MR1_halfwayto_OAS2_0098_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0098_MR2_to_OAS2_0098_MR3/OAS2_0098_MR2_halfwayto_OAS2_0098_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0098_MR2_to_OAS2_0098_MR4/OAS2_0098_MR2_halfwayto_OAS2_0098_MR4_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0098_MR3_to_OAS2_0098_MR4/OAS2_0098_MR3_halfwayto_OAS2_0098_MR4_brain_seg.nii.gz
99 : /mnt/harddisk/datasets/OASIS/REG/OAS2_0099_MR1_to_OAS2_0099_MR2/OAS2_0099_MR1_halfwayto_OAS2_0099_MR2_brain_seg.nii.gz
 99-2: 0.0257481221338 - 0.95964043116
/mnt/harddisk/datasets/OASIS/REG/OAS2_0099_MR1_to_OAS2_0099_MR3/OAS2_0099_MR1_halfwayto_OAS2_0099_MR3_brain_seg.nii.gz
/mnt/harddisk/datasets/OASIS/REG/OAS2_0099_MR1_to_OAS2_0099_MR4/OAS2_0099_MR1_halfw

In [16]:
from keras.models import Model
import os
import nibabel as nib
import SimpleITK as sitk

ref_pattern = 'OAS2_{0:04}_MR{1}_halfwayto_OAS2_{0:04}_MR{2}'
file_general_pattern = 'OAS2_{0:04}_MR{1}_to_OAS2_{0:04}_MR{2}'
dataset_location = '/mnt/harddisk/datasets/OASIS/REG/{}/{}.nii.gz'

tuples = [(1, 2, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4)]
mad_results = {}
ssim_results = {}
curr_patch_shape = (32, 32, 32)
step = (32, 32, 32)
for i in range(1, 100) :
    print '{} :'.format(i),
    for (j, k, a) in tuples :
        if k != a :
            ref_filename = dataset_location.format(
                    file_general_pattern.format(i, j, k), ref_pattern.format(i, j, k))
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, j, k) + '_brain_seg')
        else :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, k, j))
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), ref_pattern.format(i, k, j) + '_seg')

        mov_filename = ref_filename
        mov_prob_filename = ref_prob_filename

        if not os.path.exists(ref_filename):
            continue

        ######################################################################################
        volume_init = nib.load(ref_filename).get_data()

        ref_patches = extract_patches(volume_init, curr_patch_shape, step)
        ref_patches = ref_patches.reshape((-1, 1, ) + curr_patch_shape)
        ######################################################################################
        volume_init = nib.load(ref_prob_filename).get_data() == 1

        ref_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
        ref_prob_patches = ref_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
        ######################################################################################
        ref_test = np.hstack((ref_patches, ref_prob_patches)).astype('float32') ##
        seg_test = ref_prob_patches.astype('float32') ##
        del ref_patches, ref_prob_patches
        ######################################################################################
        print '{}->{}'.format(a, a)

        pred = model.predict(
            [ref_test[:, 0:1], seg_test], verbose=1, batch_size=64)[2]
        pred = pred.reshape((-1, ) + curr_patch_shape)

        volume = perform_voting(pred, curr_patch_shape, volume_init.shape, step)

        volume = np.multiply(nib.load(mov_filename).get_data() != 0, volume)

        volume_data = nib.load(mov_filename)

        nib.save(nib.Nifti1Image(volume, volume_data.affine),
                 'mse_results/{}_{}_to_{}.nii.gz'.format(i, a, a))
        
        res = sitk.ReadImage('mse_results/{}_{}_to_{}.nii.gz'.format(i, a, a))

        caster = sitk.CastImageFilter()
        caster.SetOutputPixelType(res.GetPixelID())

        orig = caster.Execute(sitk.ReadImage(mov_filename))
        seg = caster.Execute(sitk.ReadImage(mov_prob_filename))

        thresholder = sitk.BinaryThresholdImageFilter()
        enhanced_vol = caster.Execute(
            thresholder.Execute(seg, -1, -1, 1, 0))

        pairs = [(1, 1), (2, 3)]
        for (l, h) in pairs :                
            thresholder = sitk.BinaryThresholdImageFilter()
            mask = thresholder.Execute(seg, l, h, 1, 0)

            masker = sitk.MaskImageFilter()
            ref_masked = masker.Execute(orig, mask)
            moving_masked = masker.Execute(res, mask)

            matcher = sitk.HistogramMatchingImageFilter()
            matcher.SetNumberOfHistogramLevels(2048)
            matcher.SetNumberOfMatchPoints(100)
            matcher.SetThresholdAtMeanIntensity(True)
            partial_result = matcher.Execute(moving_masked, ref_masked)

            adder = sitk.AddImageFilter()
            enhanced_vol = adder.Execute(enhanced_vol, partial_result)

        masker = sitk.MaskImageFilter()
        mask = thresholder.Execute(seg, 0, 0, 0, 1)
        enhanced_vol = masker.Execute(enhanced_vol, mask)
        sitk.WriteImage(
            enhanced_vol,
            'mse_results/{}_{}_to_{}_cor.nii.gz'.format(i, a, a))

        volume = nib.load('mse_results/{}_{}_to_{}_cor.nii.gz'.format(i, a, a)).get_data()
        act_vol = np.float64(nib.load(mov_filename).get_data())
        mad_movgen = mad(volume, act_vol)
        ssim_movgen = ssim(volume, act_vol)

        mad_results['{} {}->{}'.format(i, a, a)] = mad_movgen
        ssim_results['{} {}->{}'.format(i, a, a)] = ssim_movgen
        print '{}-{}: {} - {}'.format(i, a, mad_movgen, ssim_movgen)

1 : 1->1
1-1: 0.0165569439424 - 0.978559562859
2->2
1-2: 0.016757745787 - 0.979367974128
2 : 1->1
2-1: 0.0414045747404 - 0.942364056125
2->2
2-2: 0.0193558523657 - 0.978311541727
3->3
2-3: 0.0253466302044 - 0.978658061398
3 : 4 : 1->1
4-1: 0.0274777182251 - 0.93604395135
2->2
4-2: 0.018050219702 - 0.971553112772
5 : 1->1
5-1: 0.0163551672494 - 0.974359098959
2->2
5-2: 0.0151620192686 - 0.975066591719
3->3
5-3: 0.0168428031751 - 0.973244348517
6 : 7 : 3->3
7-3: 0.0211032184944 - 0.968587666715
4->4
7-4: 0.022283305121 - 0.96705581392
8 : 1->1
8-1: 0.0186595673017 - 0.976565597121
2->2
8-2: 0.0166640023363 - 0.974983142996
9 : 1->1
9-1: 0.0158141889373 - 0.967975276576
2->2
9-2: 0.0162071799221 - 0.971666135436
10 : 1->1
10-1: 0.0169478422503 - 0.977264772681
2->2
10-2: 0.0182903461357 - 0.977009147367
11 : 12 : 1->1
12-1: 0.017299917265 - 0.977063604141
2->2
12-2: 0.0202103333907 - 0.977227727179
3->3
12-3: 0.020210582354 - 0.976941424726
13 : 1->1
13-1: 0.0185088475399 - 0.974387267224

2->2
43-2: 0.0184089330811 - 0.981583567348
44 : 1->1
44-1: 0.0196618386804 - 0.978428324354
2->2
44-2: 0.0179017187552 - 0.975094698378
3->3
44-3: 0.0180301760464 - 0.977113774601
45 : 1->1
45-1: 0.0168941688684 - 0.981202082366
2->2
45-2: 0.0174302859326 - 0.980299382952
46 : 1->1
46-1: 0.0163068333585 - 0.974088775574
2->2
46-2: 0.016750054255 - 0.977071786497
47 : 1->1
47-1: 0.0182703695399 - 0.980115495294
2->2
47-2: 0.0181982079006 - 0.978800651498
48 : 1->1
48-1: 0.0193550831527 - 0.976555446009
2->2
48-2: 0.0185420817452 - 0.977744936457
3->3
48-3: 0.0176169854868 - 0.974006127015
4->4
48-4: 0.017365428783 - 0.975918199759
49 : 1->1
49-1: 0.0169516988086 - 0.973559799211
2->2
49-2: 0.0148411939122 - 0.97916711243
3->3
49-3: 0.0138043548505 - 0.976026795799
50 : 1->1
50-1: 0.0153894248327 - 0.966509104496
2->2
50-2: 0.0155605734985 - 0.974852969067
51 : 1->1
51-1: 0.0180001370783 - 0.977259191235
2->2
51-2: 0.019695029497 - 0.976281757801
3->3
51-3: 0.0199490140005 - 0.975425830

79 : 1->1
79-1: 0.0315399033538 - 0.945433757368
2->2
79-2: 0.0139865072653 - 0.980534838979
3->3
79-3: 0.0193898382816 - 0.979941973793
80 : 1->1
80-1: 0.0279305551221 - 0.790067863148
2->2
80-2: 0.0201817416591 - 0.973742884995
3->3
80-3: 0.0176152841334 - 0.97299933676
81 : 1->1
81-1: 0.0179734952877 - 0.980984315225
2->2
81-2: 0.019301527018 - 0.978878124702
82 : 83 : 84 : 85 : 1->1
85-1: 0.0195518093195 - 0.976394613735
2->2
85-2: 0.018696998009 - 0.974336843714
86 : 1->1
86-1: 0.0198964408852 - 0.976318792397
2->2
86-2: 0.0180071093104 - 0.973845590727
87 : 1->1
87-1: 0.0207255970599 - 0.97427928038
2->2
87-2: 0.0249367643088 - 0.977029579232
88 : 1->1
88-1: 0.0159191923365 - 0.967262117726
2->2
88-2: 0.0167104453869 - 0.968716839929
89 : 3->3
89-3: 0.0193149187871 - 0.975110366976
90 : 1->1
90-1: 0.0192274302001 - 0.980253759975
2->2
90-2: 0.0252099714279 - 0.975516275387
3->3
90-3: 0.0221966937023 - 0.976819528911
91 : 1->1
91-1: 0.016142539511 - 0.946175228274
2->2
91-2: 0.015