Split the protocol into training, validation and testing set and then augment and train.

In [None]:
# Load libraries and modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from modelval import pairptl, network, trainer, dataset, data_aug_knn, perform_eval
from modelval.ArbDataGen import arb_w_gen
from modelval.spk_visu import spk_see, raster
from modelval import gp_regressor
from modelval import data_aug_gp
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

% matplotlib inline
% load_ext autoreload
% autoreload 2

In [None]:
# Load data frame
data = pd.read_csv('/src/Plasticity_Ker/data/kernel_training_data_auto.csv')
data.head()

In [None]:
# Obtain augmented data from STDP protocol
params = {'bias': 7.8571428571428577,
 'sigma_kernel': 0.54999999999999993,
 'sigma_obs': 0.90000000000000002,
 'if_stat_kernel': False
         }
x_stdp, f_stdp, x_stdp_test, y_stdp_test = data_aug_gp.stdp_gp(**params)

x_stdp_train, x_stdp_vali, y_stdp_train, y_stdp_vali = train_test_split(x_stdp, f_stdp, test_size=0.2)

In [None]:
plt.plot(x_stdp_train, y_stdp_train, 'o',label='train_data')
plt.plot(x_stdp_vali, y_stdp_vali, 'o',label='vali_data')
plt.plot(x_stdp_test,y_stdp_test, 'o', label='test_data')
plt.legend()

In [None]:
# Pute time information into dataframe
data_stdp_train = data_aug_gp.STDP_dw_gen(x_stdp_train)
data_stdp_vali = data_aug_gp.STDP_dw_gen(x_stdp_vali)
data_stdp_test = data_aug_gp.STDP_dw_gen(x_stdp_test.reshape(-1,1))

In [None]:
data_stdp_train.shape, y_stdp_train.shape,data_stdp_vali.shape, y_stdp_vali.shape

In [None]:
# # Generate data for STDP
# data3 = data[data['ptl_idx']==3]

# # Split into training and testing set (80%, 20%)
# # Create train/vali and test data frame
# np.random.seed(0)

# x_train, x_test, y_train, y_test = train_test_split(data3['dt2'], data3['dw_mean'],test_size=0.2, random_state=0)
# plt.plot(x_train, y_train,'o')

In [None]:
# x_train_quad, x_test_quad, y_train_quad, y_test_quad = train_test_split(data3['dt2'].values, data3['dw_mean'].values, test_size=0.2, random_state=0)

In [None]:
# x_r = x_train_quad[np.where(x_train_quad>0)[0]].reshape(-1,1)
# y_r = y_train_quad[np.where(x_train_quad>0)[0]].reshape(-1,1)
# x_test_r = np.linspace(np.min(x_r),120,120).reshape(-1,1)

# x_l = x_train_quad[np.where(x_train_quad<0)[0]].reshape(-1,1)
# y_l = y_train_quad[np.where(x_train_quad<0)[0]].reshape(-1,1)
# x_test_l = np.linspace(-120,np.max(x_l),120).reshape(-1,1)

In [None]:
# gp_rg = gp_regressor.GP_regressor(x_r, y_r, x_test_r,sigma_kernel=1.9, scale=5, bias=2.78, sigma_obs=3.0, noise_const=96.7, if_stat_kernel=False, if_stat_noise=False)
# f_r, v_f_r, lp = gp_rg.fit()
# std = np.sqrt(v_f_r.transpose().diagonal()).reshape(-1,1)
# plt.plot(x_r, y_r,'o', label='Raw_data_train')
# plt_range1 = np.arange(0,112,1)

# plt.plot(x_test_r, f_r,'g')
# plt.fill_between(np.squeeze(x_test_r), np.squeeze(f_r-1.96*std), np.squeeze(f_r+1.96*std), alpha=1, color='deeppink')

