In [None]:
import os

from pyment.data import NiftiDataset, AsyncNiftiGenerator
from pyment.data.preprocessors import NiftiPreprocessor

ixi_folder = os.path.join(os.path.expanduser('~'), 'data', 'IXI', 'cropped')
project_folder = os.path.join(os.path.expanduser('~'), 'projects', '')
dataset = NiftiDataset.from_folder(ixi_folder, target='age')
preprocessor = NiftiPreprocessor(sigma=255.)
generator = AsyncNiftiGenerator(
    dataset=dataset,
    preprocessor=preprocessor,
    batch_size=4,
    threads=8
)

In [None]:
from pyment.models import RegressionSFCN

model = RegressionSFCN(weights='brain-age')

predictions = model.predict(generator)

In [None]:
from pyment.models import Model

encoder = Model(model.input, model.layers[25].output)

encodings = encoder.predict(generator)

In [None]:
import numpy as np


ages = dataset.y
predictions = predictions.squeeze()
predictions = predictions[np.where(~np.isnan(ages))]
ages = ages[np.where(~np.isnan(ages))]
delta = predictions - ages
print(f'Brain age delta: {round(np.mean(np.abs(delta)), 2)}')

In [None]:
from plotly.figure_factory import create_distplot

import matplotlib.pyplot as plt

from explainability import LRP, LRPStrategy

alpha=2
beta=1

strategy = LRPStrategy(
    layers=[
        {'flat': True},
        {'flat': True},
        {'alpha': alpha, 'beta': beta},
        {'alpha': alpha, 'beta': beta},
        {'alpha': alpha, 'beta': beta},
        {'alpha': alpha, 'beta': beta},
        {'epsilon': 0.25}
    ])

lrp = LRP(model, layer=len(model.layers)-1, idx=0, strategy=strategy)

for X, y in generator:
    preds = model.predict(X)
    explanations = lrp(X[:1])[0].numpy()
    mask = np.zeros(X[0].shape)
    mask[np.where(X[0] != 0)] = 1
    explanations = explanations * mask
    explanations = explanations / np.amax(np.abs(explanations))
    idx = np.argmax(np.abs(explanations))
    idx = np.unravel_index(idx, explanations.shape)
    
    fig, ax = plt.subplots(6, 8, figsize=(15, 15))
    
    for i in range(-4, 4):
        ax[0][i+4].imshow(np.rot90(X[0,idx[0]+i]), cmap='Greys_r')
        ax[0][i+4].axis('off')
        ax[1][i+4].imshow(np.rot90(explanations[idx[0]+i]), cmap='seismic', clim=(-1, 1))
        ax[1][i+4].axis('off')
        ax[2][i+4].imshow(np.rot90(X[0,:,idx[1]+i]), cmap='Greys_r')
        ax[2][i+4].axis('off')
        ax[3][i+4].imshow(np.rot90(explanations[:,idx[1]+i]), cmap='seismic', clim=(-1, 1))
        ax[3][i+4].axis('off')
        ax[4][i+4].imshow(X[0,:,:,idx[2]+i], cmap='Greys_r')
        ax[4][i+4].axis('off')
        ax[5][i+4].imshow(explanations[:,:,idx[2]+i], cmap='seismic', clim=(-1, 1))
        ax[5][i+4].axis('off')

    break

    plt.show()

In [None]:
from tqdm import tqdm


images = []
labels = []
predictions = []
all_explanations = np.zeros((len(generator),) +  X[0].shape)

generator.reset()

for i, (X, y) in tqdm(enumerate(generator), total=generator.batches):
    for j in range(len(X)):
        image = X[j]
        labels.append(y[j])
        predictions.append(model.predict(np.expand_dims(image, 0))[0])
        expl = lrp.predict(np.expand_dims(image, 0))[0]
        mask = np.zeros(image.shape)
        mask[np.where(image != 0)] = 1
        expl = expl * mask
        all_explanations[i*4 + j] = expl
        images.append(image)

In [None]:
im1 = all_explanations[0][80]
im1 = im1 / np.amax(np.abs(im1))
im1 = np.rot90(im1)
plt.imshow(im1, cmap='seismic', clim=(-1, 1))¢
plt.show()

im2 = mean_explanation[80]
im2 = im2 / np.amax(np.abs(im2))
im2 = np.rot90(im2)
plt.imshow(im2, cmap='seismic', clim=(-1, 1))
plt.show()

