In [None]:
import os
os.chdir('/local/home/mhoerold/entrack')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import yaml
import nibabel as nib
import copy
import tensorflow as tf

from src.baum_vagan.vagan.model_wrapper import VAGanWrapper
from src.baum_vagan.utils import ncc
from src.data.streaming.vagan_streaming import MRIImagePair, AgeFixedDeltaStream
from src.baum_vagan.utils import map_image_to_intensity_range
from src.data.streaming.mri_streaming import MRISingleStream

In [None]:
def load_wrapper(smt_label):
    config_path = os.path.join("data", smt_label, "config.yaml")
    model_dir = os.path.join("data", smt_label, "logdir")
    with open(config_path, 'r') as f:
        model_config = yaml.load(f)
    wrapper = VAGanWrapper(**model_config)
    wrapper.vagan.load_weights(model_dir)
    
    return wrapper

In [None]:
hc_wrapper = load_wrapper('20180823-185855')
ad_wrapper = load_wrapper('20180823-185845')

In [None]:
test_stream = lambda bs: ad_wrapper.data.testAD.next_batch(bs)[0]

In [None]:
dump_dir = os.path.join('notebooks/vagan_generated/hc_vs_ad')
if not os.path.exists(dump_dir):
    os.makedirs(dump_dir)

## Find train and validation patients

In [None]:
train_val_pairs = ad_wrapper.data.train_pairs + ad_wrapper.data.val_pairs + hc_wrapper.data.train_pairs + hc_wrapper.data.val_pairs
train_patient_ids = set()
for pair in train_val_pairs:
    patient_id = hc_wrapper.data.get_patient_id(pair.fid1)
    train_patient_ids.add(patient_id)

## Find all test patients

In [None]:
single_config = copy.deepcopy(ad_wrapper.data.config)
single_config["use_diagnoses"] = ['healthy', 'health_ad']
single_stream = MRISingleStream(single_config)

In [None]:
all_patient_ids = set()
for fid in single_stream.all_file_ids:
    all_patient_ids.add(single_stream.get_patient_id(fid))
    
hc_test_fids = set()
ad_test_fids = set()
for fid in single_stream.all_file_ids:
    pid = single_stream.get_patient_id(fid)
    if pid not in train_patient_ids:
        diag = single_stream.get_diagnose(fid)
        if diag == "healthy":
            hc_test_fids.add(fid)
        elif diag == "health_ad":
            ad_test_fids.add(fid)
            
hc_test_fids = list(hc_test_fids)
ad_test_fids = list(ad_test_fids)
np.random.seed(11)
np.random.shuffle(hc_test_fids)
np.random.shuffle(ad_test_fids)

## Make predictions

In [None]:
def iterate_model(model, img, n_steps):
    images = []
    masks = []
    # placeholder needs a second channel not used by generator
    delta_channel = img * 0 + 1.0
    img = np.concatenate((img, delta_channel, img), axis=-1)
    img = np.array([img])  # make a batch of size 1
    for _ in range(n_steps):
        M = model.predict_mask(img)
        masks.append(np.squeeze(M))
        img += M
        img[:, :, :, 1] = delta_channel[:, :, 0]
        images.append(np.copy(np.squeeze(img[:, :, :, 0])))
        # placeholder needs a second channel not used by generator
        # img = np.concatenate((img, img), axis=-1)
        
    return images, masks

