In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import tensorflow as tf, numpy as np, os, sys
p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from matplotlib import pyplot as plt
import pandas as pd
%matplotlib inline

In [None]:
from cnn_sys_ident.architectures.models import BaseModel3D, CorePlusReadoutModel
from cnn_sys_ident.architectures.cores import MultiScanCore, StackedFactorizedConv3dCore
from cnn_sys_ident.architectures.readouts import MultiScanReadout, SpatialXFeature3dJointL1Readout, SpatialTransformerPooled3dReadout
from cnn_sys_ident.architectures.training import Trainer
from cnn_sys_ident.retina.data import Dataset
from cnn_sys_ident.retina.read_data import MultiDatasetWrapper
from cnn_sys_ident.architectures.utils import mean_sq_err, poisson, crop_responses

# Parameters

In [None]:
# Training
VAL_STEPS = 10
LEARNING_RATE = 0.001
BATCH_SIZE = 32
PATIENCE = 5
LR_DECAY_STEPS = 3

In [None]:
def print_reg_loss(session, feed_dict):
    def collect_loss(component_list, attribute):
        loss = 0
        for c in component_list:
            loss += getattr(c, attribute)
        return loss
    smooth_reg = model.core.cores[0].reg_loss
    feature_reg = collect_loss(model.readout.readouts, 'feature_reg')
    smooth_reg, feature_reg, pred = session.run(
        [smooth_reg, feature_reg, model.predictions], feed_dict)
    print('Kernel smoothness: {} | Feature weights L1: {} | Var[prediction]: {}'.format(
        smooth_reg, feature_reg, pred.var()))

In [None]:
# neurons for evaluation:
# 7,30,42,45 of 572 neurons with current filtering

# Linear model processing movies

16x16 + spatial padding works equally well as 12x12 without padding

In [None]:
md_wrapper = MultiDatasetWrapper(experimenter="Franke",
                                 date="2018-10-19",
                                 exp_num=1,
                                 stim_path='/gpfs01/berens/user/cbehrens/RGC_DNN/Stimuli/MouseCam2/')
md_wrapper.generate_dataset(detrend_traces=True,
                            quality_threshold_movie = 0.5,
                            quality_threshold_chirp = 0.35,
                            detrend_param_set_id = 2,
                            downsample_size = 28)
data = md_wrapper.multi_dataset
base = BaseModel3D(
    data,
    log_hash='testing'
)
core = MultiScanCore(
    base,
    base.data,
    base.inputs,
    core_type=StackedFactorizedConv3dCore,
    filter_size_spatial=[12],
    filter_size_temporal=[20],
    num_filters=[8],
    activation_fn=['relu'],
    # nonzero_padding=True,
    # padding_constant= np.min(data.movie_train),
    conv_smooth_weight_spatial=0.5,
    conv_smooth_weight_temporal=0.5,
)
readout = MultiScanReadout(
    base,
    base.data,
    core.output,
    readout_type=SpatialTransformerPooled3dReadout,
    pool_steps=4,
    init_range=.1,
    # mask_sparsity=0.005,
    feature_sparsity=0.001, 
    positive_feature_weights=True,
    nonlinearity=True
)
model = CorePlusReadoutModel(base, core, readout)
# model.load()
trainer = Trainer(base, model, error_fn=poisson)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=10,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS,
    callback=print_reg_loss)

trainer.compute_test_corr()

In [None]:
# plot spatial kernels
spatial_weights=model.base.evaluate(model.core.cores[0].weights_spatial)[0]
fig, ax = plt.subplots(2,8,figsize=(20,10))
ax=ax.flatten()
for i in range(2):
    for j in range(8):
        ax[i+j*2].imshow(spatial_weights[0,:,:,i,j],vmin=np.min(spatial_weights),vmax=np.max(spatial_weights))
# plt.savefig('../../../figures/spatial_kernels_8_filters_prepro2_09_11_padding.png',dpi=300)

In [None]:
# plot temporal kernels
temporal_weights=model.base.evaluate(model.core.cores[0].weights_temporal)[0]
fig, ax = plt.subplots(2,4,figsize=(20,5),sharey=True)
ax=ax.flatten()
for i in range(2):
    for j in range(4):
        ax[i+j*2].plot(temporal_weights[:,0,0,i,j])
# plt.savefig('../../../figures/temporal_kernels_8_filters_prepro2_09_11_padding.png',dpi=300)

In [None]:
# plot masks for first scan
masks=model.base.evaluate(model.readout.readouts[0].masks)
fig, ax = plt.subplots(4,6,figsize=(20,15))
for i in range(4):
    for j in range(6):
        ax[i,j].imshow(masks[j+6*i,:,:],vmin=np.min(masks),vmax=np.max(masks))
# plt.savefig('../../../figures/masks_8_filters_prepro2_09_11_padding.png',dpi=300)

