In [2]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting

from model3d import Multimodel

In [3]:
curr_patch_shape = (16, 16, 16)
input_modalities = ['T1']
output_modalities = ['Gen']
output_weights = {'Gen' : 1.0, 'concat' : 0.5}
latent_dim = 32
channels = [2]
to_process = [True, True]
patch_shape = curr_patch_shape
scale = 1

a_model = Multimodel(
    input_modalities, output_modalities, output_weights, latent_dim, channels, patch_shape, to_process, scale)
a_model.build()

Latent dimensions: 16
Fuse latent representations using max
making output: Tensor("enc_T1_act9_1/add:0", shape=(?, 16, 16, 16, 16), dtype=float32) Tensor("dec_Gen_2/dec_Gen_act9/add:0", shape=(?, 1, 16, 16, 16), dtype=float32) em_0_dec_Gen
making output: Tensor("enc_T1_act9_1/add:0", shape=(?, 16, 16, 16, 16), dtype=float32) Tensor("dec_Gen_2/dec_Gen_act9/add:0", shape=(?, 1, 16, 16, 16), dtype=float32) em_1_dec_Gen
Skipping embedding distance outputs for unimodal model
all outputs:  [u'em_0_dec_Gen_1/add:0', u'em_1_dec_Gen_1/add:0']
output dict:  {'em_1_dec_Gen': <function adhoc_loss at 0x7f5c61791de8>, 'em_0_dec_Gen': <function adhoc_loss at 0x7f5c61791de8>}
loss weights:  {'em_1_dec_Gen': 1.0, 'em_0_dec_Gen': 1.0}


ValueError: total size of new array must be unchanged

In [7]:
import os
import nibabel as nib

from medical_data import cdr_info, nwbv_info, diff_info

file_general_pattern = 'OAS2_0{0:03}_MR{1}_{3}_OAS2_0{0:03}_MR{2}'
dataset_location = 'datasets/OASIS/OASIS2/REG/{}/{}.nii.gz'

step = (16, 16, 16)
threshold = np.int32(0.30 * np.prod(curr_patch_shape[:]))
seg_train = np.empty((0, 1, ) + curr_patch_shape)
ref_train = np.empty((0, 2, ) + curr_patch_shape)
out_train = np.empty((0, 1, ) + curr_patch_shape)
for i in range(1, 100) :
    print '{} :'.format(i),
    
    for j in range(1, 5) :
        diff = 0
        for k in range(j+1, 5) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, j, k, 'halfwayto'))
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, j, k, 'halfwayto') + '_brain_seg')
            mov_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, k, j, 'halfwayto'))
            mov_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, k, j, 'halfwayto') + '_seg')
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
            
            diff = diff + diff_info['OAS2_0{:03}_MR{}'.format(i, k)]
            
            if diff < 0.01 :
                continue

            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()
            volume_init = volume_init / volume_init.max()
            
            mask_patches = extract_patches(volume_init != 0, curr_patch_shape, step)

            useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold
            
            del mask_patches
            
            ref_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_patches = ref_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(ref_prob_filename).get_data() == 1
            
            ref_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_prob_patches = ref_prob_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            ref_train = np.vstack((np.hstack((ref_patches, ref_prob_patches)), ref_train)).astype('float32') ##
            del ref_patches, ref_prob_patches

            ######################################################################################
            volume_init = nib.load(mov_filename).get_data()
            volume_init = volume_init / volume_init.max()
            
            mov_patches = extract_patches(volume_init, curr_patch_shape, step)
            mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(mov_prob_filename).get_data() == 1
            
            mov_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            mov_prob_patches = mov_prob_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            seg_train = np.vstack((mov_prob_patches, seg_train)).astype('float32') ##
            out_train = np.vstack((mov_patches, out_train)).astype('float32')
            del mov_patches, mov_prob_patches
            ######################################################################################
            print '{}->{}'.format(j, k),
    print

 1 : 1->2
