In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import datajoint as dj
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
sns.set_style('ticks', rc={'image.cmap': 'bwr'})

import os
import sys
import inspect

p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.mesonet.data import MultiDataset
from cnn_sys_ident.mesonet.parameters import Core, Readout, Model, RegPath, Fit
from cnn_sys_ident.mesonet import MODELS

In [None]:
data_key = {'data_hash': 'cfcd208495d565ef66e7dff9f98764da'}
dataset = MultiDataset() & data_key

# Load a model

In [None]:
num_filters = 16
model_rel = MODELS['HermiteSparse'] * dataset \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
key = (Fit() * model_rel).fetch(dj.key, order_by='val_loss', limit=1)[0]
num_rotations = (model_rel & key).fetch1('num_rotations')
model = Fit().load_model(key)

### Find approximate receptive field locations

In [None]:
masks = model.base.evaluate(model.readout.masks)
k = 50
m = np.exp(k*masks) / np.sum(np.exp(k*masks), axis=(1,2), keepdims=True)

ny, nx = m.shape[1:]
x = [27, 49]
y = [5, 28]
plt.figure(figsize=(20, 10))
plt.imshow(m.max(axis=0), cmap='gray')
plt.plot([x[0], x[0]], [0, ny-1], 'w')
plt.plot([x[1], x[1]], [0, ny-1], 'w')
plt.plot([0, nx-1], [y[0], y[0]], 'w')
plt.plot([0, nx-1], [y[1], y[1]], 'w')
plt.colorbar()
plt.show()

## Generate set of Gabors

In [None]:
from cnn_sys_ident.utils.stimuli import GaborSet

In [None]:
canvas_size = (64, 36)
center_range = (27, 49, 5, 28)
sizes = 8 * 1.3 ** np.arange(8)
spatial_frequencies = 1 * 1.35 ** np.arange(-1, 3)
contrasts = 2.0 ** np.arange(-5, 1)
num_orientations = 12
num_phases = 8
g = GaborSet(canvas_size, center_range, sizes, spatial_frequencies,
             contrasts, num_orientations, num_phases)

In [None]:
for idx in np.random.randint(np.prod(g.num_stims), size=(10,)):
    plt.imshow(g.gabor_from_idx(idx), vmin=-1, vmax=1)
    plt.colorbar()
    plt.show()    

# Database tables

In [None]:
from cnn_sys_ident.mesonet.insilico import GaborParams, OptimalGabor

In [None]:
GaborParams()

In [None]:
OptimalGabor().populate()

# Size-contrast experiment

In [None]:
from cnn_sys_ident.mesonet.insilico import OptimalGabor, SizeContrastTuning, SizeContrastTuningParams

In [None]:
SizeContrastTuning().populate()

In [None]:
g = SizeContrastTuningParams().gabor_set(
    SizeContrastTuningParams().fetch1(dj.key),
    [64, 36], [12, 26], 1/8, np.pi/4, 0
)

In [None]:
fig, axes = plt.subplots(12, 12, figsize=(18, 12))
for ax, img in zip(axes.flatten(), g.images()):
    ax.matshow(img, cmap='gray', vmin=-1, vmax=1)

### Plot tuning curves

In [None]:
tc = SizeContrastTuning.Unit().fetch('tuning_curve')
tc = np.array([t for t in tc])

In [None]:
k = 12
colors = plt.cm.gist_earth(np.linspace(0, 1, k))
colors = np.flipud(colors)
n = 10
fig, axes = plt.subplots(n, n, figsize=(2*n, 2*n))
for ax, t in zip(axes.flatten(), tc):
    for ti, ci in zip(t.T, colors):
        ax.plot(ti, color=ci)
    ax.set_ylim([0, 1.1*t.max()])
    sns.despine(ax=ax)

# Orthogonal plaids

In [None]:
from cnn_sys_ident.mesonet.insilico import OrthPlaidsContrast, OrthPlaidsContrastParams