In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import datajoint as dj
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
sns.set_style('ticks', rc={'image.cmap': 'bwr'})

import os
import sys
import inspect

p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.mesonet.data import MultiDataset
from cnn_sys_ident.mesonet.parameters import Core, Readout, Model, RegPath, Fit
from cnn_sys_ident.mesonet import MODELS
from cnn_sys_ident.mesonet.insilico import GaborParams, OptimalGabor, \
    SizeContrastTuning, SizeContrastTuningParams
from cnn_sys_ident.architectures.training import Trainer

In [None]:
data_key = {'data_hash': 'cfcd208495d565ef66e7dff9f98764da'}
dataset = MultiDataset() & data_key

# Load a model

In [None]:
num_filters = 16
model_rel = MODELS['HermiteSparse'] * dataset \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
key = (Fit() * model_rel).fetch(dj.key, order_by='val_loss', limit=1)[0]
num_rotations = (model_rel & key).fetch1('num_rotations')
model = Fit().load_model(key)

In [None]:
masks = model.base.evaluate(model.readout.masks)
w = model.base.evaluate(model.readout.feature_weights)
w_norm = w / np.sqrt(np.sum(w ** 2, axis=1, keepdims=True))
w_marg = w_norm.reshape([-1, num_rotations, num_filters])
w_marg = np.sum(w_marg ** 2, axis=1)

print(masks.shape)
print(w.shape)
print(w_marg.shape)

In [None]:
trainer = Trainer(model.base, model)
r = trainer.compute_val_corr()
print(r.mean())

### Plot tuning curves

In [None]:
tc = SizeContrastTuning.Unit().fetch('tuning_curve', order_by='unit_id')
tc = np.array([t for t in tc])

ms, k, inc = SizeContrastTuningParams().fetch('min_size', 'num_sizes', 'size_increment')
sizes = ms * (inc ** np.arange(k))

# convert to degrees
monitor_distance = 15.0 # cm
monitor_width = 55.0    # cm
sizes_deg = np.arctan(monitor_width / masks.shape[2] * sizes / monitor_distance)/ np.pi * 180

In [None]:
a, b = 5, 12
colors = plt.cm.gist_earth(np.linspace(0, 1, b-a))
colors = np.flipud(colors)

min_corr = 0.1
n = [4, 4]
type_id = np.argmax(np.abs(w_marg), axis=1)
unit_ids = []
for i in range(num_filters):
    idx, = np.where((type_id == i) & (r > min_corr))
    order = np.argsort(-w_marg[idx,i])
    fig, axes = plt.subplots(n[0], n[1], figsize=(n[1], n[0]))
    for t, ax in zip(tc[idx[order]], axes.flatten()):
        for ti, ci in zip(t.T[a:b], colors):
            ax.plot(sizes_deg, ti, color=ci)
        ax.set_ylim([0, 1.1*t.max()])
        ax.axis('off')
    ax.axis('on')
    ax.set_yticks([])
    sns.despine(fig=fig)
    fig.suptitle('Group {:d}'.format(i+1))
    fig.savefig('figures/size_contrast_{:d}.eps'.format(i+1), format='eps')