In [None]:
# plot masks for first scan
feature_weights=model.base.evaluate(model.readout.readouts[0].feature_weights)
fig, ax = plt.subplots(4,6,figsize=(20,15),sharey=True)
for i in range(4):
    for j in range(6):
        ax[i,j].plot(feature_weights[j+6*i,:])
# plt.savefig('../../../figures/feature_weights_8_filters_prepro2_09_11_padding.png',dpi=300)

In [None]:
# plot grids
grid = model.base.evaluate(model.readout.readouts[1].grid)
fig, ax = plt.subplots(2,5,figsize=(20,7),sharey=True)
ax = ax.flatten()
for i in range(10):
    grid = (model.base.evaluate(model.readout.readouts[i].grid))/2*17*25
    ax[i].scatter(grid[0,:,0], grid[0,:,1])
    ax[i].set(xlim=(-40,70),ylim=(-40,70))
# plt.savefig('../../../figures/transformer_test_um.png',dpi=300)

In [None]:
trainer.compute_val_var_expl()

In [None]:
trainer.compute_test_var_expl()

In [None]:
trainer.compute_test_corr()

In [None]:
trainer.compute_val_corr()

In [None]:
correlations=trainer.compute_test_corr(average=False)

In [None]:
corr_ranked_index=np.argsort(correlations)

In [None]:
corr_mid_index=np.where((correlations>0.3)&(correlations<0.6))[0]

In [None]:
# plot 10 cells with worst correlations
fig, ax = plt.subplots(10,3,figsize=(20,20))
for k in range(10):
    ind=corr_ranked_index[k]
    ax[k,0].plot(test_responses[0,:,ind],color='black')
    ax[k,0].plot(test_predictions[0,:,ind],color='red')
    ax[k,1].plot(test_predictions[0,:,ind],color='red')
    ax[k,2].plot(flat_movie_green_avg_filtered)
# plt.savefig('../../../figures/8_filters_prepro2_09_11_padding_test_worst.png',dpi=300)

In [None]:
# plot 10 cells with medium correlations
fig, ax = plt.subplots(10,3,figsize=(20,20))
for k in range(10):
    ind=corr_mid_index[k]
    ax[k,0].plot(test_responses[0,:,ind],color='black')
    ax[k,0].plot(test_predictions[0,:,ind],color='red')
    ax[k,1].plot(test_predictions[0,:,ind],color='red')
    ax[k,2].plot(flat_movie_green_avg_filtered)
# plt.savefig('../../../figures/8_filters_prepro2_09_11_padding_test_mid.png',dpi=300)

In [None]:
# plot 10 cells with best correlations
fig, ax = plt.subplots(10,3,figsize=(20,20))
for k in range(10):
    ind=corr_ranked_index[-k-1]
    ax[k,0].plot(test_responses[0,:,ind],color='black')
    ax[k,0].plot(test_predictions[0,:,ind],color='red')
    ax[k,1].plot(test_predictions[0,:,ind],color='red')
    ax[k,2].plot(flat_movie_green_avg_filtered)
# plt.savefig('../../../figures/8_filters_prepro2_09_11_padding_test_best.png',dpi=300)

In [None]:
from scipy import signal
def butter_highpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = signal.butter(order, normal_cutoff, btype='high', analog=False,
                        output='ba')
    return b, a
def butter_highpass_filter(data, cutoff, fs, order=5, axis=0):
    b, a = butter_highpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data, axis)
    return y

In [None]:
flat_movie_green = np.reshape(data.movie_test[0, 0, :, :, :, 0], (750, 28*28))
flat_movie_green_avg = np.average(flat_movie_green, axis=1)
flat_movie_green_avg_filtered = butter_highpass_filter(flat_movie_green_avg, cutoff=0.1, fs=30, order=2)
# plt.plot(flat_movie_green_avg_filtered)

In [None]:
inputs, test_responses = data.test()      
feed_dict = {model.base.inputs: inputs,
                         model.base.responses: test_responses,
                         model.base.is_training: False}
test_predictions = model.base.evaluate(model.predictions, feed_dict)
test_responses = crop_responses(test_predictions, test_responses)

In [None]:
i=7
i=corr_ranked_index[-1]
plt.plot(test_responses[0,:,i])
plt.plot(test_predictions[0,:,i])

In [None]:
inputs, val_responses = data.val()      
feed_dict = {model.base.inputs: inputs,
                         model.base.responses: val_responses,
                         model.base.is_training: False}
val_predictions = model.base.evaluate(model.predictions, feed_dict)
val_responses = crop_responses(val_predictions,val_responses)

In [None]:
fig, ax = plt.subplots(3,5,figsize=(20,10))
ax=ax.flatten()
for k in range(15):
    ax[k].plot(val_responses[k,:,i])
    ax[k].plot(val_predictions[k,:,i])