# gp_rg = gp_regressor.GP_regressor(x_l, y_l, x_test_l,sigma_kernel=1.9, scale=5, bias=2.78, sigma_obs=3.0, noise_const=96.7, if_stat_kernel=False, if_stat_noise=False)
# f_l, v_f_l, lp = gp_rg.fit()
# std = np.sqrt(v_f_l.transpose().diagonal()).reshape(-1,1)
# plt.plot(x_l, y_l,'o', label='Raw_data_train')
# plt.plot(x_test_l, f_l,'g')
# plt.fill_between(np.squeeze(x_test_l), np.squeeze(f_l-1.96*std), np.squeeze(f_l+1.96*std), alpha=1, color='deeppink')
# #plt.fill_between(np.squeeze(x_aug), np.squeeze(f-1.96*std), np.squeeze(f+1.96*std), alpha=1, color='deeppink', label="95% confidence interval")

# # Sample from the gp regression
# for i in range(len(f_l)):
#     np.random.seed(i)
#     scale = 5 * np.exp(-1 * np.abs(x_test_l[i])/96.7)
#     noise = np.random.normal(loc=0, scale=scale, size=1)
#     f_l[i] = f_l[i] + noise

# for i in range(len(f_r)):
#     np.random.seed(i)
#     scale = 5 * np.exp(-1 * np.abs(x_test_r[i])/96.7)
#     noise = np.random.normal(loc=0, scale=scale, size=1)
#     f_r[i] = f_r[i] + noise

# plt.plot(x_test_r, f_r, 'ro', label='Sampled data')
# plt.plot(x_test_l, f_l, 'ro', label='Sampled data')

# plt.legend(loc='upper left')

In [None]:
# Obtain augmented data from quadruplet protocol
# Obtain augmented data from STDP protocol
params = {'bias': 6.4285714285714288,
 'sigma_kernel': 0.84999999999999998,
 'sigma_obs': 0.79999999999999993,
 'if_stat_kernel': False}

x_train_r, y_train_r, x_train_l, y_train_l, x_quad_test, y_quad_test =data_aug_gp.quad_gp(**params)
# Split into training and validation dataset
x_quad = np.concatenate([x_train_l,x_train_r])
y_quad = np.concatenate([y_train_l,y_train_r])

# Split into training and validation set
x_quad_train, x_quad_vali, y_quad_train, y_quad_vali = train_test_split(x_quad, y_quad, test_size=0.2)

plt.plot(x_quad_train, y_quad_train, 'o', label='train_data')
plt.plot(x_quad_vali, y_quad_vali, 'o', label='vali_data')
plt.plot(x_quad_test, y_quad_test, 'o', label='test_data')
plt.legend()

In [None]:
# Put dt information into dataframe
data_quad_train = data_aug_gp.quad_dw_gen(x_quad_train)
data_quad_vali = data_aug_gp.quad_dw_gen(x_quad_vali)
data_quad_test = data_aug_gp.quad_dw_gen(x_quad_test.reshape(-1,1))

In [None]:
data_quad_train.shape, y_quad_train.shape, data_quad_vali.shape, y_quad_vali.shape

In [None]:
# Combine data from stdp and quadruplet
data_gen_train = pd.concat([data_stdp_train, data_quad_train], axis=0)
y_train = np.concatenate([y_stdp_train, y_quad_train])
data_gen_vali = pd.concat([data_stdp_vali, data_quad_vali], axis=0)
y_vali = np.concatenate([y_stdp_vali, y_quad_vali])
data_gen_test = pd.concat([data_stdp_test, data_quad_test], axis=0)
y_test = np.concatenate([y_stdp_test, y_quad_test])

In [None]:
data_gen_train.shape, y_train.shape, data_gen_vali.shape, y_vali.shape

In [None]:
trip_para = pd.read_pickle('/src/Plasticity_Ker/data/Gerstner_trip_para_df')
trip_para
# Reorder columns to match parameter of the model
trip_para = trip_para[['A2_+', 'A3_-', 'A2_-', 'A3_+', 'Tau_+', 'Tau_x', 'Tau_-', 'Tau_y']]
trip_para

