# Fig. 6: Visualization of receptive fields

This notebook produces the figure showing the gradient receptive fields. 

*Note that in the [arXiv preprint](https://arxiv.org/abs/1809.10504) of the paper, we show Maximally Exciting Images (MEIs) computed by maximizing the acitivity of the model predicting. However, one of the ICLR reviewer did not trust these MEIs to be meaningful and made us replace them by linear (gradient) RFs for the final version. In the meantime we have shown in a [separate paper](https://www.biorxiv.org/content/10.1101/506956v1) that these MEIs do in fact reveal real properties of neurons, which is why we also include the code used to generate MEIs below and did not update the preprint version.*

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import datajoint as dj
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
sns.set_style('ticks', rc={'image.cmap': 'bwr'})

import os
import sys
import inspect

p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.mesonet.data import MultiDataset
from cnn_sys_ident.mesonet.vis import MEIParams, MEIGroup, MEI
from cnn_sys_ident.mesonet.parameters import Core, Readout, Model, RegPath, Fit
from cnn_sys_ident.architectures.training import Trainer
from cnn_sys_ident.mesonet import MODELS
from cnn_sys_ident.utils.mei import ActivityMaximization

# Load model and extract masks

In [None]:
data_key = {'data_hash': 'cfcd208495d565ef66e7dff9f98764da'}
dataset = MultiDataset() & data_key

In [None]:
num_filters = 16
model_rel = MODELS['HermiteSparse'] * dataset \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
key = (Fit() * model_rel).fetch(dj.key, order_by='val_loss', limit=1)[0]
num_rotations = (model_rel & key).fetch1('num_rotations')
model = Fit().load_model(key)

In [None]:
trainer = Trainer(model.base, model)
r = trainer.compute_val_corr()
print(r.mean())

In [None]:
masks = model.base.evaluate(model.readout.masks)
w = model.base.evaluate(model.readout.feature_weights)
w_norm = w / np.sqrt(np.sum(w ** 2, axis=1, keepdims=True))
w_marg = w_norm.reshape([-1, num_rotations, num_filters])
w_marg = np.sum(w_marg ** 2, axis=1)

print(masks.shape)
print(w.shape)
print(w_marg.shape)

# Gradient RFs

In [None]:
from cnn_sys_ident.utils.mei import GradientRF

In [None]:
num_filters = 16
model_rel = MODELS['HermiteSparse'] * dataset \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
key = (Fit() * model_rel).fetch(dj.key, order_by='val_loss', limit=1)[0]
num_rotations = (model_rel & key).fetch1('num_rotations')
model = Fit().get_model(key)

In [None]:
tfs = model.base.tf_session
graph = tfs.graph
checkpoint_file = os.path.join(tfs.log_dir, 'model.ckpt')
input_shape = [model.base.data.input_shape[1], model.base.data.input_shape[2]]
gradRF = GradientRF(graph, checkpoint_file, input_shape)
print('Computing gradient RFs. This may take a few minutes...')

min_corr = 0.2
k = 12   # x2
n_x = masks.shape[2] + 2*k
n = [4, 4]
type_id = np.argmax(np.abs(w_marg), axis=1)
fig, axes = plt.subplots(n[0], n[1], figsize=(4*n[1], 4*n[0]))
for i, ax in enumerate(axes.flatten()):
    idx, = np.where((type_id == i) & (r > min_corr))
    order = np.argsort(-w_marg[idx,i])
    rfs = []
    for unit_id in idx[order[:n[0]*n[1]]]:
        rf = gradRF.gradient(unit_id)
        rf = np.pad(rf, k, 'constant')
        rf /= np.abs(rf).max() + 1e-3
        mask = np.pad(masks[unit_id], k, 'constant')
        rf_idx = mask.argmax()
        rf_i, rf_j = rf_idx // n_x, rf_idx % n_x
        rfs.append(rf[rf_i-k:rf_i+k,rf_j-k:rf_j+k])

    rfs = np.array(rfs).reshape([n[0], n[1]*2*k, 2*k])
    rfs = np.concatenate(rfs, axis=1)
    ax.imshow(rfs, vmin=-1, vmax=1)
    ax.axis('off')

fig.savefig('figures/gradients_4x4_all.eps', format='eps')    

# Visualize MEIs for different cell types

This is Fig. 6 in the [preprint on arXiv](https://arxiv.org/abs/1809.10504). MEIs are loaded from the database. For the actual computation, refer to [cnn_sys_ident.mesonet.vis](../../cnn_sys_ident/mesonet/vis.py)

In [None]:
min_corr = 0.2
k = 15   # x2
n_x = masks.shape[2] + 2*k
n = [4, 4]
type_id = np.argmax(np.abs(w_marg), axis=1)
unit_ids = []
for i in range(num_filters):
    idx, = np.where((type_id == i) & (r > min_corr))
    order = np.argsort(-w_marg[idx,i])
    fig, axes = plt.subplots(n[0], n[1], figsize=(n[1], n[0]))
    unit_ids.append(idx[order[:n[0]*n[1]]])
    for unit_id, ax in zip(idx[order], axes.flatten()):
        unit_key = {'unit_id': unit_id, 'param_id': 1}
        rel = MEI() & unit_key & key
        if len(rel):
            img = rel.fetch1('max_image')
            img = np.pad(img, k, 'constant')
            m = np.abs(img).max() + 1e-3
            mask = np.pad(masks[unit_id], k, 'constant')
            rf_idx = mask.argmax()
            rf_i, rf_j = rf_idx // n_x, rf_idx % n_x
            img  = img[rf_i-k:rf_i+k,rf_j-k:rf_j+k]
            ax.imshow(img, vmin=-m, vmax=m)
        ax.axis('off')
    fig.savefig('figures/meis_4x4{:d}.eps'.format(i+1), format='eps')
unit_ids = np.array(unit_ids)
np.save('figures/unit_ids', unit_ids)