def compare_hc_ad_images_and_masks(x_t0, t0, hc_images, hc_masks, ad_images, ad_masks):
    # plot predictions
    nrows = 5
    ncols = len(hc_images) + 1
    fsize = 4
    plt.figure(figsize=(ncols * fsize, nrows * fsize))
    # plot hc images
    plt.subplot(nrows, ncols, 1)
    plt.imshow(np.squeeze(x_t0), cmap='gray')
    plt.title("x_t0, age={}".format(str(t0)))
    plt.axis('off')

    for i, img in enumerate(hc_images):
        plt.subplot(nrows, ncols, i + 2)
        plt.imshow(np.squeeze(img), cmap='gray')
        plt.title('HC Generated x_t{}'.format(i + 1))
        plt.axis('off')

    # plot masks
    # concatenate differnce maps to plot with same scale
    plt.subplot(nrows, ncols, (ncols + 1, ncols + len(hc_masks)))
    mask_slices = [m for m in hc_masks]
    mask_slices_im = np.hstack(tuple(mask_slices))
    plt.imshow(mask_slices_im, cmap='bwr', vmin=-2, vmax=2)
    plt.title("HC Generated difference maps")
    plt.axis('off')

    # plot ad images
    plt.subplot(nrows, ncols, 2 * ncols + 1)
    plt.imshow(np.squeeze(x_t0), cmap='gray')
    plt.title("x_t0, age={}".format(str(t0)))
    plt.axis('off')

    for i, img in enumerate(ad_images):
        plt.subplot(nrows, ncols, 2 * ncols + i + 2)
        plt.imshow(np.squeeze(img), cmap='gray')
        plt.title('AD Generated x_t{}'.format(i + 1))
        plt.axis('off')

    # plot masks
    # concatenate differnce maps to plot with same scale
    plt.subplot(nrows, ncols, (3 * ncols + 1, 3 * ncols + len(hc_masks)))
    mask_slices = [m for m in ad_masks]
    mask_slices_im = np.hstack(tuple(mask_slices))
    plt.imshow(mask_slices_im, cmap='bwr', vmin=-2, vmax=2)
    plt.title("AD Generated difference maps")
    plt.axis('off')
    
    # HC to AD change
    hc_ad_change = []
    for ad_im, hc_im in zip(ad_images, hc_images):
        change = np.squeeze(ad_im - hc_im)
        hc_ad_change.append(change)
        
    plt.subplot(nrows, ncols, (4 * ncols + 1, 4 * ncols + len(hc_masks)))
    mask_slices = [m for m in hc_ad_change]
    mask_slices_im = np.hstack(tuple(mask_slices))
    plt.imshow(mask_slices_im, cmap='bwr', vmin=-2, vmax=2)
    plt.title("AD-HC")
    plt.axis('off')

    
def plot_iterative_predictions(hc_model, ad_model, fid, delta):
    # Use normalization of trained model
    # print(pair.get_age_delta())
    # print(pair.streamer.get_image_label(pair.fid1))
    t0 = round(single_stream.get_exact_age(fid), 2)
    
    # get some pair and load image
    some_pair = hc_wrapper.data.train_pairs[0]
    x_t0 = some_pair.load_image(fid)

    n_steps = delta
    hc_images, hc_masks = iterate_model(hc_model, x_t0, n_steps)
    ad_images, ad_masks = iterate_model(ad_model, x_t0, n_steps)

    # dump
    pid = single_stream.get_patient_id(fid)
    out_path = os.path.join(dump_dir, "{}.npz".format(pid))
    np.savez(out_path, hc_fake=hc_images, ad_fake=ad_images)
    
    compare_hc_ad_images_and_masks(x_t0, t0, hc_images, hc_masks, ad_images, ad_masks)
    
    

plot_iterative_predictions(hc_wrapper.vagan, ad_wrapper.vagan, ad_test_fids[2], 5)

## Very far predictions

In [None]:
def plot_far_prediction(hc_model, ad_model, fid, deltas=[10, 20, 30]):
    # Use normalization of trained model
    # print(pair.get_age_delta())
    # print(pair.streamer.get_image_label(pair.fid1))
    t0 = round(single_stream.get_exact_age(fid), 2)
    
    # get some pair and load image
    some_pair = hc_wrapper.data.train_pairs[0]
    x_t0 = some_pair.load_image(fid)
    
    def get_slice(img):
        return np.squeeze(img)
    
    def far_preds(x, model):
        cur_inp = x
        preds = []
        for i, delta in enumerate(deltas):
            steps = delta
            if i > 0:
                steps = deltas[i] - deltas[i - 1]

            images, _ = iterate_model(model, cur_inp, steps)
            cur_inp = images[len(images) - 1][:, :]
            cur_inp = np.reshape(cur_inp, tuple(list(cur_inp.shape) + [1]))
            preds.append(np.copy(cur_inp))
            
        diff_maps = []
        diff_maps.append(get_slice(preds[0] - x))
        for i in range(1, len(deltas)):
            diff_map = get_slice(preds[i] - preds[i - 1])
            diff_maps.append(diff_map)
            
        return preds, diff_maps
    
    hc_images, hc_masks = far_preds(x_t0, hc_model)
    ad_images, ad_masks = far_preds(x_t0, ad_model)
    compare_hc_ad_images_and_masks(x_t0, t0, hc_images, hc_masks, ad_images, ad_masks)
    
plot_far_prediction(hc_wrapper.vagan, ad_wrapper.vagan, ad_test_fids[2], deltas=[10, 20, 30])