In [None]:
# Visualize kernel
from modelval.kernel import KernelGen
ker_test = KernelGen()

para = trip_para.loc[('Hippo_AlltoAll', 'Full'), :]
a = para[:4].values
tau = para[4:].values
reso_set = 2
tau_pre_post = tau[0]/reso_set  # ms
tau_post_pre = tau[2]/reso_set # ms

ker_test = KernelGen(len_kernel=101)
ker_test.trip_model_ker(para, data_name='Hippocampus')

In [None]:
# Generat the spike trains and targets for STDP
data3 = data[data['ptl_idx']==3]
ptl_list = [1,3]
spk_len = int(data3['train_len'].max() * 1000 / ker_test.reso_kernel)
if_noise = 0
aug_times = [1,1]
spk_pairs_train, targets_train = arb_w_gen(df=data_gen_train, ptl_list=ptl_list, targets=y_train, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)
spk_pairs_vali, targets_vali = arb_w_gen(df=data_gen_vali, ptl_list=ptl_list, targets=y_vali, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)

In [None]:
# Create the network
ground_truth_init = 0
reg_scale=(1, 1)
toy_data_net = network.TripNet(kernel=ker_test, ground_truth_init=ground_truth_init, reg_scale=reg_scale, n_input=spk_pairs_train.shape[1])

In [None]:
# Create the trainer
save_dir= '/src/Plasticity_Ker/model/Pair_ptl1_3_real_aug_gp_no_noise'
# optimizer_op = tf.train.GradientDescentOptimizer
toy_net_trainer = trainer.Trainer(toy_data_net.loss, toy_data_net.loss, input_name=toy_data_net.inputs, target_name=toy_data_net.target, save_dir=save_dir, optimizer_config={'learning_rate': toy_data_net.lr})

In [None]:
train_data = dataset.Dataset(spk_pairs_train, targets_train)
vali_data = dataset.Dataset(spk_pairs_vali, targets_vali)

In [None]:
w_pre = toy_net_trainer.evaluate(ops=toy_data_net.kernel_pre)
w_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post)
w_post_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post_post)
fc_w = toy_net_trainer.evaluate(ops=toy_data_net.fc_w)
bias = toy_net_trainer.evaluate(ops=toy_data_net.bias)
plt.plot(w_pre,  label='ker_pre_init')
plt.plot(w_post,  label='ker_post_init')
plt.plot(w_post_post,  label='ker_post_post_init')

plt.legend()
print(fc_w, bias)

In [None]:
# Learn the kernel from random initialization
learning_rate = 0.001
iterations = 5
min_error = -1
for i in range(iterations):
    toy_net_trainer.train(train_data, vali_data, batch_size=128, min_error=min_error, feed_dict={toy_data_net.lr: learning_rate})
    learning_rate = learning_rate/3

In [None]:
# Restore the model and parameter
toy_net_trainer.restore_best()
w_pre = toy_net_trainer.evaluate(ops=toy_data_net.kernel_pre)
w_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post)
w_post_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post_post)
fc_w = toy_net_trainer.evaluate(ops=toy_data_net.fc_w)
bias = toy_net_trainer.evaluate(ops=toy_data_net.bias)
plt.plot(-1 * w_pre,label='ker_pre_trained')
plt.plot(-1 * w_post,label='ker_post_trained')
plt.plot(w_post_post,label='ker_post_post_trained')

plt.legend()
print(fc_w, bias)

In [None]:
# Updated the kernel as trained kernel
ker_test.kernel_pre = w_pre
ker_test.kernel_post = w_post
ker_test.kernel_post_post= w_post_post
ker_test.kernel_scale = fc_w
ker_test.bias = bias

In [None]:
# Generate training predictions
spk_len = int(data['train_len'].max() * 1000 / ker_test.reso_kernel)
_, predictions = arb_w_gen(spk_pairs=spk_pairs_train, spk_len=spk_len, kernel=ker_test, net_type='triplet')
plt.plot(targets_train[:160,:], predictions[:160,:], 'o')
plt.plot(targets_train[160:,:], predictions[160:,:], 'o')

