# Sensing Local Field Potentials with a Directional and Scalable Depth Array: the DISC electrode array
## Figure 1
- Associated data: F/B Ratio vs. Lead diameter
- Link: (nature excel link)

## Description:
#### This module performs simple calculations for computing front-to-back ratio and plotting the resulting voltage shown in Inset C for Figure 1.
- Input: ANSYS .csv output 
- Output: .eps file containing the plotted data from Fig. 1C

# Settings

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import random as rand # Source: https://docs.python.org/3/library/random.html
import numpy as np

## Define some vars
macro_data_path = '/home/jovyan/ansys_data/8-source-macro-large.csv'
virtual_data_path = '/home/jovyan/ansys_data/8-source-macro-virtual.csv'
waveforms_snr_source_index = 0 #defines which source index is shown in waveform & snr vs. trials plots
macro_rms_noise = 2.7 #uV
virtual_rms_noise = 4.3 #uV
num_columns = 8
num_rows = 12
num_trials = 250
snr_averages = 1
sample_period = 0.25 #ms
end_time = 60 #ms

# Methods

In [None]:
## Methods
# Function to create waveforms from the amplitude of each signal.
# Returns 'waveforms_trials', which is a 3D array. 
# Adapted to return a tuple of waveforms_trials, one for MACRO and one for VIRTUAL
# Axis 0 is a "sheet", for each trial.
# Axis 1 is the waveform for each source.
# Axis 2 is the value at each timepoint.
def create_waveforms(virtual_data, macro_data, signal_index):
    v_waveforms_trials = []
    m_waveforms_trials = []
    v_waveforms = []
    m_waveforms = []
    
    for _ in range(num_trials):
        v_waveforms.clear()
        m_waveforms.clear()
        #Always want source we're looking at to be in phase & f=1
        start_phase = rand.randrange(-90, -60, 1) #This way, every waveform is at least f=4f
        stop_phase = rand.randrange(20, 90, 1) #So, from 4f to 9f
        
        for idx in range(virtual_data.size): #each row is a source
            v_new_waveform = []
            m_new_waveform = []
            if idx == signal_index:
                start_phase, stop_phase = 0, 20
            for phase in np.linspace(start_phase, stop_phase, int(end_time / sample_period)): #loop without pi for easier math. Mult. after. Linspace = (start, stop, #pts)
                #Use max/2 for peak-to-trough voltage
                v_new_waveform.append((virtual_data[idx])*np.sin((phase*0.1)*np.pi)) #Virtual & macro have same phase & noise 
                m_new_waveform.append((macro_data[idx])*np.sin((phase*0.1)*np.pi))
            v_waveforms.append(v_new_waveform.copy()) #Add each waveform to the array (trial x)
            m_waveforms.append(m_new_waveform.copy())
            #New params for next source
            start_phase = rand.randrange(-90, -60, 1) #This way, every waveform is at least f=4f
            stop_phase = rand.randrange(20, 90, 1) #So, from 4f to 9f
        #Append trial as array, start next trial
        v_waveforms_trials.append(v_waveforms.copy()) #Add each trial to the array (all waveforms)
        m_waveforms_trials.append(m_waveforms.copy()) #Also macro
        #add same noise to all sources, at each trial & each time point
    v_wfms_srcs_trials = np.array(v_waveforms_trials)
    m_wfms_srcs_trials = np.array(m_waveforms_trials)
    for idx in range(num_trials):
        for time in range(int(end_time/sample_period)):
            noise = rand.gauss(mu=0, sigma=macro_rms_noise)
            v_wfms_srcs_trials[idx, :, time] = np.add(v_wfms_srcs_trials[idx,:,time], noise)
            m_wfms_srcs_trials[idx, :, time] = np.add(m_wfms_srcs_trials[idx,:,time], noise)
    
    return v_wfms_srcs_trials, m_wfms_srcs_trials



# Function to calculate SNR of the waveforms at each successive averaged trial.
# Takes waveforms_trials and source_index (which row is the source?) as arguments.
# Returns "snr_vec", which is the SNR of each trial. Size [(#trials) x 1].
def calculate_snr_each_trial(waveforms_trials, source_index, rms_noise):
    snr_vec, signal_vec, noise_vec = [], [], []
          
    for trial_stats in range(num_trials):
        wfms = waveforms_trials[0:trial_stats+1,:,:]
        #sum together the sources first
        total_trials = np.sum(wfms, axis=1) #now have 50x240 (trial, time)
        noise_trials = np.sum(np.delete(wfms, obj=source_index, axis=1), axis=1)
        signal_rms = np.sqrt(np.mean(np.square(np.mean(total_trials, axis=0)))) #Avg all trials together, take rms
        noise_rms = np.sqrt(np.mean(np.square(np.mean(noise_trials,axis=0)))) #Avg all noise trials together, take rms
        snr = (signal_rms / noise_rms)
        snr_vec.append(10*np.log10(snr))
        signal_vec.append(10*np.log10(signal_rms))
        noise_vec.append(10*np.log10(noise_rms))
