In [None]:
%matplotlib inline

import io

import cv2
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
import requests
from PIL import Image

import lucid
import lucid.modelzoo.vision_models as models
import lucid.optvis.objectives as objectives
import lucid.optvis.param as param
import lucid.optvis.render as render
import tensorflow as tf
from sklearn.decomposition import NMF

In [None]:
model = models.InceptionV1()
model.load_graphdef()

# download a mapping of the imagenet class ids to text
imagenet_classes_request = requests.get("https://gist.githubusercontent.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57/raw/aa66dd9dbf6b56649fa3fab83659b2acbf3cbfd1/map_clsloc.txt")
imagenet_classes_list = [None] + [a.split(' ')[2] for a in str(imagenet_classes_request.content).split('\\n')]

# download an image from flickr
response = requests.get("http://c1.staticflickr.com/5/4070/5148597478_0c34ec0b7e_n.jpg")
image = Image.open(io.BytesIO(response.content))
image = np.array(image, dtype=np.float32)

In [None]:
# compute a k-NMF of data A
def nfm(k, A):
    factorizer = NMF(n_components=k)
    shape = A.shape
    nmf_shape = (np.prod(shape[:-1]), shape[-1])
    U = factorizer.fit_transform(A.reshape(nmf_shape))
    V = factorizer.components_
    U = U.reshape(shape[1:-1] + (k,) )
    V = V.reshape((k,shape[-1]) )
    
    return U, V

# compute the necessary forward and backward passes on a model in order to obtain
# the activations at the desired layer, the gradients for specific logits
# and the evaluation of the logits
# activations is a tensor with the actions at layer `layer`
# grads is a list where each entry is a tensor with the gradient at layer `layer
# from the class-logit with the same position in the array classes
def foward_pass_and_gradients(model, image, classes, layer):
    with tf.Graph().as_default(), tf.Session():
        t_input = tf.placeholder_with_default(image, [None, None, 3])
        T = render.import_model(model, t_input, t_input)
        activations = T(layer).eval()
        logits = T("softmax2_pre_activation")
        grads = []
        for c in classes:
            t_grad = tf.gradients([logits[0, c]], [T(layer)])[0]   
            grad = t_grad.eval({T(layer) : activations})
            grads.append(grad)
        return activations, grads, logits.eval() 


In [None]:
def visualize_group(model, layer, weights, color, color_strip_height=0.1, visualization_size=100, diversity=1, lambda_diversity=1e3):    
    color_strip_height = max(int(color_strip_height * visualization_size), 1)
    color_strip = np.ones((color_strip_height, visualization_size, 3)) * color
    param_f = lambda: param.image(visualization_size, batch=diversity)
    obj = objectives.direction(layer, weights.ravel())
    if diversity > 1:
         obj -= lambda_diversity*objectives.diversity(layer)
    visualization = render.render_vis(model, obj, param_f, verbose=False)[-1]
    visualization = list(map(lambda x: x[0, ...], np.split(visualization, diversity)))
    print(".", end='', flush=True) # progress indicator as this step is slow
    visualization[0] = np.vstack([color_strip, visualization[0]])
    return visualization

def get_group_colors(k):
    colors = matplotlib.cm.inferno.colors
    return [np.array(colors[int((j+1)*(256-1)/(k+1))]) for j in range(k)]

def get_saliency_map(groups, colors, treshold):
    k = groups.shape[-1]
    saliency_map = np.zeros(groups.shape[:-1] + (4,))
    for j in range(k):  
        c = colors[j]
        intensities = groups[:, :, j] / groups[:, :, j].max()
        cmap = np.tile(np.expand_dims(intensities, -1), (1, 1, 3))
        cmap *= np.tile(c[None, None, :], groups.shape[:-1] + (1,))
        cmap[groups[:, :, j] < treshold * groups[:, :, j].max()] = 0
        saliency_map[:, :, :3] += cmap
        saliency_map[:, :, 3] = np.maximum(intensities, saliency_map[:, :, 3])
    return saliency_map

