In [1]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting

from wnet import generate_wnet_model

Using TensorFlow backend.


In [21]:
wparams = {}
wparams['input_channels'] = 2
wparams['output_channels'] = 1
wparams['latent_channels'] = 16
wparams['scale'] = [0.5, 0.5, 0.5]
wparams['use_combined_loss'] = True
wparams['patch_shape'] = (32, 32, 32)
# wparams['loss_weights'] = [1, 1.0/np.sqrt(32**4), 1]
wparams['loss_weights'] = [1, 1, 1.1]

segparams = {}
segparams['seg_model_filename'] = ''
segparams['seg_model_params_filename'] = ''
segparams['segmentation_classes'] = 4

model = generate_wnet_model(wparams, segparams)

In [22]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_15 (InputLayer)           (None, 1, 32, 32, 32 0                                            
__________________________________________________________________________________________________
input_14 (InputLayer)           (None, 1, 32, 32, 32 0                                            
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 1, 32, 32, 32 0           input_15[0][0]                   
__________________________________________________________________________________________________
multiply_5 (Multiply)           (None, 1, 32, 32, 32 0           input_14[0][0]                   
                                                                 lambda_3[0][0]                   
__________

In [23]:
import os
import nibabel as nib

file_general_pattern = 'ADNI_{0:03}_MR{1}_to_ADNI_{0:03}_MR{2}_siena'
dataset_location = '/mnt/harddisk/datasets/ADNI/REG/{}/{}.nii.gz'

step = (16, 16, 16)
threshold = np.int32(0.30 * np.prod(wparams['patch_shape'][:]))
seg_train = np.empty((0, 1, ) + wparams['patch_shape'])
ref_train = np.empty((0, 1, ) + wparams['patch_shape'])
out_train = np.empty((0, 1, ) + wparams['patch_shape'])
for i in range(37, 153) :
    print '{} :'.format(i),
    
    for j in range(0, 11) :
        for k in range(j+1, 11) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'B')
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'B_seg')
            mov_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'A_reg')
            mov_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'A_reg_seg')
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
            
            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()
            
            mask_patches = extract_patches(volume_init != 0, wparams['patch_shape'], step)

            useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold
            
            del mask_patches
            
            ref_patches = extract_patches(volume_init, wparams['patch_shape'], step)
            ref_patches = ref_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            ######################################################################################
            ref_train = np.vstack((ref_patches, ref_train)).astype('float32')
            del ref_patches

            ######################################################################################
            volume_init = nib.load(mov_filename).get_data()
            
            mov_patches = extract_patches(volume_init, wparams['patch_shape'], step)
            mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            ######################################################################################
            volume_init = nib.load(mov_prob_filename).get_data() == 1
            
            mov_prob_patches = extract_patches(volume_init, wparams['patch_shape'], step)
            mov_prob_patches = mov_prob_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            ######################################################################################
            seg_train = np.vstack((mov_prob_patches, seg_train)).astype('float32') ##
            out_train = np.vstack((mov_patches, out_train)).astype('float32')
            del mov_patches, mov_prob_patches
            ######################################################################################
            print '{}->{}'.format(j, k),
    print

37 : 0->5 0->8 1->2 1->4 1->7 2->4 2->7 4->7 5->8 6->9
38 :
39 :
40 :
41 :
42 :
43 :
44 :
45 :
46 :
47 :
48 :
49 :
50 :
51 : 0->4 0->5 1->2 1->3 2->3 4->5
52 : 0->2 0->3 2->3
53 : 0->1 0->2 1->2
54 :
55 :
56 :
57 :
58 :
59 :
60 :
61 :
62 :
63 :
64 :
65 :
66 :
67 : 1->2
68 : 1->3 1->4 3->4
69 :
70 : 1->2
71 :
72 :
73 :
74 :
75 :
76 :
77 :
78 :
79 :
80 :
81 :
82 : 0->1 0->2 0->3 0->5 1->2 1->3 1->5 2->3 2->5 3->5
83 :
84 :
85 :
86 :
87 :
88 :
89 :
90 :
91 :
92 :
93 :
94 : 0->3 2->4
95 :
96 :
97 :
98 : 0->3 0->6 1->2 1->4 1->5 1->7 1->8 2->4 2->5 2->7 2->8 3->6 4->5 4->7 4->8 5->7 5->8 7->8
99 :
100 :
101 :
102 :
103 :
104 :
105 :
106 :
107 :
108 :
109 :
110 :
111 :
112 :
113 :
114 : 0->1 0->2 1->2
115 :
116 : 0->3 0->5 3->5 4->6 4->7 4->8 4->9 6->7 6->8 6->9 7->8 7->9 8->9
117 :
118 :
119 :
120 :
121 :
122 :
123 : 0->1 0->2 0->3 0->4 1->2 1->3 1->4 2->3 2->4 3->4
124 :
125 :
126 : 0->3 0->4 0->5 1->2 1->6 2->6 3->4 3->5 4->5
127 : 0->1 0->5 1->5 2->7 3->4
128 : 0->4 1->6 2->3
129 :
130 :