2 : 1->2 1->3 2->3
3 :
4 : 1->2
5 : 1->3 2->3
6 :
7 : 1->3 1->4 3->4
8 : 1->2
9 : 1->2
10 : 1->2
11 :
12 : 1->2 1->3 2->3
13 : 1->3 2->3
14 :
15 :
16 : 1->2
17 : 1->3 1->4
18 : 1->4 3->4
19 :
20 : 1->2 1->3 2->3
21 : 1->2
22 : 1->2
23 : 1->2
24 :
25 :
26 : 1->2
27 : 1->3 1->4 2->3 2->4
28 : 1->2
29 :
30 :
31 : 1->3 2->3
32 :
33 :
34 : 1->4 2->4 3->4
35 :
36 : 1->3 1->4 3->4
37 : 1->2 1->3 1->4 2->3 2->4
38 :
39 : 1->2
40 : 1->3 2->3
41 : 1->3 2->3
42 : 1->2
43 : 1->2
44 : 1->2 1->3
45 : 1->2
46 :
47 :
48 : 1->2 1->3 1->4 2->3 2->4 3->4
49 :
50 : 1->2
51 : 1->2 1->3 2->3
52 : 1->2
53 : 1->2
54 : 1->2
55 : 1->2
56 : 1->2
57 : 1->3 2->3
58 : 1->3 2->3
59 :
60 : 1->2
61 : 1->2 1->3 2->3
62 : 1->2 1->3 2->3
63 : 1->2
64 : 1->3 2->3
65 :
66 :
67 : 1->2 1->3 1->4 2->3 2->4
68 : 1->2
69 :
70 : 1->2 1->3 1->4 2->3 2->4 3->4
71 : 1->2
72 :
73 : 1->2 1->3 1->4 2->3 2->4 3->4
74 :
75 :
76 : 1->2 1->3
77 : 1->2
78 : 1->3 2->3
79 : 1->2 1->3 2->3
80 : 1->2 1->3 2->3
81 : 1->2
82 :
83 :
84 

In [15]:
def adhoc_loss(y_true, y_pred) :
    mask = K.cast(K.not_equal(y_true, 0), 'float32')
    m = mae(y_true, y_pred*mask)
    cc = K.mean(categorical_crossentropy(
        S((y_true - train_mean) / train_std),
        S((y_pred - train_mean) / train_std)), axis=(1))
    return m + 0.01 * cc

In [33]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 3

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_o1o2_tv.h5', save_best_only=True, save_weights_only=True)

N = len(ref_train)
a_model.model.fit(
    [np.hstack((ref_train[:, 0:1], seg_train))],
    [out_train, out_train],
    validation_split=0.3, epochs=40,
    callbacks=[checkpointer, stopper])

Train on 38573 samples, validate on 16532 samples
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40


<keras.callbacks.History at 0x7f0d062c6fd0>

In [34]:
a_model.model.load_weights('models/ag_o1o2_tv.h5')

In [39]:
from keras.models import Model
import os
import nibabel as nib
import SimpleITK as sitk

file_general_pattern = 'OAS2_0{0:03}_MR{1}_{3}_OAS2_0{0:03}_MR{2}'
dataset_location = 'datasets/OASIS/OASIS2/REG/{}/{}.nii.gz'

