# Longitudinal atrophy generation model training

## Load libraries and utilities

In [None]:
import os
import numpy as np
import nibabel as nib

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting
from utils.match_intensities import match_intensities

from wnet import generate_wnet_model

## Define cascaded U-Net parameters

In [None]:
wparams = {}
wparams['input_channels'] = 2
wparams['output_channels'] = 1
wparams['latent_channels'] = 16
wparams['scale'] = [0.5, 0.5]
wparams['use_combined_loss'] = True
wparams['patch_shape'] = (32, 32, 32)
wparams['loss_weights'] = [1, 1, 1, 1.5, 3]

model = generate_wnet_model(wparams)

model.summary()

## Loading training data
In this particular example, we use data from OASIS2 which has been registered to MNI and skull stripped previously. The folders used in this example are:
- db_location: file name pattern of input volumes
- db_seg_location: file name pattern of segmentation masks
- db_prob_location: file name pattern of segmentation probability maps

In [None]:
fname_pattern = 'OAS2_{0:04}_MR{1}'
db_location = 'datasets/OASIS/REG_MNI/{}Warped.nii.gz'
db_seg_location = 'datasets/OASIS/REG_MNI/{}Warped_seg.nii.gz'
db_prob_location = 'datasets/OASIS/REG_MNI/{}Warped_pve_{}.nii.gz'

# Step to extract patches
step = (16, 16, 16)