In [24]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 10

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_mseloss_a2a1.h5', save_best_only=True, save_weights_only=True)

N = len(ref_train)
model.fit(
    [ref_train, seg_train],
    [np.multiply(out_train, 1-seg_train), np.multiply(out_train, seg_train), out_train],
    validation_split=0.3, epochs=100, batch_size=32,
    callbacks=[checkpointer, stopper])

Train on 28483 samples, validate on 12208 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100


<keras.callbacks.History at 0x7f964d4c8450>

In [29]:
model.load_weights('models/ag_mseloss_a1a2.h5')

In [26]:
from keras.models import Model
import os
import nibabel as nib
import SimpleITK as sitk

file_general_pattern = 'ADNI_{0:03}_MR{1}_to_ADNI_{0:03}_MR{2}_siena'
dataset_location = '/mnt/harddisk/datasets/ADNI/REG/{}/{}.nii.gz'

mad_results = {}
ssim_results = {}
curr_patch_shape = (32, 32, 32)
step = (16, 16, 16)
for i in range(1, 37) :
    print '{} :'.format(i),
    for j in range(0, 10) :
        for k in range(j+1, 10) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'B')
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'B_seg')
            mov_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'A_reg')
            mov_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'A_reg_seg')
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
                
            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()

            ref_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_patches = ref_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(mov_prob_filename).get_data() == 1
            
            mov_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            mov_prob_patches = mov_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            print '{}->{}'.format(j, k),

            pred = model.predict(
                [ref_patches, mov_prob_patches], verbose=1, batch_size=64)[2]
            pred = pred.reshape((-1, ) + curr_patch_shape)

            volume = perform_voting(pred, curr_patch_shape, volume_init.shape, step)

            volume_data = nib.load(mov_filename)

            volume = np.multiply(volume_data.get_data() != 0, volume)
            
            save_filename = 'mse_results/{}_{}_to_{}.nii.gz'.format(i, j, k)
            save_filename_cor = 'mse_results/{}_{}_to_{}_cor.nii.gz'.format(i, j, k)

            nib.save(nib.Nifti1Image(volume, volume_data.affine), save_filename)

            res = sitk.ReadImage(save_filename)

            caster = sitk.CastImageFilter()
            caster.SetOutputPixelType(res.GetPixelID())

            orig = caster.Execute(sitk.ReadImage(mov_filename))
            seg = caster.Execute(sitk.ReadImage(mov_prob_filename))

            thresholder = sitk.BinaryThresholdImageFilter()
            enhanced_vol = caster.Execute(
                thresholder.Execute(seg, -1, -1, 1, 0))

            pairs = [(1, 1), (2, 3)]
            for (a, b) in pairs :                
                thresholder = sitk.BinaryThresholdImageFilter()
                mask = thresholder.Execute(seg, a, b, 1, 0)

                masker = sitk.MaskImageFilter()
                ref_masked = masker.Execute(orig, mask)
                moving_masked = masker.Execute(res, mask)

                matcher = sitk.HistogramMatchingImageFilter()
                matcher.SetNumberOfHistogramLevels(2048)
                matcher.SetNumberOfMatchPoints(15)
                matcher.SetThresholdAtMeanIntensity(True)
                partial_result = matcher.Execute(moving_masked, ref_masked)

                adder = sitk.AddImageFilter()
                enhanced_vol = adder.Execute(enhanced_vol, partial_result)

            masker = sitk.MaskImageFilter()
            mask = thresholder.Execute(seg, 0, 0, 0, 1)
            enhanced_vol = masker.Execute(enhanced_vol, mask)
            sitk.WriteImage(enhanced_vol, save_filename_cor)

            volume = nib.load(save_filename_cor).get_data()
            act_vol = np.float64(nib.load(mov_filename).get_data())
            mad_movgen = mad(volume, act_vol)
            ssim_movgen = ssim(volume, act_vol)

            mad_results['{} {}->{}'.format(i, j, k)] = mad_movgen
            ssim_results['{} {}->{}'.format(i, j, k)] = ssim_movgen
            print '{}-{}: {} - {}'.format(i, k, mad_movgen, ssim_movgen)

 2-1: 0.0145475146341 - 0.985838571335
 2-2: 0.0153455677969 - 0.986108127765
 2-2: 0.0148068729573 - 0.985431718501
 3-4: 0.0140218093699 - 0.979262580292
 3-5: 0.0119566472145 - 0.982351423457
 5-5: 0.0129650037475 - 0.978426716805
 5-7: 0.0138332290071 - 0.975878727649
 5-4: 0.0145509360546 - 0.976224238142
 5-8: 0.0169809946673 - 0.972765510018
 5-3: 0.0149660099116 - 0.983114582219
 5-8: 0.016147034477 - 0.96852585383
 5-7: 0.0152272029905 - 0.975328045882
 5-9: 0.0241196120047 - 0.968495935016
 6-4: 0.0173282360449 - 0.976503282002
 6-5: 0.0189247829164 - 0.97861914061
 6-7: 0.0205112687604 - 0.976136085555
 6-9: 0.0189532231588 - 0.976772163687
 6-3: 0.0145018636373 - 0.98403298281
 6-7: 0.0136611758706 - 0.984864783088
 6-9: 0.0126341805599 - 0.985355084821
 6-8: 0.0153647748308 - 0.978653817386
 6-9: 0.0112585252925 - 0.970608699492
 7-3: 0.0163886979169 - 0.980278549813
 7-4: 0.0163474855176 - 0.979370993047
 7-7: 0.0250081861894 - 0.96984995121
 7-8: 0.016593694933 - 0.98193