mad_results = {}
ssim_results = {}
curr_patch_shape = (32, 32, 32)
step = (16, 16, 16)
for i in range(100, 190) :
    print '{} :'.format(i),
    for j in range(1, 5) :
        for k in range(j+1, 5) :
            ref_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, j, k, 'halfwayto'))
            ref_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, j, k, 'halfwayto') + '_brain_seg')
            mov_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, k, j, 'halfwayto'))
            mov_prob_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, k, j, 'halfwayto') + '_seg')
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue

            ######################################################################################
            volume_init = nib.load(ref_filename).get_data()
            volume_init = volume_init / volume_init.max()
            
            ref_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_patches = ref_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(ref_prob_filename).get_data() == 1
            
            ref_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            ref_prob_patches = ref_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            ref_test = np.hstack((ref_patches, ref_prob_patches)).astype('float32') ##
            del ref_patches, ref_prob_patches
            ######################################################################################
            volume_init = nib.load(mov_filename).get_data()
            
            mov_patches = extract_patches(volume_init, curr_patch_shape, step)
            mov_patches = mov_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            volume_init = nib.load(mov_prob_filename).get_data() == 1
            
            mov_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
            mov_prob_patches = mov_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
            ######################################################################################
            seg_test = mov_prob_patches.astype('float32') ##
            del mov_patches, mov_prob_patches
            ######################################################################################
            print '{}->{}'.format(j, k),
    
            pred = a_model.model.predict(
                [np.hstack((ref_test[:, 0:1], seg_test))], verbose=1)[1]
            pred = pred.reshape((-1, ) + curr_patch_shape)
            
            volume = perform_voting(pred, curr_patch_shape, volume_init.shape, step)
    
            volume = volume * nib.load(mov_filename).get_data().max()
            volume = np.multiply(nib.load(mov_filename).get_data() != 0, volume)

            volume_data = nib.load(mov_filename)
            
            nib.save(nib.Nifti1Image(volume, volume_data.affine),
                     'tv_results/{}_{}_to_{}.nii.gz'.format(i, j, k))

            mov_seg_filename = dataset_location.format(
                file_general_pattern.format(i, j, k, 'to'),
                file_general_pattern.format(i, k, j, 'halfwayto') + '_seg')

            res = sitk.ReadImage('tv_results/{}_{}_to_{}.nii.gz'.format(i, j, k))
            
            caster = sitk.CastImageFilter()
            caster.SetOutputPixelType(res.GetPixelID())
            
            orig = caster.Execute(sitk.ReadImage(mov_filename))
            seg = caster.Execute(sitk.ReadImage(mov_seg_filename))
            
            thresholder = sitk.BinaryThresholdImageFilter()
            enhanced_vol = caster.Execute(
                thresholder.Execute(seg, -1, -1, 1, 0))

            pairs = [(1, 3)]
            for (a, b) in pairs :                
                thresholder = sitk.BinaryThresholdImageFilter()
                mask = thresholder.Execute(seg, a, b, 1, 0)

                masker = sitk.MaskImageFilter()
                ref_masked = masker.Execute(orig, mask)
                moving_masked = masker.Execute(res, mask)

                matcher = sitk.HistogramMatchingImageFilter()
                matcher.SetNumberOfHistogramLevels(1024)
                matcher.SetNumberOfMatchPoints(15)
#                 matcher.SetThresholdAtMeanIntensity(True)
                partial_result = matcher.Execute(moving_masked, ref_masked)

                adder = sitk.AddImageFilter()
                enhanced_vol = adder.Execute(enhanced_vol, partial_result)

            masker = sitk.MaskImageFilter()
            mask = thresholder.Execute(seg, 0, 0, 0, 1)
            enhanced_vol = masker.Execute(enhanced_vol, mask)
            sitk.WriteImage(
                enhanced_vol,
                'tv_results/{}_{}_to_{}_cor.nii.gz'.format(i, j, k))

            volume = nib.load('tv_results/{}_{}_to_{}_cor.nii.gz'.format(i, j, k)).get_data()
            ref_vol = np.float64(nib.load(ref_filename).get_data())
            act_vol = np.float64(nib.load(mov_filename).get_data())
            mad_movgen = mad(volume, act_vol)
            ssim_movgen = ssim(volume, act_vol)

            mad_results['{} {}->{}'.format(i, j, k)] = mad_movgen
            ssim_results['{} {}->{}'.format(i, j, k)] = ssim_movgen
            print '{}-{}: {} - {}'.format(i, k, mad_movgen, ssim_movgen)

 100-2: 0.0200097411497 - 0.977986689094
 100-3: 0.0212650491053 - 0.976152976819
 100-3: 0.0194470355125 - 0.978128767289
 101-2: 0.0189666146154 - 0.961411501427
 101-3: 0.019062400455 - 0.973364544691
 101-3: 0.0180518605273 - 0.97784806666
 102-2: 0.0210240991108 - 0.966112270822
 102-3: 0.02269769686 - 0.968011486121
 102-3: 0.0252259324666 - 0.963367511471
 103-2: 0.0160669584339 - 0.974551874063
 103-3: 0.0188296731589 - 0.971807098422
 103-3: 0.0176035369855 - 0.975748912301
 104-2: 0.017465528467 - 0.951203221458
 105-2: 0.0228107183896 - 0.954374141619
 106-2: 0.019706747924 - 0.971398967362
 108-2: 0.0235164992111 - 0.956835473121
 109-2: 0.0187925570061 - 0.972160743732
 111-2: 0.0205593138572 - 0.934952137488
 112-2: 0.0243435947233 - 0.954769877719
 113-2: 0.0212247625676 - 0.968867685667
 114-2: 0.0225264842746 - 0.927072965505
 116-2: 0.0230040183698 - 0.941516848394
 117-2: 0.0224418450126 - 0.954307675648
 117-3: 0.0251091324711 - 0.962808343077
 117-4: 0.022768087068

 161-2: 0.0210229114058 - 0.961915592409
 161-3: 0.0176480724595 - 0.96361141514
 161-3: 0.0179693940335 - 0.93416397125
 162-2: 0.022178095184 - 0.951634376454
 164-2: 0.0201640037216 - 0.964424653291
 165-2: 0.0206433227992 - 0.954662311434
 169-2: 0.014911422401 - 0.977088831453
 171-2: 0.0193399513926 - 0.974226299445
 171-3: 0.0202267660185 - 0.971885593752
 171-3: 0.0209816323019 - 0.976024135233
 172-2: 0.0197940660459 - 0.970774579534
 174-2: 0.017953495684 - 0.965496856143
 174-3: 0.0201396594547 - 0.969047961606
 174-3: 0.0213966478666 - 0.967478847549
 175-2: 0.0234588735184 - 0.96743329961
 175-3: 0.022998609588 - 0.967089478984
 175-3: 0.0224665919349 - 0.958187428784
 176-2: 0.0222952502148 - 0.961552480605
 176-3: 0.0219364273478 - 0.954982727695
 176-3: 0.020295342527 - 0.963773773822
 177-2: 0.0206976514085 - 0.952633029027
 178-2: 0.0195837009691 - 0.955319030991
 178-3: 0.0181923286516 - 0.960201280665
 178-3: 0.0173997519779 - 0.968287136605
 179-2: 0.0174681746613 