def compute_effects(groups, weights, gradients):
    k = groups.shape[-1]
    nr_classes = len(gradients)
    effects = np.zeros((nr_classes, k))
    for j in range(k):
        groupweights = groups[:, :, j, None] * weights[j, :]
        for i in range(nr_classes):
            effects[i, j] = np.sum(groupweights * gradients[i][0, ...])
    return effects

def show(image, saliency_map, group_visualizations, layer):
    k = len(group_visualizations)
    d = len(group_visualizations[0])
    f, axarr = plt.subplots(1, 3, figsize=(18, 4))
    f.suptitle(layer + ' visualized with ' + str(k) + ' groups', fontsize=16)
    axarr[0].axis('off')
    axarr[0].imshow(image/255.0)    
    axarr[1].axis('off')
    axarr[1].imshow(saliency_map)    
    axarr[2].axis('off')
    axarr[2].imshow(image/255.0)    
    saliency_map_large = cv2.resize(saliency_map, dsize=(image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
    axarr[2].imshow(saliency_map_large)
    f, axarr = plt.subplots(d, k, figsize=(18, 4*d))
    axarr = axarr.reshape((d, k))
    for j in range(k):
        for k in range(d):
            axarr[k, j].axis('off')
            axarr[k, j].imshow(group_visualizations[j][k])

def print_effects(effects, target_classes):
    _, k = effects.shape
    print('Effects:')
    print('Class/Group', end='\t')
    for j in range(k):
        print(j, end='\t')    
    print()
    for s, classname in enumerate(target_classes):
        print(classname, end='\t')
        for j in range(k):
            print(effects[s, j].round(2), end='\t')
        print()

def building_blocks_of_interpretability(model, image, layer, classes, imagenet_classes_list, k=5,
                                        saliency_map_treshold=0.7, visualization_size=100, diversity=1,
                                        lambda_diversity=1e3):
    layer_names = list(map(lambda x: x.name, model.layers))
    assert 1 <= k
    assert layer in layer_names
    classes_ids = list(map(lambda x: imagenet_classes_list.index(x), classes))
    activations, gradients, logits = foward_pass_and_gradients(model, image, classes_ids, layer)
    top_classes = logits[0, :].ravel().argsort()[::-1][:10]
    print('Top 10 classes:')
    print(list(map(lambda x: imagenet_classes_list[x], top_classes.tolist())))
    print('Computing NMF')
    groups, weights = nfm(k, activations)
    print('Visualizing Groups')
    colors = get_group_colors(k)
    group_visualizations = [visualize_group(model, layer, weights[j, :], colors[j], visualization_size=visualization_size, diversity=diversity, lambda_diversity=lambda_diversity) for j in range(k)]
    print()
    print('Creating Saliency Map')
    saliency_map = get_saliency_map(groups, colors, saliency_map_treshold)
    print('Computing effects')
    effects = compute_effects(groups, weights, gradients)
    print('')
    print('')
    print('')
    show(image, saliency_map, group_visualizations, layer)
    print_effects(effects, classes)   
    
    return group_visualizations, colors, saliency_map, effects
    

Note: If you are running on a CPU visualzing a 60x60 image takes roughly 1 or 2 minutes on a modern laptop. So during developement/debugging you will want to keep k and the visualization_size small or move to a GPU (for example move the notebook to colab, where you did the lucid tutorials, which allows free GPU access for 2 hour sessions).

In [None]:
layer_names = list(map(lambda x: x.name, model.layers))
print(layer_names)

In [None]:
layer = 'mixed5a'
target_classes = ['bloodhound', 'tiger_cat']

In [None]:
group_visualizations, colors, saliency_map, effects =\
building_blocks_of_interpretability(model, image, layer, target_classes, imagenet_classes_list,
                                    k=2, visualization_size=20, diversity=1)

In [None]:
group_visualizations, colors, saliency_map, effects =\
building_blocks_of_interpretability(model, image, layer, target_classes, imagenet_classes_list,
                                    k=5, visualization_size=160, diversity=3)

In [None]:
group_visualizations, colors, saliency_map, effects =\
building_blocks_of_interpretability(model, image, layer, target_classes, imagenet_classes_list,
                                    k=5, visualization_size=160, diversity=4, lambda_diversity=1e4)