In [None]:
%pylab inline
import os
import sys
import numpy as np
import importlib

%cd ..
p = os.getcwd()
print("path:" + p)
if p not in sys.path:
    sys.path.append(p)
    
from cnn_sys_ident.data import Dataset, MonkeyDataset
from cnn_sys_ident.lnpsysid import LNP

### Load Data

In [None]:
# data_dict = Dataset.load_data()
# data_dict = Dataset.manage_repeats(data_dict)
# data_dict = Dataset.preprocess_nans(data_dict)
# data_dict = Dataset.add_train_test_types(data_dict, types_train='all', types_test='all')

# With a wrapper function
data_dict = Dataset.get_clean_data()

In [None]:
data = MonkeyDataset(data_dict, seed=1000, train_frac=0.8 ,subsample=2, crop = 30)

### Define the Model

In [None]:
model = LNP(data, log_dir='monkey', log_hash='lnp', obs_noise_model='poisson')

In [None]:
print('Log dir: %s' % model.log_hash)
_, test_responses = data.test_av()
_, val_responses, real_val_resps = data.val()
_, tr_responses, real_tr_resps = data.train()

val_array = data.nanarray(real_val_resps,val_responses)
tr_array = data.nanarray(real_tr_resps,tr_responses)
print('Average variances | validation set: %f | test set: %f' % (np.nanmean(np.nanvar(val_array, axis=0)), np.nanmean(np.nanvar(test_responses, axis=0))))

### Build the Model

In [None]:
model.build(smooth_reg_weight=0.5,
            sparse_reg_weight=0.01)


In [None]:
learning_rate=3e-4
for lr_decay in range(3):
    training = model.train(max_iter=10000,
                         val_steps=100,
                         save_steps=1000,
                         early_stopping_steps=5,
                         learning_rate=learning_rate)
    for (i, (logl, total_loss, mse, pred)) in training:
        print('Step %d | Total loss: %s | %s: %s | MSE: %s | Var(y): %s' % (i, total_loss, model.obs_noise_model, logl, mse, np.mean(np.var(pred, axis=0))))
    learning_rate /= 3
    print('Reducing learning rate to %f' % learning_rate)

print('Done fitting')

### Test Performance of the Model

In [None]:
model.performance_test()
eve = model.eve.mean()
print('Explainable variance explained on test set: {}'.format(eve))

In [None]:
model.performance_val()
eve_val = model.eve_val.mean()
print('Explainable variance explained on validation set: {}'.format(eve_val))

In [None]:
avg_correlation_valset = model.evaluate_avg_corr_val()
print('Mean single trial correlation on validation set: {}'.format(avg_correlation_valset))