In [3]:
from keras.models import Model
import os
import nibabel as nib
import SimpleITK as sitk

file_general_pattern = 'OAS2_0{0:03}_MR{1}_{3}_OAS2_0{0:03}_MR{2}'
dataset_location = 'datasets/OASIS/OASIS2/REG/{}/{}.nii.gz'

tuples = [(1, 2, 1, 2), (1, 2, 2, 1), (1, 3, 3, 1), (1, 4, 4, 1)]
mad_results = {}
ssim_results = {}
curr_patch_shape = (32, 32, 32)
step = (16, 16, 16)
for i in range(100, 190) :
#     print '{} :'.format(i),
    for (j, k, a, b) in tuples :
        ref_filename = dataset_location.format(
            file_general_pattern.format(i, j, k, 'to'),
            file_general_pattern.format(i, a, b, 'halfwayto'))
        ref_prob_filename = dataset_location.format(
            file_general_pattern.format(i, j, k, 'to'),
            file_general_pattern.format(i, a, b, 'halfwayto') + ('_brain_seg' if a < b else '_seg'))
        mov_filename = ref_filename
        mov_prob_filename = ref_prob_filename

        if not os.path.exists(ref_filename):
            continue

#         ######################################################################################
#         volume_init = nib.load(ref_filename).get_data()
#         volume_init = volume_init / volume_init.max()

#         ref_patches = extract_patches(volume_init, curr_patch_shape, step)
#         ref_patches = ref_patches.reshape((-1, 1, ) + curr_patch_shape)
#         ######################################################################################
#         volume_init = nib.load(ref_prob_filename).get_data() == 1

#         ref_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
#         ref_prob_patches = ref_prob_patches.reshape((-1, 1, ) + curr_patch_shape)
#         ######################################################################################
#         ref_test = np.hstack((ref_patches, ref_prob_patches)).astype('float32') ##
#         seg_test = ref_prob_patches.astype('float32') ##
#         del ref_patches, ref_prob_patches
#         ######################################################################################
#         print '{}->{}'.format(a, a),

#         pred = a_model.model.predict(
#             [np.hstack((ref_test[:, 0:1], seg_test))], verbose=1)[1]
#         pred = pred.reshape((-1, ) + curr_patch_shape)

