# RCS simulation tutorial

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
from ast import literal_eval
from tkinter import Tk, filedialog

from rcssim import *

## Load data

In [None]:
# Open file dialog for selecting neural data (.csv)
root = Tk()
root.withdraw()
root.call('wm', 'attributes', '.', '-topmost', True)

# print('Select left hemisphere .csv file.')
# left_file_name = filedialog.askopenfilename(multiple=False)

print('Select folder containing data and settings files.')
data_folder = filedialog.askdirectory()

root.destroy()
# left_neural_data = pd.read_csv(left_file_name)
data = pd.read_csv(data_folder + '/dataset_ld9.csv')
settings = pd.read_csv(data_folder + '/dataset_ld9_config.csv')

settings['band_edges_hz'] = settings['band_edges_hz'].apply(literal_eval)
settings['subtract_vec'] = settings['subtract_vec'].apply(literal_eval)
settings['multiply_vec'] = settings['multiply_vec'].apply(literal_eval)
settings['update_rate'] = settings['update_rate'].apply(literal_eval)
settings['weights'] = settings['weights'].apply(literal_eval)
settings['threshold'] = settings['threshold'].apply(literal_eval)
settings['blank_duration'] = settings['blank_duration'].apply(literal_eval)
settings['onset'] = settings['onset'].apply(literal_eval)
settings['termination'] = settings['termination'].apply(literal_eval)
settings['blank_both'] = settings['blank_both'].apply(literal_eval)
settings['target_amp'] = settings['target_amp'].apply(literal_eval)

In [None]:
print('Data')
data

In [None]:
print('Settings')
settings

## Isolated TD &rarr; PB

In [None]:
td_data = df.TD.values[~np.isnan(df.TD.values),np.newaxis]
t_td = df.timestamp.values[~np.isnan(df.TD.values)]

amp_gain = 231
fs_td = 250
L = 64
interval = 50
bit_shift = 3
band_edges_hz = [[19, 24]]

t_start = time.time()
td_data = transform_mv_to_rcs(td_data, amp_gain)
hann_win = create_hann_window(L, percent=100)
data_fft, t_pb_comp = rcs_td_to_fft(td_data, t_td, fs_td, L, interval, hann_win, 
                                    output_in_mv=False)
pb_data_comp = rcs_fft_to_pb(data_fft, fs_td, L, bit_shift, 
                             band_edges_hz=band_edges_hz, input_is_mv=False)
t_end = time.time()
print('Time elapsed: ' + str(t_end-t_start))

In [None]:
%matplotlib widget
fig, ax = plt.subplots(1,1, figsize=(6,2), sharex='col', sharey=False)

t0 = df.timestamp.values[0]
pb_data = df.PB.values[~np.isnan(df.PB.values),np.newaxis]
t_pb = df.timestamp.values[~np.isnan(df.PB.values)]

ax.plot(t_pb_comp-t0, pb_data_comp, label='Computed')
ax.plot(t_pb-t0, pb_data, label='Measured')
ax.legend(bbox_to_anchor=(1.02, 0.6))
ax.set_xlim([1516, 1526])
ax.set_ylim([0, 100])
ax.axhline(y=19, color='k', linestyle='--')
ax.axhline(y=20, color='k', linestyle='--')
ax.grid()
ax.set_xlabel('Time [grid=2sec]')
ax.set_ylabel('PB Output \n [RCS units]')

plt.tight_layout()

## Isolated PB &rarr; stim

In [None]:
t_start = time.time()
pb_data = df.PB.values[~np.isnan(df.PB.values),np.newaxis]
t_adapt = df.timestamp.values[~np.isnan(df.PB.values)]
ld_output_comp, ld_state_comp, time_ld_comp = rcs_pb_to_ld(
                        pb_data, 
                        t_adapt, 
                        update_rate=[2,[]], 
                        weights=[[1],[]], 
                        dual_threshold=[True,[]], 
                        threshold=[[19,20],[]], 
                        onset_duration=[0,[]], 
                        termination_duration=[0,[]], 
                        blank_duration=[11,[]], 
                        blank_both=[False, False], 
                        subtract_vec=[np.zeros(4), np.zeros(4)], 
                        multiply_vec=[np.ones(4), np.ones(4)])
stim_comp, time_stim_comp = rcs_ld_to_stim(ld_state_comp, time_ld_comp, 
                                           target_amp=[2.0, 2.0, 2.6,0,0,0,0,0], 
                                           rise_time=1.6, fall_time=1.6)
t_end = time.time()
print('Time elapsed: ' + str(t_end-t_start))

In [None]:
# %matplotlib widget
fig, ax = plt.subplots(3,1, figsize=(6,4), sharex='col', sharey=False)

t0 = df.timestamp.values[0]
ld_output = df.feature.values[~np.isnan(df.feature.values)]
t_output = df.timestamp.values[~np.isnan(df.feature.values)]
ld_state = df.state.values[~np.isnan(df.state.values)]
t_state = df.timestamp.values[~np.isnan(df.state.values)]
stim = df.stim.values[~np.isnan(df.stim.values)]
t_stim = df.timestamp.values[~np.isnan(df.stim.values)]