#         signal_vec.append(signal_rms)
#         noise_vec.append(noise_rms)
    return snr_vec, signal_vec, noise_vec

# Main runner

In [None]:
data_macro = pd.read_csv(macro_data_path)
data_virtual = pd.read_csv(virtual_data_path)
# Drop the label row, plus source params columns
voltages_macro = data_macro.drop(["r_s", "theta_e", "gap"], axis=1).to_numpy(copy=True)
macro_rows, macro_columns = voltages_macro.shape
voltages_virtual = data_virtual.drop(["r_s", "theta_e", "gap"], axis=1).to_numpy(copy=True)
virtual_rows, virtual_columns = voltages_virtual.shape

macro_snr_vecs, virtual_snr_vecs, macro_signal_vecs, virtual_signal_vecs, macro_noise_vecs, virtual_noise_vecs = [],[],[],[],[],[]
for _ in range(snr_averages):
    ## Create Sine Waves (Virtual & Macro)
    macro_waveforms_trials = []
    virtual_waveforms_trials = []
    for idx in range(macro_rows):
        v_wfms_t, m_wfms_t = create_waveforms(voltages_virtual, voltages_macro, idx)
        macro_waveforms_trials.append(m_wfms_t)
        virtual_waveforms_trials.append(v_wfms_t)

    ## Calculate SNR for each trial (macro)
    macro_snr_vec, macro_signal_vec, macro_noise_vec = [], [], []
    for signal_index in range(macro_rows):
        snr, signal, noise = calculate_snr_each_trial(macro_waveforms_trials[signal_index], signal_index, macro_rms_noise)
        macro_snr_vec.append(snr)
        macro_signal_vec.append(signal)
        macro_noise_vec.append(noise)
    macro_snr_vecs.append(macro_snr_vec.copy())
    macro_signal_vecs.append(macro_signal_vec.copy())
    macro_noise_vecs.append(macro_noise_vec.copy())

    ## Calculate SNR for each trial (virtual)
    virtual_snr_vec, virtual_signal_vec, virtual_noise_vec = [], [], []
    for signal_index in range(virtual_rows):
        snr, signal, noise = calculate_snr_each_trial(virtual_waveforms_trials[signal_index], signal_index, virtual_rms_noise)
        virtual_snr_vec.append(snr)
        virtual_signal_vec.append(signal)
        virtual_noise_vec.append(noise)
    virtual_snr_vecs.append(virtual_snr_vec.copy())
    virtual_signal_vecs.append(virtual_signal_vec.copy())
    virtual_noise_vecs.append(virtual_noise_vec.copy())
    
macro_snr_vec_avg = np.mean(macro_snr_vecs, axis=0)
virtual_snr_vec_avg = np.mean(virtual_snr_vecs, axis=0)
macro_signal_vec_avg = np.mean(macro_signal_vecs, axis=0)
virtual_signal_vec_avg = np.mean(virtual_signal_vecs, axis=0)
macro_noise_vec_avg = np.mean(macro_noise_vecs, axis=0)
virtual_noise_vec_avg = np.mean(virtual_noise_vecs, axis=0)

# Plot data

In [12]:
## Create plots
# First, clear old plot if one exists
plt.clf()
# Create figures (2 plots)
fig, axs = plt.subplots(4,2, figsize=(15,15), sharex=False, sharey=False)
fig.tight_layout()

# get waveforms
macro_trial_1_wfm = np.sum(macro_waveforms_trials[waveforms_snr_source_index][0], axis=0)
macro_avg_wfms = np.mean(macro_waveforms_trials[waveforms_snr_source_index], axis=0)
macro_trial_1_wfm_amp = np.max(macro_trial_1_wfm) - np.min(macro_trial_1_wfm)
macro_avg_wfm_amp = np.max(macro_avg_wfms) - np.min(macro_avg_wfms)

virtual_trial_1_wfm = np.sum(virtual_waveforms_trials[waveforms_snr_source_index][0], axis=0)
virtual_avg_wfms = np.mean(virtual_waveforms_trials[waveforms_snr_source_index], axis=0)
virtual_trial_1_wfm_amp = np.max(virtual_trial_1_wfm) - np.min(virtual_trial_1_wfm)
virtual_avg_wfm_amp = np.max(virtual_avg_wfms) - np.min(virtual_avg_wfms)
print("Trial 1 Virtual amplitude is", virtual_trial_1_wfm_amp / macro_trial_1_wfm_amp * 100, "% of modeled macro.")
print("Trial", num_trials, "Virtual amplitude is", virtual_avg_wfm_amp / macro_avg_wfm_amp * 100, "% of modeled macro.")