#         volume = perform_voting(pred, curr_patch_shape, volume_init.shape, step)

#         volume = volume * nib.load(mov_filename).get_data().max()
#         volume = np.multiply(nib.load(mov_filename).get_data() != 0, volume)

#         volume_data = nib.load(mov_filename)

#         nib.save(nib.Nifti1Image(volume, volume_data.affine),
#                  'tv_results/{}_{}_to_{}.nii.gz'.format(i, a, a))

#         mov_seg_filename = ref_prob_filename

#         res = sitk.ReadImage('tv_results/{}_{}_to_{}.nii.gz'.format(i, a, a))

#         caster = sitk.CastImageFilter()
#         caster.SetOutputPixelType(res.GetPixelID())

#         orig = caster.Execute(sitk.ReadImage(mov_filename))
#         seg = caster.Execute(sitk.ReadImage(mov_seg_filename))

#         thresholder = sitk.BinaryThresholdImageFilter()
#         enhanced_vol = caster.Execute(
#             thresholder.Execute(seg, -1, -1, 1, 0))

#         pairs = [(1, 3)]
#         for (l, h) in pairs :                
#             thresholder = sitk.BinaryThresholdImageFilter()
#             mask = thresholder.Execute(seg, l, h, 1, 0)

#             masker = sitk.MaskImageFilter()
#             ref_masked = masker.Execute(orig, mask)
#             moving_masked = masker.Execute(res, mask)

#             matcher = sitk.HistogramMatchingImageFilter()
#             matcher.SetNumberOfHistogramLevels(1024)
#             matcher.SetNumberOfMatchPoints(15)
#             partial_result = matcher.Execute(moving_masked, ref_masked)

#             adder = sitk.AddImageFilter()
#             enhanced_vol = adder.Execute(enhanced_vol, partial_result)

#         masker = sitk.MaskImageFilter()
#         mask = thresholder.Execute(seg, 0, 0, 0, 1)
#         enhanced_vol = masker.Execute(enhanced_vol, mask)
            
#         sitk.WriteImage(
#             enhanced_vol,
#             'tv_results/{}_{}_to_{}_cor.nii.gz'.format(i, a, a))

        volume = nib.load('tv_results/{}_{}_to_{}_cor.nii.gz'.format(i, a, a)).get_data()
        act_vol = np.float64(nib.load(mov_filename).get_data())
        mad_movgen = mad(volume, act_vol)
        ssim_movgen = ssim(volume, act_vol)

        mad_results['{} {}->{}'.format(i, j, k)] = mad_movgen
        ssim_results['{} {}->{}'.format(i, j, k)] = ssim_movgen
        print '{}-{}: {} - {}'.format(i, k, mad_movgen, ssim_movgen)

100-2: 0.0145007257649 - 0.985976906127
100-2: 0.0146274471815 - 0.985608970079
100-3: 0.0154512490589 - 0.985470795577
101-2: 0.0153865208513 - 0.977279987792
101-2: 0.013540480636 - 0.974919974063
101-3: 0.0145768764903 - 0.979943341421
102-2: 0.0140089732054 - 0.984398769897
102-2: 0.0147663126048 - 0.983835464031
102-3: 0.0154387141344 - 0.985374688646
103-2: 0.0156283765837 - 0.978212420416
103-2: 0.0126556433487 - 0.98369758887
103-3: 0.0146568471137 - 0.980904564397
104-2: 0.0170967246546 - 0.979455720557
104-2: 0.0138538221128 - 0.957921587992
105-2: 0.0172720657441 - 0.979325079086
105-2: 0.0139086916593 - 0.983695651988
106-2: 0.0141760005605 - 0.984895809708
106-2: 0.0142402803439 - 0.9822110088
108-2: 0.0144512988005 - 0.974153799469
108-2: 0.0167385178688 - 0.979985052914
109-2: 0.0135206196205 - 0.983216813292
109-2: 0.0144491399302 - 0.983633414773
111-2: 0.0170007052613 - 0.983544730758
111-2: 0.0140361397489 - 0.975870854697
112-2: 0.0155883416504 - 0.967183903204
112-

KeyboardInterrupt: 