plt.plot(np.linspace(-30,60,90),np.linspace(-30,60,90),'--')

In [None]:
# Generate validation predictions
spk_len = int(data['train_len'].max() * 1000 / ker_test.reso_kernel)
spk_pairs, predictions = arb_w_gen(spk_pairs=spk_pairs_vali, spk_len=spk_len, kernel=ker_test, net_type='triplet')
plt.plot(targets_vali, predictions, 'o')
plt.plot(np.linspace(-30,60,90),np.linspace(-30,60,90),'--')

Evaluate the test restult

In [None]:
# Generat the spike trains and targets for STDP
data3 = data[data['ptl_idx']==3]
ptl_list = [1,3]
spk_len = int(data3['train_len'].max() * 1000 / ker_test.reso_kernel)
if_noise = 0
aug_times = [1,1]
spk_pairs_test, targets_test = arb_w_gen(df=data_gen_test, ptl_list=ptl_list, targets=y_test, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)

In [None]:
# Generate validation predictions
spk_len = int(data['train_len'].max() * 1000 / ker_test.reso_kernel)
spk_pairs, predictions = arb_w_gen(spk_pairs=spk_pairs_test, spk_len=spk_len, kernel=ker_test, net_type='triplet')
plt.plot(targets_test, predictions, 'o')
plt.plot(np.linspace(-30,60,90),np.linspace(-30,60,90),'--')

In [None]:
# # Generate data for Triplet
# dt = np.array([-10, -5, 0, 5, 10]).reshape(-1,1)
# data2_gen, targets2 = dw_gen.triplet_dw_gen(dt)

In [None]:
# # Visualize the triplet data
# data2_gen

In [None]:
# # Generate data for Quadruplet
# data3 = data[data['ptl_idx']==3]
# data3_gen, targets3 = dw_gen.quad_dw_gen(n_neighbors=7)

In [None]:
# width_list = np.concatenate([np.linspace(10,3,45), np.linspace(3,10,45)])
# plt.plot(data3['dt2'], data3['dw_mean'],'o', label='Raw data')
# plt.plot(data3_gen['dt2'], targets3,'o', label='KNN')
# targets3_sm = np.concatenate([dw_gen.smooth(targets3[:45],width_list = width_list), dw_gen.smooth(targets3[45:],width_list = width_list)])
# plt.plot(data3_gen['dt2'],targets3_sm,'o', label='Smoothness filter')
# plt.xlabel('dt(ms)')
# plt.ylabel('$\Delta w$')
# plt.legend(loc='upper left')

In [None]:
# # Sample randomly the smoothed Quadruplet data
# samp_len = len(targets3_sm)
# np.random.seed(1)
# test_idx_quad = np.unique(np.random.randint(low=0, high=90, size=9))
# train_vali_idx = np.setdiff1d(np.linspace(0,89,90), test_idx_quad).astype(int)
# np.random.seed(10)
# vali_idx_idx = np.random.randint(low=0, high=80, size=18)
# vali_idx_quad = np.unique(train_vali_idx[vali_idx_idx])
# train_idx_quad = np.setdiff1d(train_vali_idx, vali_idx_quad).astype(int)
# plt.plot(data3_gen.loc[train_idx_quad]['dt2'],targets3_sm[train_idx_quad],'o', label='train_data')
# plt.plot(data3_gen.loc[vali_idx_quad]['dt2'],targets3_sm[vali_idx_quad],'o', label='vali_data')
# plt.plot(data3_gen.loc[test_idx_quad]['dt2'],targets3_sm[test_idx_quad],'o', label='test_data')
# plt.legend()
# print(len(set(train_idx_quad)), len(set(vali_idx_quad)), len(set(test_idx_quad)))

In [None]:
# Visualize kernel
from modelval.kernel import KernelGen
ker_test = KernelGen()

