In [None]:
%matplotlib inline

import os

from keras import backend as K
from keras.models import load_model

import matplotlib.pyplot as plt
import numpy as np

from models import get_single_modality_model, get_stream_model, get_two_path_cascade_local, get_type_2_model

data_dir = 'data'

In [None]:
pet_local = np.load(os.path.join(data_dir, 'pet_local_vis.npy'))
ct = np.load(os.path.join(data_dir, 'ct_vis.npy'))
pet = np.load(os.path.join(data_dir, 'pet_vis.npy'))
y = np.load(os.path.join(data_dir, 'y_vis.npy'))
ct_image = np.squeeze(np.load(os.path.join(data_dir, 'ct_image.npy')))
pet_image = np.squeeze(np.load(os.path.join(data_dir, 'pet_image.npy')))
y_image = np.squeeze(np.load(os.path.join(data_dir, 'y_image.npy')))

In [None]:
def segment(preds, name, ct=False):
    preds = preds.reshape(pet_image.shape)
    subplot_kw ={'adjustable': 'box-forced', 'aspect': 'equal'}
    
    if ct:
        f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharey=True, subplot_kw=subplot_kw)
        ax1.imshow(ct_image)
        ax2.imshow(pet_image)
        ax3.imshow(preds)
        ax4.imshow(y_image)        
        
        ax1.set_title('CT')
        ax2.set_title('PET')
        ax3.set_title('Prediction')
        ax4.set_title('Label')

    else:
        f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, subplot_kw=subplot_kw)
        ax1.imshow(pet_image)
        ax2.imshow(preds)
        ax3.imshow(y_image)
        
        ax1.set_title('PET')
        ax2.set_title('Prediction')
        ax3.set_title('Label')
    
    f.savefig(f'{name}.png', bbox_inches='tight')
    plt.show()

def graph_filters(filters, name):
    num_filters = len(filters)
    subplot_kw = {'adjustable': 'box-forced', 'aspect': 'equal'}
    
    f, axes = plt.subplots(1, num_filters, sharey=True, subplot_kw=subplot_kw)
    for i in range(num_filters):
        ax = axes[i]
        ax.imshow(filters[i])
    
    f.savefig(f'{name}.png', bbox_inches='tight')
    plt.show()

In [None]:
def deprocess(x):
    x -= x.mean()
    x /= (x.std() + K.epsilon())
    return x

def get_vis_filter_step(input_img, layer_filter):
    loss = K.mean(layer_filter)
    grads = K.gradients(loss, input_img)
    grads = [grad / (K.sqrt(K.mean(K.square(grad))) + K.epsilon()) if grad is not None else K.constant(0) for grad in grads]
    return K.function(input_img, [loss] + grads)

def vis_filter(vis_filter_step, num_iters, input_shape, num_inputs=1):
    inputs = [(np.random.random(input_shape) - 0.5) * 2 * np.sqrt(3) for i in range(num_inputs)]
    
    for i in range(num_iters):
        outputs = vis_filter_step(inputs)
        loss = outputs[0]
        grads = outputs[1:]
        inputs = list(map(sum, zip(inputs, grads * 1)))
        
        if loss <= 0.:
            break

    return inputs, loss

def vis_layer(model, layer, best=4, num_iters=20, print_every=10, input_i=None):
    kept_filters = []
    if isinstance(model.input, list):
        input_shape = [1] + model.input[0].shape[1:].as_list()
        num_inputs = len(model.input)
    else:
        input_shape = [1] + model.input.shape[1:].as_list()
        num_inputs = 1
        
    if input_i is not None:
        input_img = [model.input[input_i]]
    elif not isinstance(model.input, list):
        input_img = [model.input]
    else:
        input_img = model.input
        
    for i in range(layer.filters):
        layer_filter = layer.output[:, :, :, i]
        vis_filter_step = get_vis_filter_step(input_img, layer_filter)
        
        inputs, loss = vis_filter(vis_filter_step, num_iters, input_shape, num_inputs=num_inputs)
        imgs = [deprocess(input_img) for input_img in inputs]
        if loss > 0:
            kept_filters.append((imgs, loss))
        
        if (i + 1) % print_every == 0:
            print(f'{i + 1} Filters done')
            
    kept_filters.sort(key=lambda x: x[1], reverse=True)
    
    return kept_filters[:best]

def vis_layers(model, layer_names, **kwargs):
    layer_dict = {layer.name: layer for layer in model.layers}
    filters_dict = {}
    for layer_name in layer_names:
        layer = layer_dict[layer_name]
        best_filters = vis_layer(model, layer, **kwargs)
        filters_dict[layer_name] = best_filters
        print(f'Vis for layer {layer_name}')
    
    return filters_dict

def get_best_filters(filter_dict, layer_names):
    best_filters = []
    for layer_name in layer_names:
        best_filters.append(np.squeeze(filter_dict[layer_name][0][0][0]))
    return best_filters

In [None]:
K.clear_session()
model_II = get_type_2_model()
model_II.load_weights('best_type_II_model.h5')

In [None]:
preds_II = model_II.predict([ct, pet]) >= 0.5
segment(preds_II, 'type_II', ct=True)

In [None]:
ct_filters = vis_layers(model_II, [f'conv2d_{i}' for i in range(1, 5)], input_i=0, best=1)
pet_filters = vis_layers(model_II, [f'conv2d_{i}' for i in range(5, 9)], input_i=1, best=1)
best_type_II_filters = get_best_filters(ct_filters, [f'conv2d_{i}' for i in range(1, 5)]) + \
    get_best_filters(pet_filters, [f'conv2d_{i}' for i in range(5, 9)])
graph_filters(best_type_II_filters, 'type_II_filters')

In [None]:
K.clear_session()
model_pet = get_single_modality_model()
model_pet.load_weights('best_pet_model.h5')

In [None]:
preds_pet = model_pet.predict(pet) >= 0.5
segment(preds_pet, 'pet')

In [None]:
pet_only_filters = vis_layers(model_pet, [f'conv2d_{i}' for i in range(1, 5)], best=1)
best_pet_only_filters = get_best_filters(pet_only_filters, [f'conv2d_{i}' for i in range(1, 5)])
graph_filters(best_pet_only_filters, 'pet_filters')

In [None]:
K.clear_session()
model_local = get_two_path_cascade_local(get_stream_model, mode='pet', maxout=True, dropout=False)
model_local.load_weights('best_local_model.h5')

In [None]:
preds_local = model_local.predict(pet_local) >= 0.5
segment(preds_local, 'local')

In [None]:
local_filters = vis_layers(model_local, [f'conv2d_{i}' for i in range(8, 14)], best=1)
best_local_filters = get_best_filters(local_filters, [f'conv2d_{i}' for i in range(8, 14)])
graph_filters(best_local_filters, 'local_filters')

In [None]:
K.clear_session()
stream_model = get_stream_model(49, 21, mode='pet', n_feature_maps=2, maxout=True, dropout=False)
stream_model.load_weights('local_stream_pet.h5')

In [None]:
stream_filters = vis_layers(stream_model, [f'conv2d_{i}' for i in range(1, 7)], best=1)
best_stream_filters = get_best_filters(stream_filters, [f'conv2d_{i}' for i in range(1, 7)])
graph_filters(best_stream_filters, 'stream_filters')