In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting

from wnet import generate_wnet_model

In [None]:
def match_intensities(ref, mov) :
    matcher = sitk.HistogramMatchingImageFilter()
    matcher.SetNumberOfHistogramLevels(256)
    matcher.SetNumberOfMatchPoints(15)
    matcher.SetThresholdAtMeanIntensity(True)
    return matcher.Execute(mov, ref) 

In [None]:
wparams = {}
wparams['input_channels'] = 2
wparams['output_channels'] = 1
wparams['latent_channels'] = 16
wparams['scale'] = [0.5, 0.5]
wparams['use_combined_loss'] = True
wparams['patch_shape'] = (32, 32, 32)
wparams['loss_weights'] = [1, 1, 1, 1.5, 3]

segparams = {}
segparams['seg_model_filename'] = ''
segparams['seg_model_params_filename'] = ''
segparams['segmentation_classes'] = 4

model = generate_wnet_model(wparams, segparams)

In [None]:
model.summary()

In [None]:
import os
import nibabel as nib

# fname_pattern = 'OAS2_{0:04}_MR{1}'
# db_location = '/mnt/harddisk/datasets/OASIS/REG_MNI/{}Warped.nii.gz'
# db_seg_location = '/mnt/harddisk/datasets/OASIS/REG_MNI/{}Warped_seg.nii.gz'
# db_prob_location = '/mnt/harddisk/datasets/OASIS/REG_MNI/{}Warped_pve_{}.nii.gz'

fname_pattern = 'ADNI_{0:03}_MR{1}'
db_location = '/mnt/harddisk/datasets/ADNI/REG_MNI/{}Warped.nii.gz'
db_seg_location = '/mnt/harddisk/datasets/ADNI/REG_MNI/{}Warped_seg.nii.gz'
db_prob_location = '/mnt/harddisk/datasets/ADNI/REG_MNI/{}Warped_pve_{}.nii.gz'