im3 = im1 - im2
plt.imshow(im3, cmap='seismic', clim=(-1, 1))
plt.show()

brain = images[0][80]
brain = np.rot90(brain)
plt.imshow(brain, cmap='Greys_r', clim=(0, 1))
plt.show()

In [None]:
mean_explanation = np.mean(all_explanations, axis=0)
mean_explanation = mean_explanation / np.amax(np.abs(mean_explanation))
idx = (np.asarray(mean_explanation.shape) / 2).astype(int)

fig, ax = plt.subplots(6, 8, figsize=(15, 15))

for i in range(-4, 4):
    ax[0][i+4].imshow(np.rot90(X[0,idx[0]+i]), cmap='Greys_r')
    ax[0][i+4].axis('off')
    ax[1][i+4].imshow(np.rot90(mean_explanation[idx[0]+i]), cmap='seismic', clim=(-1, 1))
    ax[1][i+4].axis('off')
    ax[2][i+4].imshow(np.rot90(X[0,:,idx[1]+i]), cmap='Greys_r')
    ax[2][i+4].axis('off')
    ax[3][i+4].imshow(np.rot90(mean_explanation[:,idx[1]+i]), cmap='seismic', clim=(-1, 1))
    ax[3][i+4].axis('off')
    ax[4][i+4].imshow(X[0,:,:,idx[2]+i], cmap='Greys_r')
    ax[4][i+4].axis('off')
    ax[5][i+4].imshow(mean_explanation[:,:,idx[2]+i], cmap='seismic', clim=(-1, 1))
    ax[5][i+4].axis('off')

plt.show()

In [None]:
for i in range(len(all_explanations)):
    e = all_explanations[i]
    e = e / np.amax(np.abs(e))
    idx = np.argmax(np.abs(e))
    idx = np.unravel_index(idx, e.shape)
    
    fig, ax = plt.subplots(6, 8, figsize=(15, 15))
    
    for j in range(-4, 4):
        ax[0][j+4].imshow(np.rot90(images[i][idx[0]+j]), cmap='Greys_r')
        ax[0][j+4].axis('off')
        ax[1][j+4].imshow(np.rot90(e[idx[0]+j]), cmap='seismic', clim=(-1, 1))
        ax[1][j+4].axis('off')
        ax[2][j+4].imshow(np.rot90(images[i][:,idx[1]+j]), cmap='Greys_r')
        ax[2][j+4].axis('off')
        ax[3][j+4].imshow(np.rot90(e[:,idx[1]+j]), cmap='seismic', clim=(-1, 1))
        ax[3][j+4].axis('off')
        ax[4][j+4].imshow(images[i][:,:,idx[2]+j], cmap='Greys_r')
        ax[4][j+4].axis('off')
        ax[5][j+4].imshow(e[:,:,idx[2]+j], cmap='seismic', clim=(-1, 1))
        ax[5][j+4].axis('off')
    
    plt.show()

In [None]:
import matplotlib.pyplot as plt

from matplotlib import cm
from PIL import Image, ImageDraw, ImageFont
from typing import Tuple


def pad_to_size(image, size: int = 212, value: Tuple = 0):
    vertical = size - image.shape[0]
    top = int(np.ceil(vertical / 2))
    bottom = vertical - top
    
    horizontal = size - image.shape[1]
    left = int(np.ceil(horizontal / 2))
    right = horizontal - left
    
    return np.pad(image, ((top, bottom), (left, right)), constant_values=value)

def concat_horizontal(i1, i2, color=(0, 0, 0)):
    dst = Image.new('RGB', (i1.width + i2.width, i1.height))
    dst.paste(i1, (0, 0))
    dst.paste(i2, (i1.width, 0))
    return dst

def concat_vertical(i1, i2):
    dst = Image.new('RGB', (i1.width, i1.height + i2.height))
    dst.paste(i1, (0, 0))
    dst.paste(i2, (0, i1.height))
    return dst

idx = np.argsort([pred[0] for pred in predictions])
sorted_labels = [labels[i] for i in idx]
sorted_predictions = [predictions[i] for i in idx]
sorted_images = [images[i] for i in idx]
sorted_explanations = [all_explanations[i] for i in idx]

sorted_bitmaps = []