ax[0].plot(time_ld_comp-t0, ld_output_comp[0], label='Computed')
ax[0].plot(t_output-t0, ld_output, label='Measured')
ax[0].legend(bbox_to_anchor=(1.02, 0.6))
ax[0].set_xlim([1516, 1526])
ax[0].set_ylim([0, 100])
ax[0].axhline(y=19, color='k', linestyle='--')
ax[0].axhline(y=20, color='k', linestyle='--')
ax[0].grid()
ax[0].set_ylabel('LD Output')

ax[1].plot(time_ld_comp-t0, ld_state_comp+0.05, label='Computed')
ax[1].plot(t_state-t0, ld_state, label='Measured')
ax[1].set_ylim([0, 2.5])
ax[1].grid()
ax[1].set_ylabel('LD State')

ax[2].plot(time_stim_comp-t0, stim_comp+0.01, label='Computed')
ax[2].plot(t_stim-t0, stim, label='Measured')
ax[2].grid()
ax[2].set_xlabel('Time [grid=2sec]')
ax[2].set_ylim([1.6,2.05])
ax[2].set_yticks([2.0,2.4,2.8])
ax[2].set_yticklabels([2.0,'',2.8])
ax[2].set_ylabel('Stim amplitude \n [mA]')

plt.tight_layout()

## Full TD &rarr; stim

In [None]:
td_data = df.TD.values[~np.isnan(df.TD.values),np.newaxis]
t_td = df.timestamp.values[~np.isnan(df.TD.values)]
ld_output = df.feature.values[~np.isnan(df.feature.values)]
t_output = df.timestamp.values[~np.isnan(df.feature.values)]
ld_state = df.state.values[~np.isnan(df.state.values)]
t_state = df.timestamp.values[~np.isnan(df.state.values)]
stim = df.stim.values[~np.isnan(df.stim.values)]
t_stim = df.timestamp.values[~np.isnan(df.stim.values)]

t_start = time.time()
td_data = transform_mv_to_rcs(td_data, amp_gain)
hann_win = create_hann_window(L, percent=100)
data_fft, t_pb_comp = rcs_td_to_fft(td_data, t_td, fs_td, L, interval, hann_win, 
                                    output_in_mv=False)
pb_data_comp = rcs_fft_to_pb(data_fft, fs_td, L, bit_shift, 
                             band_edges_hz=band_edges_hz, input_is_mv=False)
ld_output_comp, ld_state_comp, time_ld_comp = rcs_pb_to_ld(
                        pb_data_comp, 
                        t_pb_comp, 
                        update_rate=[2,[]], 
                        weights=[[1],[]], 
                        dual_threshold=[True,[]], 
                        threshold=[[19,20],[]], 
                        onset_duration=[0,[]], 
                        termination_duration=[0,[]], 
                        blank_duration=[11,[]], 
                        blank_both=[False, False], 
                        subtract_vec=[np.zeros(4), np.zeros(4)], 
                        multiply_vec=[np.ones(4), np.ones(4)])
stim_comp, time_stim_comp = rcs_ld_to_stim(ld_state_comp, time_ld_comp, 
                                           target_amp=[2.0, 2.0, 2.6,0,0,0,0,0], 
                                           rise_time=1.6, fall_time=1.6)
t_end = time.time()
print('Time elapsed: ' + str(t_end-t_start))

In [None]:
# %matplotlib widget
fig, ax = plt.subplots(4,1, figsize=(6,6), sharex='col', sharey=False)

ax[0].plot(t_pb_comp-t0, pb_data_comp, label='Computed')
ax[0].plot(t_pb-t0, pb_data, label='Measured')
ax[0].legend(bbox_to_anchor=(1.02, 0.6))
ax[0].set_xlim([1516, 1526])
ax[0].set_ylim([0, 100])
ax[0].axhline(y=19, color='k', linestyle='--')
ax[0].axhline(y=20, color='k', linestyle='--')
ax[0].grid()
ax[0].set_ylabel('PB Output \n [RCS units]')

ax[1].plot(time_ld_comp-t0, ld_output_comp[0], label='Computed')
ax[1].plot(t_output-t0, ld_output, label='Measured')
ax[1].set_xlim([1516, 1526])
ax[1].set_ylim([0, 100])
ax[1].axhline(y=19, color='k', linestyle='--')
ax[1].axhline(y=20, color='k', linestyle='--')
ax[1].grid()
ax[1].set_ylabel('LD Output')

ax[2].plot(time_ld_comp-t0, ld_state_comp+0.05, label='Computed')
ax[2].plot(t_state-t0, ld_state, label='Measured')
ax[2].set_ylim([0, 2.5])
ax[2].grid()
ax[2].set_ylabel('LD State')

ax[3].plot(time_stim_comp-t0, stim_comp+0.01, label='Computed')
ax[3].plot(t_stim-t0, stim, label='Measured')
ax[3].grid()
ax[3].set_xlabel('Time [grid=2sec]')
ax[3].set_ylim([1.6,2.05])
ax[3].set_yticks([2.0,2.4,2.8])
ax[3].set_yticklabels([2.0,'',2.8])
ax[3].set_ylabel('Stim amplitude \n [mA]')

plt.tight_layout()