In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import tensorflow as tf, numpy as np, os, sys
p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
from cnn_sys_ident.architectures.models import BaseModel3D, CorePlusReadoutModel
from cnn_sys_ident.architectures.cores import MultiScanCore, StackedFactorizedConv3dCore
from cnn_sys_ident.architectures.readouts import MultiScanReadout, SpatialXFeature3dJointL1Readout
from cnn_sys_ident.architectures.training import Trainer
from cnn_sys_ident.retina.data import Dataset
from cnn_sys_ident.retina.read_data import MultiDatasetWrapper
from cnn_sys_ident.architectures.utils import mean_sq_err, poisson, crop_responses

# Parameters

In [None]:
# Training
VAL_STEPS = 10
LEARNING_RATE = 0.001
BATCH_SIZE = 32
PATIENCE = 5
LR_DECAY_STEPS = 3

In [None]:
def print_reg_loss(session, feed_dict):
    def collect_loss(component_list, attribute):
        loss = 0
        for c in component_list:
            loss += getattr(c, attribute)
        return loss
    smooth_reg = model.core.cores[0].reg_loss
    mask_reg = collect_loss(model.readout.readouts, 'mask_reg')
    feature_reg = collect_loss(model.readout.readouts, 'feature_reg')
    smooth_reg, mask_reg, feature_reg, pred = session.run(
        [smooth_reg, mask_reg, feature_reg, model.predictions], feed_dict)
    print('Kernel smoothness: {} | Mask L1: {} | Feature weights L1: {} | Var[prediction]: {}'.format(
        smooth_reg, mask_reg, feature_reg, pred.var()))

In [None]:
# neurons for evaluation:
# 7,30,42,45 of 572 neurons with current filtering

# Linear model processing movies

In [None]:
md_wrapper = MultiDatasetWrapper(experimenter="Franke",
                                 date="2018-10-19",
                                 exp_num=1,
                                 stim_path='/gpfs01/berens/user/cbehrens/RGC_DNN/Stimuli/MouseCam2/')
md_wrapper.generate_dataset(filter_traces=True,
                            preproc_param_set_id = 1,
                            quality_threshold_movie=0.4)
data = md_wrapper.multi_dataset
base = BaseModel3D(
    data,
    log_hash='testing'   # If you omit this, it will generate a new, random hash every time
)
core = MultiScanCore(
    base,
    base.data,
    base.inputs,
    core_type=StackedFactorizedConv3dCore,
    filter_size_spatial=[15],   # probably too large -- otherwise downsample movies
    filter_size_temporal=[20],
    num_filters=[8],
    activation_fn=['none'],
    conv_smooth_weight_spatial=1,
    conv_smooth_weight_temporal=1,
)
readout = MultiScanReadout(
    base,
    base.data,
    core.output,
    readout_type=SpatialXFeature3dJointL1Readout,
    mask_sparsity=0.0005,     # TO DO: find good value
    feature_sparsity=0.01, # TO DO: find good value #0.005
    positive_feature_weights=False,
    nonlinearity=False
)
model = CorePlusReadoutModel(base, core, readout)
# model.load()
trainer = Trainer(base, model, error_fn=mean_sq_err)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=10,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS,
    callback=print_reg_loss)

trainer.compute_test_corr()      #   (for training, see below)

In [None]:
# plot spatial kernels
spatial_weights=model.base.evaluate(model.core.cores[0].weights_spatial)[0]
fig, ax = plt.subplots(2,8,figsize=(20,10))
ax=ax.flatten()
for i in range(2):
    for j in range(8):
        ax[i+j*2].imshow(spatial_weights[0,:,:,i,j],vmin=np.min(spatial_weights),vmax=np.max(spatial_weights))
# plt.savefig('../../../figures/spatial_kernels_8_filters_prepro1_08_08.png',dpi=300)

In [None]:
# plot temporal kernels
temporal_weights=model.base.evaluate(model.core.cores[0].weights_temporal)[0]
fig, ax = plt.subplots(2,8,figsize=(20,5),sharey=True)
ax=ax.flatten()
for i in range(2):
    for j in range(8):
        ax[i+j*2].plot(temporal_weights[:,0,0,i,j])
# plt.savefig('../../../figures/temporal_kernels_8_filters_prepro1_08_08.png',dpi=300)

In [None]:
# plot masks for first scan
masks=model.base.evaluate(model.readout.readouts[0].masks)
fig, ax = plt.subplots(5,9,figsize=(20,20))
for i in range(5):
    for j in range(9):
        ax[i,j].imshow(masks[j+9*i,:,:],vmin=np.min(masks),vmax=np.max(masks))
# plt.savefig('../../../figures/masks_8_filters_prepro1_08_08.png',dpi=300)

In [None]:
# plot masks for first scan
feature_weights=model.base.evaluate(model.readout.readouts[0].feature_weights)
fig, ax = plt.subplots(5,9,figsize=(20,20),sharey=True)
for i in range(5):
    for j in range(9):
        ax[i,j].plot(feature_weights[j+9*i,:])
# plt.savefig('../../../figures/masks_8_filters_prepro1_08_08.png',dpi=300)

In [None]:
trainer.compute_val_var_expl()

In [None]:
trainer.compute_test_var_expl()

In [None]:
trainer.compute_test_corr()

In [None]:
trainer.compute_val_corr()

In [None]:
inputs, test_responses = data.test()      
feed_dict = {model.base.inputs: inputs,
                         model.base.responses: responses,
                         model.base.is_training: False}
test_predictions = model.base.evaluate(model.predictions, feed_dict)
test_responses = crop_responses(test_predictions, test_responses)

In [None]:
i=7
plt.plot(test_responses[0,:,i])
plt.plot(test_predictions[0,:,i])

In [None]:
inputs, val_responses = data.val()      
feed_dict = {model.base.inputs: inputs,
                         model.base.responses: responses,
                         model.base.is_training: False}
val_predictions = model.base.evaluate(model.predictions, feed_dict)
val_responses = crop_responses(val_predictions,val_responses)

In [None]:
fig, ax = plt.subplots(3,5,figsize=(20,10))
ax=ax.flatten()
for k in range(15):
    ax[k].plot(val_responses[k,:,i])
    ax[k].plot(val_predictions[k,:,i])