threshold = np.int32(0.30 * np.prod(wparams['patch_shape'][:]))
in_train_1 = np.empty((0, 1, ) + wparams['patch_shape'])
in_train_2 = np.empty((0, 1, ) + wparams['patch_shape'])
in_train_3 = np.empty((0, 1, ) + wparams['patch_shape'])
in_train_4 = np.empty((0, 1, ) + wparams['patch_shape'])
out_train = np.empty((0, 1, ) + wparams['patch_shape'])
for i in range(1, 100) :
    for j in range(1, 5) :
        for k in range(j+1, 5) :
            ref_filename = db_location.format(fname_pattern.format(i, j))
            ref_seg_filename = db_seg_location.format(fname_pattern.format(i, j))
            mov_filename = db_location.format(fname_pattern.format(i, k))
            mov_seg_filename = db_seg_location.format(fname_pattern.format(i, k))
            
            if not (os.path.exists(ref_filename) and os.path.exists(mov_filename)) :
                continue
            
            ref_seg_init = nib.load(ref_seg_filename).get_data() == 1
            mov_seg_init = nib.load(mov_seg_filename).get_data() == 1
            
            ref_volume = nib.load(ref_filename).get_data()
            ref_volume = ref_volume.reshape((1, ) + ref_volume.shape)
            
            mask_init = ref_volume != 0
            mask_patches = extract_patches(mask_init, (1, ) + wparams['patch_shape'], (1, ) + step)
            useful_patches = np.sum(mask_patches, axis=(1, 2, 3, 4)) > threshold
            del mask_patches
            
            mov_prob_init_1 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 0)).get_data()
            mov_prob_init_2 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 1)).get_data()
            mov_prob_init_3 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 2)).get_data()
            
            mov_prob_init_1 = mov_prob_init_1.reshape((1, ) + mov_prob_init_1.shape)
            mov_prob_init_2 = mov_prob_init_2.reshape((1, ) + mov_prob_init_2.shape)
            mov_prob_init_3 = mov_prob_init_3.reshape((1, ) + mov_prob_init_3.shape)
            
            vol_patches = extract_patches(ref_volume, (1, ) + wparams['patch_shape'], (1, ) + step)
            vol_patches = vol_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            prob_1_patches = extract_patches(mov_prob_init_1, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_1_patches = prob_1_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            prob_2_patches = extract_patches(mov_prob_init_2, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_2_patches = prob_2_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            prob_3_patches = extract_patches(mov_prob_init_3, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_3_patches = prob_3_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            in_train_1 = np.vstack((vol_patches, in_train_1)).astype('float32')
            in_train_2 = np.vstack((prob_1_patches, in_train_2)).astype('float32')
            in_train_3 = np.vstack((prob_2_patches, in_train_3)).astype('float32')
            in_train_4 = np.vstack((prob_3_patches, in_train_4)).astype('float32')
            
            del vol_patches, prob_1_patches, prob_2_patches, prob_3_patches

            mov_volume = nib.load(mov_filename).get_data()
            mov_volume = mov_volume.reshape((1, ) + mov_volume.shape)
            mov_patches = extract_patches(mov_volume, (1, ) + wparams['patch_shape'], (1, ) + step)
            mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + wparams['patch_shape'])
            
            out_train = np.vstack((mov_patches, out_train)).astype('float32')
            del mov_patches

## Training step

In [None]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 10

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/model.h5', save_best_only=True, save_weights_only=True)

N = len(in_train_1)
model.fit(
    [in_train_1, in_train_2, in_train_3, in_train_4],
    [np.multiply(out_train, in_train_2 >= 0.8),
     np.multiply(out_train, in_train_3 >= 0.8),
     np.multiply(out_train, in_train_4 >= 0.8),
     out_train,
     out_train],
    validation_split=0.3, epochs=100, batch_size=32,
    callbacks=[checkpointer, stopper])

## Load model leading to the lowest validation loss

In [None]:
model.load_weights('models/ag_mseloss_o2o1_2probs.h5')

## Model testing

In [None]:
path_to_results = 'outputs/'

ref_volume_train = 'datasets/OASIS/REG_MNI/OAS2_003_MR1Warped.nii.gz'
fname_pattern = 'ADNI_{0:03}_MR{1}'
db_location = 'datasets/ADNI/REG_MNI/{}Warped.nii.gz'
db_seg_location = 'datasets/ADNI/REG_MNI/{}Warped_seg.nii.gz'
db_prob_location = 'datasets/ADNI/REG_MNI/{}Warped_pve_{}.nii.gz'

curr_patch_shape = (32, 32, 32)
step = (8, 8, 8)
for i in range(1, 200) :
    for j in range(1, 5) :
        ref_filename = db_ref_location.format(fname_pattern.format(i, j))
        
        if not os.path.exists(ref_filename):
            continue
            
        ref_volume = nib.load(ref_filename).get_data()
        ref_volume = ref_volume.reshape((1, ) + ref_volume.shape)
        vol_patches = extract_patches(ref_volume, (1, ) + wparams['patch_shape'], (1, ) + step)
        
        for k in range(j+1, 5) :
            mov_filename = db_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k))
            mov_prob_filename = db_seg_location.format(fname_pattern.format(i, j), fname_pattern.format(i, k), 0)

            if not os.path.exists(mov_filename):
                continue

            mov_prob_init_1 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 0)).get_data()
            mov_prob_init_2 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 1)).get_data()
            mov_prob_init_3 = nib.load(db_prob_location.format(fname_pattern.format(i, k), 2)).get_data()
            
            mov_prob_init_1 = mov_prob_init_1.reshape((1, ) + mov_prob_init_1.shape)
            mov_prob_init_2 = mov_prob_init_2.reshape((1, ) + mov_prob_init_2.shape)
            mov_prob_init_3 = mov_prob_init_3.reshape((1, ) + mov_prob_init_3.shape)

            prob_1_patches = extract_patches(mov_prob_init_1, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_2_patches = extract_patches(mov_prob_init_2, (1, ) + wparams['patch_shape'], (1, ) + step)
            prob_3_patches = extract_patches(mov_prob_init_3, (1, ) + wparams['patch_shape'], (1, ) + step)

            pred = model.predict(
                [vol_patches, prob_1_patches, prob_2_patches, prob_3_patches], verbose=1, batch_size=64)[4]
            pred = pred.reshape((-1, ) + curr_patch_shape)

            volume = perform_voting(pred, curr_patch_shape, ref_volume.shape[1:], step)

            volume_data = nib.load(mov_filename)

            nib.save(nib.Nifti1Image(volume, volume_data.affine),
                     '{}/{}_{}_to_{}.nii.gz'.format(path_to_results, i, j, k)

            res = sitk.ReadImage('{}/{}_{}_to_{}.nii.gz'.format(path_to_results, i, j, k))

            caster = sitk.CastImageFilter()
            caster.SetOutputPixelType(res.GetPixelID())

            orig = caster.Execute(sitk.ReadImage(mov_filename))
            seg = caster.Execute(sitk.ReadImage(mov_prob_filename))

            enhanced_vol = match_intensities(orig, res)

            sitk.WriteImage(
                enhanced_vol,
                '{}/{}_{}_to_{}_cor.nii.gz'.format(path_to_results, i, j, k)