In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import tensorflow as tf, numpy as np, os, sys
p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
from cnn_sys_ident.architectures.models import BaseModel3D, CorePlusReadoutModel
from cnn_sys_ident.architectures.cores import MultiScanCore, StackedFactorizedConv3dCore
from cnn_sys_ident.architectures.readouts import MultiScanReadout, SpatialXFeature3dJointL1Readout
from cnn_sys_ident.architectures.training import Trainer
from cnn_sys_ident.retina.data import Dataset
from cnn_sys_ident.retina.read_data import MultiDatasetWrapper
from cnn_sys_ident.architectures.utils import mean_sq_err, poisson

# Parameters

In [None]:
# Training
VAL_STEPS = 50
LEARNING_RATE = 0.001
BATCH_SIZE = 128
PATIENCE = 5
LR_DECAY_STEPS = 2

# Linear model processing movies

In [None]:
md_wrapper = MultiDatasetWrapper(experimenter="Franke",
                                 date="2018-10-19",
                                 exp_num=1,
                                 stim_path='/gpfs01/berens/user/cbehrens/RGC_DNN/Stimuli/MouseCam2/')
md_wrapper.generate_dataset(filter_traces=True, quality_threshold = 0.5)
data = md_wrapper.multi_dataset # TODO: data loading
base = BaseModel3D(
    data,
    log_hash='testing'   # If you omit this, it will generate a new, random hash every time
)
core = MultiScanCore(
    base,
    base.data,
    base.inputs,
    core_type=StackedFactorizedConv3dCore,
    filter_size_spatial=[15],   # probably too large -- otherwise downsample movies
    filter_size_temporal=[20],
    num_filters=[8],
    activation_fn=['none'],
    conv_smooth_weight_spatial=0.005,
    conv_smooth_weight_temporal=0.005
)
readout = MultiScanReadout(
    base,
    base.data,
    core.output,
    readout_type=SpatialXFeature3dJointL1Readout,
    mask_sparsity=0.03,     # TO DO: find good value
    feature_sparsity=0.001, # TO DO: find good value
    positive_feature_weights=True,
    nonlinearity=True
)
model = CorePlusReadoutModel(base, core, readout)
# model.load()
trainer = Trainer(base, model, error_fn=poisson)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=VAL_STEPS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS)

trainer.compute_test_corr()      #   (for training, see below)