for i in tqdm(range(len(images))):
    expl = sorted_explanations[i]
    expl = expl / np.amax(np.abs(expl))
    expl = expl + 0.5
    
    
    saggital_image = sorted_images[i][84]
    saggital_image = np.rot90(saggital_image)
    saggital_image = pad_to_size(saggital_image)
    saggital_explanations = expl[84]
    saggital_explanations = np.rot90(saggital_explanations)
    saggital_explanations = pad_to_size(saggital_explanations, value=0.5)
    saggital_image = Image.fromarray(np.uint8(cm.Greys_r(saggital_image)*255))
    saggital_explanations = Image.fromarray(np.uint8(cm.seismic(saggital_explanations)*255))
    
    coronal_image = sorted_images[i][:,106]
    coronal_image = np.rot90(coronal_image)
    coronal_image = pad_to_size(coronal_image)
    coronal_explanations = expl[:,106]
    coronal_explanations = np.rot90(coronal_explanations)
    coronal_explanations = pad_to_size(coronal_explanations, value=0.5)
    coronal_image = Image.fromarray(np.uint8(cm.Greys_r(coronal_image)*255))
    coronal_explanations = Image.fromarray(np.uint8(cm.seismic(coronal_explanations)*255))
    
    axial_image = sorted_images[i][:,:,80]
    axial_image = np.rot90(axial_image)
    axial_image = pad_to_size(axial_image)
    axial_explanations = expl[:,:,80]
    axial_explanations = np.rot90(axial_explanations)
    axial_explanations = pad_to_size(axial_explanations, value=0.5)
    axial_image = Image.fromarray(np.uint8(cm.Greys_r(axial_image)*255))
    axial_explanations = Image.fromarray(np.uint8(cm.seismic(axial_explanations)*255))
    
    brain_bitmap = concat_horizontal(concat_horizontal(saggital_image, coronal_image), axial_image)
    explanations_bitmap = concat_horizontal(concat_horizontal(saggital_explanations, coronal_explanations),
                                            axial_explanations)
    bitmap = concat_vertical(brain_bitmap, explanations_bitmap)
    
    draw = ImageDraw.Draw(bitmap)
    font = ImageFont.truetype('arial.ttf', 20)
    draw.text((180, 180),f'Age={sorted_labels[i]:.2f}, brain age {sorted_predictions[i][0]:.2f}', 
              (255,255,255), font=font)
    
    sorted_bitmaps.append(bitmap)
    
sorted_bitmaps[0].save('/home/esten/demo.gif',
               save_all=True, append_images=sorted_bitmaps[1:], optimize=False, duration=40, loop=0)

In [None]:
import tensorflow as tf
import time

def correlate(a, b):
    numerator = np.sum(a * b)
    sums = np.sum(a ** 2) * np.sum(b ** 2)
    denominator = np.sqrt(sums)
    
    return numerator / denominator

correlations = np.zeros((len(all_explanations), len(all_explanations)))

start = time.time()

for i in tqdm(range(len(all_explanations))):
    for j in range(len(all_explanations)):
        correlations[i,j] = correlate(
            all_explanations[i] / np.abs(np.amax(all_explanations[i])),
            all_explanations[j] / np.abs(np.amax(all_explanations[j]))
        )

In [None]:
idx = np.argsort(labels)
sorted_correlations = correlations[idx][:,idx]

fig = plt.figure(figsize=(15, 15))
heatmap = plt.imshow(sorted_correlations, cmap='YlGnBu', clim=(0, 1))
plt.colorbar(heatmap)
plt.xticks(np.arange(0, 600, 100), [round(labels[idx[i]], 2) for i in np.arange(0, 600, 100)])
plt.xlabel('Chronological age')
plt.yticks(np.arange(0, 600, 100), [round(labels[idx[i]], 2) for i in np.arange(0, 600, 100)])
plt.ylabel('Chronological age')
plt.savefig('/home/esten/sorted_correlations.png')
plt.show()

In [None]:
labels = dataset.labels
groups = {}

for i in range(len(encodings)):
    sex = dataset.labels['sex'][i]
    age = dataset.labels['age'][i]
    scanner = dataset.ids[i].split('-')[1]
    key = f'{scanner}-{sex}-{int(np.round(age))}'
    
    if not key in groups:
        groups[key] = []
        
    groups[key].append(encodings[i])
    
groups = {key: groups[key] for key in groups if len(groups[key]) >= 5}