In [30]:
from keras.models import Model
import os
import nibabel as nib
import SimpleITK as sitk

file_general_pattern = 'ADNI_{0:03}_MR{1}_to_ADNI_{0:03}_MR{2}_siena'
dataset_location = '/mnt/harddisk/datasets/ADNI/REG/{}/{}.nii.gz'

tuples = [(1, 2, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4)]
mad_results = {}
ssim_results = {}
curr_patch_shape = (32, 32, 32)
step = (16, 16, 16)
for i in range(37, 154) :
    print '{} :'.format(i),
    for j in range(0, 10) :
        for k in range(j+1, 10) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'A_reg')
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k), 'A_reg_seg')

            mov_filename = ref_filename
            mov_prob_filename = ref_prob_filename

            if not os.path.exists(ref_filename):
                continue

            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()

            ref_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_patches = ref_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(ref_prob_filename).get_data() == 1

            ref_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_prob_patches = ref_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            ref_test = np.hstack((ref_patches, ref_prob_patches)).astype('float32') ##
            seg_test = ref_prob_patches.astype('float32') ##
            del ref_patches, ref_prob_patches
            ######################################################################################
            print '{}->{}'.format(k, k)

            pred = model.predict(
                [ref_test[:, 0:1], seg_test], verbose=1, batch_size=64)[2]
            pred = pred.reshape((-1, ) + curr_patch_shape)

            volume = perform_voting(pred, curr_patch_shape, volume_init.shape, step)

            volume = np.multiply(nib.load(mov_filename).get_data() != 0, volume)

            volume_data = nib.load(mov_filename)

            nib.save(nib.Nifti1Image(volume, volume_data.affine),
                     'mse_results/{}_{}_to_{}.nii.gz'.format(i, k, k))

            res = sitk.ReadImage('mse_results/{}_{}_to_{}.nii.gz'.format(i, k, k))

            caster = sitk.CastImageFilter()
            caster.SetOutputPixelType(res.GetPixelID())

            orig = caster.Execute(sitk.ReadImage(mov_filename))
            seg = caster.Execute(sitk.ReadImage(mov_prob_filename))

            thresholder = sitk.BinaryThresholdImageFilter()
            enhanced_vol = caster.Execute(
                thresholder.Execute(seg, -1, -1, 1, 0))

            pairs = [(1, 1), (2, 3)]
            for (l, h) in pairs :                
                thresholder = sitk.BinaryThresholdImageFilter()
                mask = thresholder.Execute(seg, l, h, 1, 0)

                masker = sitk.MaskImageFilter()
                ref_masked = masker.Execute(orig, mask)
                moving_masked = masker.Execute(res, mask)

                matcher = sitk.HistogramMatchingImageFilter()
                matcher.SetNumberOfHistogramLevels(2048)
                matcher.SetNumberOfMatchPoints(100)
                matcher.SetThresholdAtMeanIntensity(True)
                partial_result = matcher.Execute(moving_masked, ref_masked)

                adder = sitk.AddImageFilter()
                enhanced_vol = adder.Execute(enhanced_vol, partial_result)

            masker = sitk.MaskImageFilter()
            mask = thresholder.Execute(seg, 0, 0, 0, 1)
            enhanced_vol = masker.Execute(enhanced_vol, mask)
            sitk.WriteImage(
                enhanced_vol,
                'mse_results/{}_{}_to_{}_cor.nii.gz'.format(i, k, k))

            volume = nib.load('mse_results/{}_{}_to_{}_cor.nii.gz'.format(i, k, k)).get_data()
            act_vol = np.float64(nib.load(mov_filename).get_data())
            mad_movgen = mad(volume, act_vol)
            ssim_movgen = ssim(volume, act_vol)

            mad_results['{} {}->{}'.format(i, k, k)] = mad_movgen
            ssim_results['{} {}->{}'.format(i, k, k)] = ssim_movgen
            print '{}-{}: {} - {}'.format(i, k, mad_movgen, ssim_movgen)

