In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import datajoint as dj
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
sns.set_style('ticks', rc={'image.cmap': 'bwr'})

import os
import sys
import inspect

p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.mesonet.data import MultiDataset
from cnn_sys_ident.mesonet.parameters import Core, Readout, Model, RegPath, Fit
from cnn_sys_ident.mesonet import MODELS
from cnn_sys_ident.mesonet.insilico import OptimalGabor, OrthPlaidsContrast, \
    OrthPlaidsContrastParams
from cnn_sys_ident.architectures.training import Trainer

In [None]:
data_key = {'data_hash': 'cfcd208495d565ef66e7dff9f98764da'}
dataset = MultiDataset() & data_key

# Load a model

In [None]:
num_filters = 16
model_rel = MODELS['HermiteSparse'] * dataset \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
key = (Fit() * model_rel).fetch(dj.key, order_by='val_loss', limit=1)[0]
num_rotations = (model_rel & key).fetch1('num_rotations')
model = Fit().load_model(key)

In [None]:
masks = model.base.evaluate(model.readout.masks)
w = model.base.evaluate(model.readout.feature_weights)
w_norm = w / np.sqrt(np.sum(w ** 2, axis=1, keepdims=True))
w_marg = w_norm.reshape([-1, num_rotations, num_filters])
w_marg = np.sum(w_marg ** 2, axis=1)

print(masks.shape)
print(w.shape)
print(w_marg.shape)

In [None]:
trainer = Trainer(model.base, model)
r = trainer.compute_val_corr()
print(r.mean())

### Show plaids

In [None]:
s = model.base.inputs.shape.as_list()
canvas_size = [s[2], s[1]]
example_key = OptimalGabor.Unit().fetch(dj.key)[3]
loc, sz, sf, _, ori, ph = OptimalGabor.Unit().params(example_key)
g_pref = OrthPlaidsContrastParams().gabor_set(key, canvas_size, loc, sz, sf, ori, ph)
g_orth = OrthPlaidsContrastParams().gabor_set(key, canvas_size, loc, sz, sf, ori + np.pi/2, ph)
comps_pref = g_pref.images()
comps_orth = g_orth.images()
plaids = comps_pref[None,...] + comps_orth[:,None,...]

In [None]:
fig, axes = plt.subplots(plaids.shape[0], plaids.shape[1], figsize=(20, 20))
for ax, pl in zip(axes, plaids):
    for a, p in zip(ax, pl):
        a.imshow(p[:30,30:60], cmap='gray', vmin=-2, vmax=2)
        a.axis('off')

### Plot tuning curves

In [None]:
tc = OrthPlaidsContrast.Unit().fetch('tuning_curve', order_by='unit_id')
tc = np.array([t for t in tc])

contrasts = OrthPlaidsContrastParams().contrasts(key)

In [None]:
contrast_idx = np.concatenate([np.zeros(1), np.arange(1, 10)]).astype(np.uint32)
contrast_idx = np.arange(10)
colors = plt.cm.gist_earth(np.linspace(0, 1, b-a))
colors = np.flipud(colors)

min_corr = 0.1
n = [4, 4]
type_id = np.argmax(np.abs(w_marg), axis=1)
unit_ids = []
for i in range(num_filters):
    idx, = np.where((type_id == i) & (r > min_corr))
    order = np.argsort(-w_marg[idx,i])
    fig, axes = plt.subplots(n[0], n[1], figsize=(n[1], n[0]))
    for t, ax in zip(tc[idx[order]], axes.flatten()):
        ax.imshow(t, vmin=0, vmax=1.1*t.max(), extent=[0, 1, 1, 0])
        ax.axis('off')
    ax.axis('on')
    ax.yaxis.tick_right()
    ax.set_yticks([0, 1])
    fig.suptitle('Group {:d}'.format(i+1))
    fig.savefig('figures/orth_plaids_{:d}.eps'.format(i+1), format='eps')