step = (16, 16, 16)
threshold = np.int32(0.30 * np.prod(wparams['patch_shape'][:]))
in_train_1 = np.empty((0, 1, ) + wparams['patch_shape'])
in_train_2 = np.empty((0, 1, ) + wparams['patch_shape'])
in_train_3 = np.empty((0, 1, ) + wparams['patch_shape'])
in_train_4 = np.empty((0, 1, ) + wparams['patch_shape'])
out_train = np.empty((0, 1, ) + wparams['patch_shape'])
for i in range(1, 200) :
    print '{} :'.format(i),
    
    for j in range(1, 5) :
        diff = 0
        for k in range(j+1, 5) :
            ref_filename = db_location.format(fname_pattern.format(i, j))
            ref_seg_filename = db_seg_location.format(fname_pattern.format(i, j))
            mov_filename = db_location.format(fname_pattern.format(i, k))
            mov_seg_filename = db_seg_location.format(fname_pattern.format(i, k))
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
            
            ref_seg_init = nib.load(ref_seg_filename).get_data() == 1
            mov_seg_init = nib.load(mov_seg_filename).get_data() == 1
                        
            delta_csf = np.float32(np.sum(np.not_equal(ref_seg_init, mov_seg_init))) / np.sum(ref_seg_init)
            
            diff = diff + delta_csf
            
            if diff > 0.80:
                continue
            
            ref_volume = nib.load(ref_filename).get_data()
            ref_volume = ref_volume.reshape((1, ) + ref_volume.shape)
            
            mask_init = ref_volume != 0
            mask_patches = extract_patches(mask_init, (1, ) + wparams['patch_shape'], (1, ) + step)
            useful_patches = np.sum(mask_patches, axis=(1, 2, 3, 4)) > threshold
            del mask_patches
            
            mov_prob_init_1 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 0)).get_data()
            mov_prob_init_2 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 1)).get_data()
            mov_prob_init_3 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 2)).get_data()
            
            mov_prob_init_1 = mov_prob_init_1.reshape((1, ) + mov_prob_init_1.shape)
            mov_prob_init_2 = mov_prob_init_2.reshape((1, ) + mov_prob_init_2.shape)
            mov_prob_init_3 = mov_prob_init_3.reshape((1, ) + mov_prob_init_3.shape)
            
            vol_patches = extract_patches(ref_volume, (1, ) + wparams['patch_shape'], (1, ) + step)
            vol_patches = vol_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            prob_1_patches = extract_patches(mov_prob_init_1, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_1_patches = prob_1_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            prob_2_patches = extract_patches(mov_prob_init_2, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_2_patches = prob_2_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            prob_3_patches = extract_patches(mov_prob_init_3, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_3_patches = prob_3_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            in_train_1 = np.vstack((vol_patches, in_train_1)).astype('float32')
            in_train_2 = np.vstack((prob_1_patches, in_train_2)).astype('float32')
            in_train_3 = np.vstack((prob_2_patches, in_train_3)).astype('float32')
            in_train_4 = np.vstack((prob_3_patches, in_train_4)).astype('float32')
            
            del vol_patches, prob_1_patches, prob_2_patches, prob_3_patches

            mov_volume = nib.load(mov_filename).get_data()
            mov_volume = mov_volume.reshape((1, ) + mov_volume.shape)
            mov_patches = extract_patches(mov_volume, (1, ) + wparams['patch_shape'], (1, ) + step)
            mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            out_train = np.vstack((mov_patches, out_train)).astype('float32')
            del mov_patches
            
            print '{}->{} ({:.2f})'.format(j, k, diff),
    print

In [None]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 10

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_mseloss_a1a2_2probs.h5', save_best_only=True, save_weights_only=True)

N = len(in_train_1)
model.fit(
    [in_train_1, in_train_2, in_train_3, in_train_4],
    [np.multiply(out_train, in_train_2 >= 0.8),
     np.multiply(out_train, in_train_3 >= 0.8),
     np.multiply(out_train, in_train_4 >= 0.8),
     out_train,
     out_train],
    validation_split=0.3, epochs=100, batch_size=32,
    callbacks=[checkpointer, stopper])

In [None]:
model.load_weights('models/ag_mseloss_o2o1_2probs.h5')

In [None]:
from keras.models import Model
import os
import nibabel as nib
from scipy.signal import medfilt
import SimpleITK as sitk
from dipy.denoise.non_local_means import nlmeans_block, non_local_means
from dipy.denoise.noise_estimate import estimate_sigma
from scipy.ndimage.morphology import binary_erosion

path_to_results = '/mnt/harddisk/Experiments/AG/ce_results/outputs'

ref_volume_train = '/mnt/harddisk/datasets/OASIS/REG_MNI/OAS2_0103_MR1Warped.nii.gz'
fname_pattern = 'ADNI_{0:03}_MR{1}'
db_ref_location = '/mnt/harddisk/Experiments/AG/ce_results/inputs/{}Warped.nii.gz'
db_location = '/mnt/harddisk/Experiments/AG/ce_results/inputs/{}_to_{}_{}.nii.gz'
db_seg_location = '/mnt/harddisk/Experiments/AG/ce_results/inputs/{}_to_{}_{}_seg.nii.gz'
db_prob_location = '/mnt/harddisk/Experiments/AG/ce_results/inputs/{}_to_{}_{}_pve_{}.nii.gz'

curr_patch_shape = (32, 32, 32)
step = (8, 8, 8)
for i in range(95, 200) :
    print '{} :'.format(i),
    for j in range(1, 5) :
        ref_filename = db_ref_location.format(fname_pattern.format(i, j))
        
        if not os.path.exists(ref_filename):
            continue
            
        ref_volume = nib.load(ref_filename).get_data()
        ref_volume = ref_volume.reshape((1, ) + ref_volume.shape)
        vol_patches = extract_patches(ref_volume, (1, ) + wparams['patch_shape'], (1, ) + step)
        
        for k in range(j+1, 5) :
            for factor in [25, 50, 75, 100] :
                mov_filename = db_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k), str(factor))
                mov_prob_filename = db_seg_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k), str(factor), 0)

                if not os.path.exists(mov_filename):
                    continue

                mov_prob_init_1 = nib.load(db_prob_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k), str(factor), 0)).get_data()
                mov_prob_init_2 = nib.load(db_prob_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k), str(factor), 1)).get_data()
                mov_prob_init_3 = nib.load(db_prob_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k), str(factor), 2)).get_data()
                
                mov_prob_init_1 = mov_prob_init_1.reshape((1, ) + mov_prob_init_1.shape)
                mov_prob_init_2 = mov_prob_init_2.reshape((1, ) + mov_prob_init_2.shape)
                mov_prob_init_3 = mov_prob_init_3.reshape((1, ) + mov_prob_init_3.shape)

                prob_1_patches = extract_patches(mov_prob_init_1, (1, ) + wparams['patch_shape'], (1, ) + step)
                prob_2_patches = extract_patches(mov_prob_init_2, (1, ) + wparams['patch_shape'], (1, ) + step)
                prob_3_patches = extract_patches(mov_prob_init_3, (1, ) + wparams['patch_shape'], (1, ) + step)

                pred = model.predict(
                    [vol_patches, prob_1_patches, prob_2_patches, prob_3_patches], verbose=1, batch_size=64)[4]
                pred = pred.reshape((-1, ) + curr_patch_shape)

                volume = perform_voting(pred, curr_patch_shape, ref_volume.shape[1:], step)

                icv_mask = nib.load(mov_filename).get_data() != 0
                volume = np.multiply(icv_mask, volume)
                csf_mask = mov_prob_init_1.reshape(volume.shape) == 1
                csf_mask = binary_erosion(csf_mask)
                volume = np.multiply(medfilt(volume, (3, 3, 3)), csf_mask) + np.multiply(volume, ~csf_mask)

                sigma_est = estimate_sigma(volume, disable_background_masking=False)

                volume = nlmeans_block(volume, mask=(volume != 0).astype('double'), patch_radius=1,
                                       block_radius=6, h=sigma_est[0], rician=True)

                volume_data = nib.load(mov_filename)

                nib.save(nib.Nifti1Image(volume, volume_data.affine),
                         '{}/{}_{}_to_{}_{}.nii.gz'.format(path_to_results, i, j, k, str(factor)))

                res = sitk.ReadImage('{}/{}_{}_to_{}_{}.nii.gz'.format(path_to_results, i, j, k, str(factor)))

                caster = sitk.CastImageFilter()
                caster.SetOutputPixelType(res.GetPixelID())

                orig = caster.Execute(sitk.ReadImage(mov_filename))
                seg = caster.Execute(sitk.ReadImage(mov_prob_filename))

                enhanced_vol = match_intensities(orig, res)

                sitk.WriteImage(
                    enhanced_vol,
                    '{}/{}_{}_to_{}_{}_cor.nii.gz'.format(path_to_results, i, j, k, str(factor)))
                print '{} - {}->{}'.format(i, j, k)