para = trip_para.loc[('Hippo_AlltoAll', 'Full'), :]
a = para[:4].values
tau = para[4:].values
reso_set = 2
tau_pre_post = tau[0]/reso_set  # ms
tau_post_pre = tau[2]/reso_set # ms

ker_test = KernelGen(len_kernel=101)
ker_test.trip_model_ker(para, data_name='Hippocampus')

In [None]:
data_gen_train = pd.concat([data1_gen.loc[train_idx_stdp],data2_gen, data3_gen.loc[train_idx_quad]])
targets_gen_train = np.concatenate([targets1_sm[train_idx_stdp], targets2, targets3_sm[train_idx_quad]])

In [None]:
data_gen_vali = pd.concat([data1_gen.loc[vali_idx_stdp], data2_gen, data3_gen.loc[vali_idx_quad]])
targets_gen_vali = np.concatenate([targets1_sm[vali_idx_stdp], targets2, targets3_sm[vali_idx_quad]])

In [None]:
data_gen_test = pd.concat([data1_gen.loc[test_idx_stdp], data2_gen, data3_gen.loc[test_idx_quad]])
targets_gen_test = np.concatenate([targets1_sm[test_idx_stdp], targets2, targets3_sm[test_idx_quad]])

In [None]:
# len_stdp = len(vali_idx_stdp)*
# len_triplet = len(vali_idx_stdp)*20+len(data2_gen[data2_gen['ptl_idx']==2])*40
# len_trip2 = len(vali_idx_stdp)*20+len(data2_gen[data2_gen['ptl_idx']==2])*40+len(data2_gen[data2_gen['ptl_idx']==4])*40
# len_quad = len(targets_gen_vali) - len_trip2

In [None]:
data_gen_train['ptl_idx'].value_counts()

In [None]:
# Generat the spike trains and targets for STDP
ptl_list = [1,2,4,3]
spk_len = int(data3_gen['train_len'].max() * 1000 / ker_test.reso_kernel)
if_noise = 1
aug_times = [20,40,40,20]
spk_pairs_train, targets_train = arb_w_gen(df=data_gen_train, ptl_list=ptl_list, targets=targets_gen_train, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)

In [None]:
# Generat the spike trains and targets for Quadruplet
ptl_list = [1,2,4,3]
spk_len = int(data3_gen['train_len'].max() * 1000 / ker_test.reso_kernel)
if_noise = 1
aug_times = [20,40,40,20]
spk_pairs_vali, targets_vali = arb_w_gen(df=data_gen_vali, ptl_list=ptl_list, targets=targets_gen_vali, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)

In [None]:
spk_pairs_train.shape, spk_pairs_vali.shape

In [None]:
# Create the network
ground_truth_init = 0
reg_scale=(1, 1)
init_seed=(4,5,6,7)
toy_data_net = network.TripNet(kernel=ker_test, ground_truth_init=ground_truth_init, init_seed=init_seed, reg_scale=reg_scale, n_input=spk_pairs_train.shape[1])

In [None]:
# Create the trainer
save_dir= '/src/Plasticity_Ker/model/Trip_ptl1_4_real_aug'
# optimizer_op = tf.train.GradientDescentOptimizer
toy_net_trainer = trainer.Trainer(toy_data_net.loss, toy_data_net.loss, input_name=toy_data_net.inputs, optimizer_op=optimizer_op, target_name=toy_data_net.target, save_dir=save_dir, optimizer_config={'learning_rate': toy_data_net.lr})

In [None]:
train_data = dataset.Dataset(spk_pairs_train, targets_train)
vali_data = dataset.Dataset(spk_pairs_vali, targets_vali)

In [None]:
w_pre = toy_net_trainer.evaluate(ops=toy_data_net.kernel_pre)
w_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post)
w_post_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post_post)
fc_w = toy_net_trainer.evaluate(ops=toy_data_net.fc_w)
bias = toy_net_trainer.evaluate(ops=toy_data_net.bias)
plt.plot(w_pre,  label='ker_pre_trained')
plt.plot(w_post,  label='ker_post_trained')
plt.plot(w_post_post,  label='ker_post_trained')
plt.legend()
print(fc_w, bias)