# Chart for waveforms
# Macro
axs[0,0].plot(np.arange(0, end_time, sample_period), macro_trial_1_wfm, color='r')
axs[0,1].plot(np.arange(0, end_time, sample_period), macro_avg_wfms[waveforms_snr_source_index], color='r')
axs[0,0].set_yticks([-200,-150,-100,-50,0,50,100,150,200])
axs[0,0].title.set_text("Trial 1 Waveform")

axs[0,0].plot(np.arange(0, end_time, sample_period), virtual_trial_1_wfm, color='b')
axs[0,1].plot(np.arange(0, end_time, sample_period), virtual_avg_wfms[waveforms_snr_source_index], color='b')
axs[0,1].set_yticks([-50,-25,0,25,50])
axs[0,1].title.set_text("Trial 250 Waveform")

# Bar chart plot for all sources SNR
# Macro
for idx, row in enumerate(macro_snr_vec_avg):
    axs[1,1].bar(idx+0.875, height=row[num_trials-1], color='r', width=0.25)
    axs[1,0].bar(idx+0.875, height=row[0], color='r', width=0.25)
    
# Virtual
for idx, row in enumerate(virtual_snr_vec_avg):
    axs[1,1].bar(idx+1.125, height=row[num_trials-1], color='b', width=0.25)
    axs[1,0].bar(idx+1.125, height=row[0], color='b', width=0.25)
    
# Compare SNR values
for idx in range(virtual_rows):
    ratio_in_percent = 10**((virtual_snr_vec_avg[idx][num_trials-1]-macro_snr_vec_avg[idx][num_trials-1])/10)*100
    print("Source", idx+1, "SNR ratio (virtual/macro*100%):", ratio_in_percent, "%")
    
axs[1,1].set_xticks([1,2,3,4,5,6,7,8])
axs[1,1].set_yticks([0,2,4,6,8,10])
axs[1,1].title.set_text("Trial 250 SNR")
axs[1,1].grid(which='both',axis='both') 
axs[1,0].set_xticks([1,2,3,4,5,6,7,8])
axs[1,0].set_yticks([0,2,4,6,8,10])
axs[1,0].title.set_text("Trial 1 SNR")  
axs[1,0].grid(which='both',axis='both')

# Bar chart plot for all sources Signal
# Macro
for idx, row in enumerate(macro_signal_vec_avg):
    axs[2,1].bar(idx+0.875, height=row[num_trials-1], color='r', width=0.25)
    axs[2,0].bar(idx+0.875, height=row[0], color='r', width=0.25)
    
# Virtual
for idx, row in enumerate(virtual_signal_vec_avg):
    axs[2,1].bar(idx+1.125, height=row[num_trials-1], color='b', width=0.25)
    axs[2,0].bar(idx+1.125, height=row[0], color='b', width=0.25)
    
axs[2,1].set_xticks([1,2,3,4,5,6,7,8])
# axs[2,1].set_yticks([0,10,20,30,40,50])
axs[2,1].set_yticks([0,5,10,15,20])
axs[2,1].title.set_text("Trial 250 Signal")
axs[2,1].grid(which='both',axis='both')
axs[2,0].set_xticks([1,2,3,4,5,6,7,8])
# axs[2,0].set_yticks([0,20,40,60,80])
axs[2,0].set_yticks([0,5,10,15,20])
axs[2,0].title.set_text("Trial 1 Signal")
axs[2,0].grid(which='both',axis='both')
    
# Bar chart plot for all sources Noise
# Macro
for idx, row in enumerate(macro_noise_vec_avg):
    axs[3,1].bar(idx+0.875, height=row[num_trials-1], color='r', width=0.25)
    axs[3,0].bar(idx+0.875, height=row[0], color='r', width=0.25)
    
# Virtual
for idx, row in enumerate(virtual_noise_vec_avg):
    axs[3,1].bar(idx+1.125, height=row[num_trials-1], color='b', width=0.25)
    axs[3,0].bar(idx+1.125, height=row[0], color='b', width=0.25)
    
axs[3,1].set_xticks([1,2,3,4,5,6,7,8])
axs[3,1].set_yticks([0,5,10,15,20])
axs[3,1].title.set_text("Trial 250 Noise")
axs[3,1].grid(which='both',axis='both')
axs[3,0].set_xticks([1,2,3,4,5,6,7,8])
# axs[3,0].set_yticks([0,20,40,60,80])
axs[3,0].set_yticks([0,5,10,15,20])
axs[3,0].title.set_text("Trial 1 Noise")
axs[3,0].grid(which='both',axis='both')


plt.savefig('/home/jovyan/ansys_data/images/8-source-macro-virtual-compare.eps', format='eps', dpi=500)
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: '/home/jovyan/ansys_data/8-source-macro-large.csv'