37 : 5->5
37-5: 0.01352928562 - 0.985313290887
8->8
37-8: 0.0127462437503 - 0.984661949501
2->2
37-2: 0.0140288894377 - 0.984837180357
4->4
37-4: 0.012464295859 - 0.984948963595
7->7
37-7: 0.0126062284047 - 0.984930809736
4->4
37-4: 0.0118236779711 - 0.987019532243
7->7
37-7: 0.0116051849761 - 0.986996402002
7->7
37-7: 0.0133285747212 - 0.985684787061
8->8
37-8: 0.0121779989224 - 0.980697546609
9->9
37-9: 0.0115247677178 - 0.980618643395
38 : 39 : 40 : 41 : 42 : 43 : 44 : 45 : 46 : 47 : 48 : 49 : 50 : 51 : 4->4
51-4: 0.0129548876199 - 0.976527429346
5->5
51-5: 0.0119615156978 - 0.97780972396
2->2
51-2: 0.0108201807067 - 0.986900101764
3->3
51-3: 0.0120037151654 - 0.98622439647
3->3
51-3: 0.0114953723432 - 0.985439982741
5->5
51-5: 0.0113407356135 - 0.983110313242
52 : 2->2
52-2: 0.0102066161589 - 0.983794324605
3->3
52-3: 0.0109520147383 - 0.982623115173
3->3
52-3: 0.0115513011297 - 0.986023291379
53 : 1->1
53-1: 0.0100401642304 - 0.989514000805
2->2
53-2: 0.00920672040606 - 0.98981269

123-4: 0.0103545819228 - 0.988404688869
2->2
123-2: 0.0109861303103 - 0.987533147347
3->3
123-3: 0.0108204142809 - 0.987369509081
4->4
123-4: 0.0104440948346 - 0.987955094888
3->3
123-3: 0.0110473019815 - 0.987960168361
4->4
123-4: 0.0110524875859 - 0.988334040775
4->4
123-4: 0.0120172457374 - 0.986943162634
124 : 125 : 126 : 3->3
126-3: 0.0121284240228 - 0.984169028846
4->4
126-4: 0.0128808094416 - 0.984381128764
5->5
126-5: 0.0133401351588 - 0.983744008285
2->2
126-2: 0.0109222429006 - 0.979474162187
6->6
126-6: 0.0107323748188 - 0.979050528163
6->6
126-6: 0.0110643987173 - 0.978265039004
4->4
126-4: 0.0124493720096 - 0.984829166694
5->5
126-5: 0.0121068736461 - 0.984239867892
5->5
126-5: 0.0110002082177 - 0.980421779294
127 : 1->1
127-1: 0.0116126873785 - 0.974990623295
5->5
127-5: 0.0121804189061 - 0.976635933866
5->5
127-5: 0.0121170605723 - 0.979013478408
7->7
127-7: 0.00970834759088 - 0.986610834676
4->4
127-4: 0.0106000749592 - 0.986887327496
128 : 4->4
128-4: 0.0121993011348 -