In [None]:
# Learn the kernel from random initialization
learning_rate = 0.001
iterations = 5
min_error = -1
for i in range(iterations):
    toy_net_trainer.train(train_data, vali_data, batch_size=128, min_error=min_error, feed_dict={toy_data_net.lr: learning_rate})
    learning_rate = learning_rate/3

In [None]:
toy_net_trainer.restore_best()
w_pre = toy_net_trainer.evaluate(ops=toy_data_net.kernel_pre)
w_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post)
w_post_post = toy_net_trainer.evaluate(ops=toy_data_net.kernel_post_post)
fc_w = toy_net_trainer.evaluate(ops=toy_data_net.fc_w)
bias = toy_net_trainer.evaluate(ops=toy_data_net.bias)
plt.plot(-1*w_pre, label='ker_pre_trained')
plt.plot(-1*w_post, label='ker_post_trained')
plt.plot(-1*w_post_post, label='ker_post_post_trained')
plt.legend()
print([fc_w, bias])

In [None]:
# # Test effect of smoothed kernel
# w_pre_sm = w_pre
# w_post_sm = w_post 
# w_post_post_sm = w_post_post
# w_pre_sm[:50] = dw_gen.smooth(w_pre[:50], width=2)
# w_post_sm[:48] = dw_gen.smooth(w_post[:48], width=2)
# w_post_post_sm[:49] = dw_gen.smooth(w_post_post[:49], width=3)

# plt.plot(-1 * w_pre_sm)
# plt.plot(-1 * w_post_sm)
# plt.plot(-1*w_post_post_sm)

# ker_test.kernel_pre = w_pre_sm
# ker_test.kernel_post = w_post_sm
# ker_test.kernel_post_post= w_post_post_sm
# ker_test.kernel_scale = fc_w
# ker_test.bias = bias

Compare the target and prediction

In [None]:
ker_test.kernel_pre = w_pre
ker_test.kernel_post = w_post
# ker_test.kernel_post_post= w_post_post
ker_test.kernel_scale = fc_w
ker_test.bias = bias

In [None]:
# Generate training predictions
spk_len = int(data1_gen['train_len'].max() * 1000 / ker_test.reso_kernel)
spk_pairs, predictions = arb_w_gen(spk_pairs=spk_pairs_train, spk_len=spk_len, kernel=ker_test, net_type='triplet')

In [None]:
ptl_len = [len(train_idx_stdp), len(data2_gen[data2_gen['ptl_idx']==2]), len(data2_gen[data2_gen['ptl_idx']==4]), len(train_idx_quad)]
ptl_whole_len, targets_ptl, predictions_ptl = dw_gen.target_pred_gen(targets_train, predictions, ptl_len, [20, 40, 40, 20])

In [None]:
ptl_name = ['stdp', 'triplet', 'trip2', 'quadruplet']
x_fit = np.linspace(np.min(targets_vali)-1, np.max(targets_vali)+1, 100)
for i in range(len(ptl_len)):
    plt.plot(np.linspace(-30,70,100),np.linspace(-30,70,100), 'k--' )
    # R2, corr, y_fit = perform_eval.R2_corr(predictions_ptl[i],targets_ptl[i], x_fit)
    plt.plot(targets_ptl[i], predictions_ptl[i], 'o', label=ptl_name[i]+'(n={a})'.format(a=targets_ptl[i].shape[0]))

R2, corr, y_fit = perform_eval.R2_corr(predictions,targets_train, x_fit)    
plt.xlabel('targets')
plt.ylabel('predictions')
plt.title(''.join(('R2=%.4f'%(R2), ', Corr=%.4f'%(corr))))
plt.legend()

In [None]:
# Generat the spike trains and targets for Quadruplet
ptl_list = [1,2,4,3]
spk_len = int(data3_gen['train_len'].max() * 1000 / ker_test.reso_kernel)
if_noise = 1
aug_times = [1,1,1,1]
spk_pairs_vali, targets_vali = arb_w_gen(df=data_gen_vali, ptl_list=ptl_list, targets=targets_gen_vali, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)

In [None]:
# Generate validation predictions
spk_pairs, predictions = arb_w_gen(spk_pairs=spk_pairs_vali, spk_len=spk_len, kernel=ker_test, net_type='triplet')

In [None]:
ptl_len = [len(vali_idx_stdp), len(data2_gen[data2_gen['ptl_idx']==2]), len(data2_gen[data2_gen['ptl_idx']==4]), len(vali_idx_quad)]
ptl_whole_len, targets_ptl, predictions_ptl = dw_gen.target_pred_gen(targets_vali, predictions, ptl_len, aug_times)

In [None]:
ptl_name = ['stdp', 'triplet', 'trip2', 'quadruplet']
x_fit = np.linspace(np.min(targets_vali)-1, np.max(targets_vali)+1, 100)
for i in range(len(ptl_len)):
    plt.plot(np.linspace(-30,70,100),np.linspace(-30,70,100), 'k--' )
    # R2, corr, y_fit = perform_eval.R2_corr(predictions_ptl[i],targets_ptl[i], x_fit)
    plt.plot(targets_ptl[i], predictions_ptl[i], 'o', label=ptl_name[i]+'(n={a})'.format(a=targets_ptl[i].shape[0]))

R2, corr, y_fit = perform_eval.R2_corr(predictions,targets_vali, x_fit)    

plt.xlabel('targets')
plt.ylabel('predictions')
plt.title(''.join(('R2=%.4f'%(R2), ', Corr=%.4f'%(corr))))
plt.legend()

In [None]:
# Generat the spike trains and targets for Quadruplet
ptl_list = [1,2,4,3]
spk_len = int(data3_gen['train_len'].max() * 1000 / ker_test.reso_kernel)
if_noise = 0
aug_times = [1,1,1,1]
spk_pairs_test, targets_test = arb_w_gen(df=data_gen_test, targets=targets_gen_test, ptl_list=ptl_list, if_noise=if_noise, spk_len=spk_len, kernel=ker_test, net_type='triplet', aug_times=aug_times)
test= spk_pairs_test

Generate test restuls

In [None]:
# Generate validation predictions
spk_pairs, predictions = arb_w_gen(spk_pairs=spk_pairs_test, spk_len=spk_len, kernel=ker_test, net_type='triplet')

In [None]:
ptl_len = [len(test_idx_stdp), len(data2_gen[data2_gen['ptl_idx']==2]), len(data2_gen[data2_gen['ptl_idx']==4]), len(test_idx_quad)]
ptl_whole_len, targets_ptl, predictions_ptl = dw_gen.target_pred_gen(targets_test, predictions, ptl_len, aug_times)

In [None]:
ptl_name = ['Stdp', 'Triplet', 'Trip2', 'Quadruplet']
x_fit = np.linspace(np.min(targets_vali)-1, np.max(targets_vali)+1, 100)
for i in range(len(ptl_len)):
    plt.plot(np.linspace(-30,70,100),np.linspace(-30,70,100), 'k--' )
    # R2, corr, y_fit = perform_eval.R2_corr(predictions_ptl[i],targets_ptl[i], x_fit)
    plt.plot(targets_ptl[i], predictions_ptl[i], 'o', label=ptl_name[i]+'(n={a})'.format(a=targets_ptl[i].shape[0]))

R2, corr, y_fit = perform_eval.R2_corr(predictions,targets_test, x_fit)    

plt.xlabel('targets')
plt.ylabel('predictions')
plt.title(''.join(('R2=%.4f'%(R2), ', Corr=%.4f'%(corr))))
plt.legend()