### This is to study the relationship between phase in oscilation and performance


In [None]:
# import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.signal import hilbert

from functions import Generate_Vectors, Generate_RandomMatrix
from functions import show_mn, show_conn
from functions import Draw_Output, Draw_Conductance,  load_config_yaml, Draw_RasterPlot, Draw_Voltage, Draw_Projection, save_model
from functions import plot_peak_envelope, peak_envelope
from functions import load_init
from lowranksnn import LowRankSNN
plt.rcParams.update({'font.size': 30})  

from pathlib import Path
import os
import csv
import datetime
import yaml



In [None]:
# Read the configuration file
config = load_config_yaml('./configures/config_test_phase_sensitivity.yaml')

N_E = config['N_E']
N_I = config['N_I']
N = N_E + N_I
P_EE = config['P_EE']
P_EI = config['P_EI']
P_IE = config['P_IE']
P_II = config['P_II']
factor_mn = config['factor_mn'] # 组合成conn时乘在lowrank matrix上的常數
RS = config['RandomStrength'] # 组合成conn时乘在random matrix上的常數

taud_E = config['taud_E']
taud_I = config['taud_I']

eta_E = config['eta_E']
eta_I = config['eta_I']
delta_E = config['delta_E']
delta_I = config['delta_I']

mu = config['mu']
si = config['sigma']

si_rand = config['sigma_rand']
dt = config['dt'] #(ms/step)
T_pre = config['T_pre'] # length of time before sti (ms)
T_sti = config['T_sti'] # length of time for sti (ms)
T_after = config['T_after'] # length of time after sti (ms)
T = T_pre+T_sti+T_after # length of Period time (ms): 30ms

IS = config['InputStrength'] #Input Strength (maybe chage to norm in the future)

color_Go = config['color_Go']
color_Nogo = config['color_Nogo']

num_phase = config['num_phase']
trails = config['trails']


In [None]:
# Initialiazation
LRSNN = LowRankSNN(N_E=N_E,N_I=N_I,taud_E=taud_E,taud_I=taud_I,RS=RS)

i = 0
while i<100:
    m_test, n_test, Sti_nogo_test = Generate_Vectors(N, mu, si)
    if torch.sum(m_test[:N_E]).abs() < 1 and torch.sum(n_test[:N_E]).abs() < 1 and torch.sum(Sti_nogo_test[:N_E]).abs() < 1:
        print(N,mu,si)
        # sum of all the element in m and n and Sti_nogo_test
        print(torch.sum(m_test[:N_E]))
        print(torch.sum(n_test[:N_E]))
        print(torch.sum(Sti_nogo_test[:N_E]))
        print('i:',i)
        print('-----------------------------------')
        m = m_test
        n = n_test
        Sti_nogo = Sti_nogo_test
        break
    i += 1
    if i == 100:
        i = 0
        print('did not find the suitable m, n, Sti_nogo')

m[N_E:] = 0
n[N_E:] = 0
Sti_nogo[N_E:] = 0
Sti_go = n.clone()
W_out = m.clone()
W_rank1 = factor_mn*torch.ger(m.squeeze(), n.squeeze())
conn_rand = Generate_RandomMatrix(N_E, N_I, P_EE, P_EI, P_IE, P_II, W_rank1, sigma = si_rand)


In [None]:
# Assemble the Network
LRSNN.add_lowrank(W_rank1, W_out)
LRSNN.add_random(conn_rand)

LRSNN.conn[LRSNN.conn>1] = 1
LRSNN.conn[LRSNN.conn<0] = 0

In [None]:
# 1st simulation: get the first zero phase time after 100 ms (use hilbert transform)

T = T_pre+T_sti+T_after # length of Period time (ms）

Input_go = torch.zeros((LRSNN.N_E+LRSNN.N_I,int(T/dt))) #size:(N,time)
Input_go[:,int(T_pre/dt):int((T_pre+T_sti)/dt)] = IS*Sti_go
Input_nogo = torch.zeros((LRSNN.N_E+LRSNN.N_I,int(T/dt)))
Input_nogo[:,int(T_pre/dt):int((T_pre+T_sti)/dt)] = IS*Sti_nogo
#

# bias current
bias = torch.zeros_like(Input_go)
bias[:N_E,:] = (eta_E+delta_E*torch.tan(torch.tensor(np.pi*(np.arange(1,N_E+1)/(N_E+1)-1/2)))).unsqueeze(1)
bias[N_E:,:] = (eta_I+delta_I*torch.tan(torch.tensor(np.pi*(np.arange(1,N_I+1)/(N_I+1)-1/2)))).unsqueeze(1)

#将模型及相应属性移动到GPU
device = torch.device('cuda:0')
LRSNN = LRSNN.to(device)
Input_go = Input_go.to(device)
Input_nogo = Input_nogo.to(device)
bias = bias.to(device)

# Start Simulation
Out_ref, V_ref, [g_ref,g_ref_EE,g_ref_EI,g_ref_IE,g_ref_II],[I_ref_syn,I_ref_syn_EE,I_ref_syn_EI,I_ref_syn_IE,I_ref_syn_II], spk_step_ref, spk_ind_ref, spk_ref, phase_ref = LRSNN(dt,bias)



In [None]:
# load the values at T_pre
LRSNN = load_init(LRSNN, T_pre, dt, g_ref, g_ref_EE, g_ref_EI, g_ref_IE, g_ref_II, V_ref, phase_ref, I_ref_syn, I_ref_syn_EE, I_ref_syn_EI, I_ref_syn_IE, I_ref_syn_II, spk_ref)


In [None]:

# do hilbert transform to get the phase of the conductance
g_ref_EE_np = g_ref_EE.clone().cpu().detach().numpy()
g_ref_II_np = g_ref_II.clone().cpu().detach().numpy()
g_ref_EI_np = g_ref_EI.clone().cpu().detach().numpy()
g_ref_IE_np = g_ref_IE.clone().cpu().detach().numpy()

signal = np.mean(g_ref_II_np, axis=0)

signal = np.mean(g_ref_EE_np, axis=0)

signal_II = np.mean(g_ref_II_np, axis=0)
signal_EE = np.mean(g_ref_EE_np, axis=0)
signal_EI = np.mean(g_ref_EI_np, axis=0)
signal_IE = np.mean(g_ref_IE_np, axis=0)

# centralize the signal
mean_signal = np.mean(signal)
signal = signal - mean_signal
analytic_signal = hilbert(signal)
amplitude_envelope = np.abs(analytic_signal)  
instantaneous_phase = np.angle(analytic_signal)  
phase_diff = np.diff(instantaneous_phase)  
t = np.array(range(len(signal)))*dt

start_time = T_pre
end_time = T_pre+40
plt.figure(figsize=(12, 8))
t = np.array(range(len(signal)))*dt

plt.subplot(2, 1, 1)

plt.plot(t[int(start_time/dt):int(end_time/dt)], signal[int(start_time/dt):int(end_time/dt)]+mean_signal)

plt.title("Synaptic Conductance (E to E)")
plt.xticks([])
plt.yticks([])
plt.grid(True)

plt.subplot(2, 1, 2)

plt.plot(t[int(start_time/dt):int(end_time/dt)], instantaneous_phase[int(start_time/dt):int(end_time/dt)], color='green')
plt.grid(True)
plt.title("Instantaneous Phase")
plt.xlabel("Time (ms))")

ticks = [-np.pi,  0,  np.pi]
labels = [r"$-\pi$",  r"$0$",  r"$\pi$"]

plt.yticks(ticks, labels)



In [None]:
plt.figure(figsize=(10, 5))
plt.plot(instantaneous_phase[int(200/dt):int(250/dt)],signal_II[int(200/dt):int(250/dt)], color='green', label='g_II')
plt.plot(instantaneous_phase[int(200/dt):int(250/dt)],signal_EE[int(200/dt):int(250/dt)]*4.5, color='orange', label='g_EE')

plt.xlabel("Phase")
plt.ylabel("g (mS/cm^2)")

plt.ylabel("g (A.U.)")
plt.yticks([])

plt.legend()

ticks = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]

plt.xticks(ticks, labels)
plt.grid(True)
plt.show()

In [None]:
# find the first zero phase time
T_pre_ind = int(T_pre/dt)
phase_diff_T_pre = phase_diff[T_pre_ind:]

# find the index where the phase difference crosses -3(.14)
crossings = np.where(phase_diff_T_pre<-3)[0]  
print(phase_diff_T_pre[crossings])

# record the beginning and end of the first cycle
if len(crossings) >= 2:
    start_index = crossings[0]  
    end_index = crossings[1]  
else:
    raise ValueError("didn't find the first cycle")

# calculate the starting and ending phase of the first cycle
start_index += 1 
end_index += 1 
phase_start = instantaneous_phase[T_pre_ind+start_index]
phase_end = instantaneous_phase[T_pre_ind+end_index]

print('First minimum phase:', phase_start)
print('time:', T_pre+start_index*dt, 'ms')
print('First maximum phase:', phase_end)
print('time:', T_pre+end_index*dt, 'ms')
print('period:', (end_index-start_index)*dt, 'ms')


In [None]:
from scipy.interpolate import interp1d
# obtain the corresponding time of the 33 phases
time = t[T_pre_ind+start_index:T_pre_ind+end_index]
phase_period = instantaneous_phase[T_pre_ind+start_index:T_pre_ind+end_index]

# generate 33 evenly spaced phase points
phase_target = np.linspace(-np.pi, np.pi, num_phase)

# find the time corresponding to these phases
interp_func = interp1d(phase_period, time, kind='linear', fill_value="extrapolate")
time_target = interp_func(phase_target) # the time of the 33 phases after T_pre (ms)

# limit the time_target to the range of time
time_target[time_target<=time[0]] = time[0]
time_target[time_target>=time[-1]] = time[-1]
time_target_ind = (time_target/dt).astype(int)
# time_target_ind_after_T_pre = time_target_ind - T_pre_ind

plt.figure(figsize=(12, 8))
plt.plot(time, phase_period, '-', label='Original period')
plt.plot(time_target, instantaneous_phase[time_target_ind], 'o', label='Target phase')
plt.xlabel('Time (ms)')
plt.ylabel('Phase')
plt.legend()
plt.title('33 phases')

In [None]:
#simulation: get the reaction time for different phases

#store the reaction time for different phases
reaction_times = []
Input_go_rec = []
Out_go_rec = []
Out_nogo_rec = []

T_pre_origin = T_pre
T_after_origin = T_after
for T_phase in time_target-T_pre_origin:
    T_pre = T_phase
    T_after = T_after_origin-T_phase # length of time after sti (ms) for the 2nd simulation

    T = T_pre+T_sti+T_after

    Input_go = torch.zeros((LRSNN.N_E+LRSNN.N_I,int(T/dt))) #size:(N,time)
    Input_go[:,int(T_pre/dt):int((T_pre+T_sti)/dt)] = IS*Sti_go
    Input_nogo = torch.zeros((LRSNN.N_E+LRSNN.N_I,int(T/dt)))
    Input_nogo[:,int(T_pre/dt):int((T_pre+T_sti)/dt)] = IS*Sti_nogo

    Input_go_rec.append(Input_go.tolist())

    # bias current
    bias = torch.zeros_like(Input_go)
    bias[:N_E,:] = (eta_E+delta_E*torch.tan(torch.tensor(np.pi*(np.arange(1,N_E+1)/(N_E+1)-1/2)))).unsqueeze(1)
    bias[N_E:,:] = (eta_I+delta_I*torch.tan(torch.tensor(np.pi*(np.arange(1,N_I+1)/(N_I+1)-1/2)))).unsqueeze(1)
 

    device = torch.device('cuda:0')
    LRSNN = LRSNN.to(device)
    Input_go = Input_go.to(device)
    Input_nogo = Input_nogo.to(device)
    bias = bias.to(device)

    # Note: initial values has been loaded
    # Start Simulation
    Out_go, V_go, g_go, I_syn_go, spk_step_go, spk_ind_go,_,_ = LRSNN(dt,Input_go+bias)
    Out_nogo, V_nogo, g_nogo, I_syn_nogo, spk_step_nogo, spk_ind_nogo,_,_ = LRSNN(dt,Input_nogo+bias)

    Out_go_rec.append(Out_go.cpu().tolist())
    Out_nogo_rec.append(Out_nogo.cpu().tolist())

T_pre = T_pre_origin
T_after = T_after_origin

In [None]:
Input_go_rec = np.array(Input_go_rec)
Out_go_rec = np.array(Out_go_rec)
Out_nogo_rec = np.array(Out_nogo_rec)

# save the data
import datetime
now = datetime.datetime.now()
# create a folder to store the data, name=now
import os
folder = f'./data_phase_sensitivity/{now.strftime("%y%m%d%H%M%S")}'
os.makedirs(folder)
np.save(folder+'/Input_go_rec'+'.npy', Input_go_rec)
np.save(folder+'/Out_go_rec'+'.npy', Out_go_rec)
np.save(folder+'/Out_nogo_rec'+'.npy', Out_nogo_rec)
np.save(folder+'/phases_eff'+'.npy', phases_eff)




In [None]:
# import csv
# import datetime
# now = datetime.datetime.now()
# filename = './data_phase_to_reaction_times/reaction_times_'+now.strftime('%y%m%d%H%M%S')+'.csv'
# with open(filename, mode='w') as file:
#     writer = csv.writer(file)
#     for i in range(len(phases_eff)):
#         writer.writerow([phases_eff[i], reaction_times[i]])

# draw the result plot

In [None]:
def reaction_time_amplitude(phases_eff, Out_go_rec, Input_go_rec, dt):
    interval_sti_peak_go = []
    peak_go_values = []

    for i in range(len(phases_eff)):
        go_signal_energy = Out_go_rec[i].squeeze()**2
        # nogo_signal_energy = Out_nogo_rec[i].squeeze()**2
        Input_go = Input_go_rec[i]
        # Input_nogo = Input_go_rec[i]
        # calculate the time when the output rich the peak
        peak_time_go = np.argmax(go_signal_energy)*dt
        peak_value_go = np.max(go_signal_energy)
        sti_start_time = np.nonzero(Input_go[0])[0][0]*dt
        sti_end_time = np.nonzero(Input_go[0])[0][-1]*dt
        interval_sti_peak_go.append(peak_time_go - sti_start_time)
        peak_go_values.append(peak_value_go)
    return interval_sti_peak_go, peak_go_values


In [None]:

# read the data

interval_rec = []
peak_rec = []
interval_rec_nogo = []
peak_rec_nogo = []
phases_eff_rec = []
dt = 0.1

for folder in os.listdir('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_phase_sensitivity/'): #please change to your path
    # print(folder)
    path_folder = os.path.join('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_phase_sensitivity/', folder)
    # print(path)
    if os.path.isdir(path_folder):
        for file in os.listdir(path_folder):
            path = os.path.join(path_folder, file)
            # print(path)
            if file == 'Input_go_rec.npy':
                Input_go_rec = np.load(path)
            elif file == 'Out_go_rec.npy':
                Out_go_rec = np.load(path)
            elif file == 'Out_nogo_rec.npy':
                Out_nogo_rec = np.load(path)
            elif file == 'phases.npy':
                phases = np.load(path)
        interval_sti_peak_go, peak_go_values = reaction_time_amplitude(phases, Out_go_rec, Input_go_rec, dt)
        interval_sti_peak_nogo, peak_nogo_values = reaction_time_amplitude(phases, Out_nogo_rec, Input_go_rec, dt)
        interval_rec_nogo.append(interval_sti_peak_nogo)
        peak_rec_nogo.append(peak_nogo_values)

        interval_rec.append(interval_sti_peak_go)
        peak_rec.append(peak_go_values)
        phases_eff_rec.append(phases)



In [None]:
interval_rec = np.array(interval_rec)
peak_rec = np.array(peak_rec)
interval_rec_nogo = np.array(interval_rec_nogo)
peak_rec_nogo = np.array(peak_rec_nogo)
phases_eff_rec = np.array(phases_eff_rec)

# calculate the mean and std of the reaction time for different phases
mean_interval = np.mean(interval_rec, axis=0)
std_interval = np.std(interval_rec, axis=0)
mean_peak = np.mean(peak_rec, axis=0)
std_peak = np.std(peak_rec, axis=0)

mean_interval_nogo = np.mean(interval_rec_nogo, axis=0)
std_interval_nogo = np.std(interval_rec_nogo, axis=0)
mean_peak_nogo = np.mean(peak_rec_nogo, axis=0)
std_peak_nogo = np.std(peak_rec_nogo, axis=0)

phases_eff = np.linspace(-np.pi, np.pi, num_phase)

In [None]:
# plot the mean and std of reaction time for different phases
plt.figure(figsize=(12, 8))
plt.errorbar(phases_eff, mean_interval, yerr=std_interval, fmt='o', capsize=5)
plt.xlabel('Phase')
plt.ylabel('Reaction Time (ms)')
ticks = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]

plt.xticks(ticks, labels)
plt.show()

In [None]:
#draw the std_interval to phase
plt.figure(figsize=(12, 8))
# plt.errorbar(phases_eff, mean_interval, yerr=std_interval, fmt='o', capsize=5)
plt.plot(phases_eff, std_interval, '-o')
plt.xlabel('Phase')
plt.ylabel('Std. of Reaction Time')
ticks = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]

plt.xticks(ticks, labels)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt


ax = plt.subplot(111, projection='polar')

ax.set_theta_direction(-1)

# labels = [r'$-\pi$', r'$-\frac{\pi}{2}$', '0', r'$\frac{\pi}{2}$', r'$\pi$']
# ax.set_xticklabels(labels)

ax.set_thetamin(-180)  # -π
ax.set_thetamax(180)   #  π


ax.bar(
    phases_eff[:-1], mean_peak[:-1], width=2*np.pi/33,  
    color=color_Go,  
    align='center',  
    bottom=0,  
    alpha=0.5,
    label='Reaction Time (ms)'
)

# ax.plot(phases_eff, mean_peak, label='Mean', color='blue', lw=2)
ax.fill_between(phases_eff,mean_peak - std_peak, mean_peak+ std_peak, color='gray', alpha=0.3, label='Mean ± Std')
# ax.yaxis.set_visible(False)
# ax.text(np.pi * 3 / 2 - 0.2, 90, 'Peak Energy', fontsize=14)
ax.legend(loc='upper right', fontsize=12)

In [None]:
# plot the mean and std of peak value for different phases
plt.figure(figsize=(10, 5))
plt.errorbar(phases_eff, mean_peak, yerr=std_peak, fmt='o', capsize=5,ecolor=color_Go,color=color_Go)
plt.xlabel('Phase')
plt.ylabel('Peak Energy')
# plt.title('Dependency of Peak Output Energy on Gamma Oscillation Phase')
ticks = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]

plt.xticks(ticks, labels)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt


ax = plt.subplot(111, projection='polar')
# ax.plot(theta, r)
# 顺时针
ax.set_theta_direction(-1)

# labels = [r'$-\pi$', r'$-\frac{\pi}{2}$', '0', r'$\frac{\pi}{2}$', r'$\pi$']
# ax.set_xticklabels(labels)

# 设置极坐标角度范围（单位是度）
ax.set_thetamin(-180)  # -π
ax.set_thetamax(180)   #  π

# 绘制柱状图
ax.bar(
    phases_eff[:-1], mean_interval[:-1], width=2*np.pi/33,  
    color='#32B6F7',  
    align='center',  
    bottom=0,  
    alpha=0.5,
    label='Reaction Time (ms)'
)

# ax.plot(phases_eff, mean_peak, label='Mean', color='blue', lw=2)
ax.fill_between(phases_eff,mean_interval - std_interval, mean_interval+ std_interval, color='gray', alpha=0.3, label='Mean ± Std')
# ax.yaxis.set_visible(False)
# ax.text(np.pi * 3 / 2 - 0.2, 90, 'Peak Energy', fontsize=14)
ax.legend(loc='upper right', fontsize=12)

In [None]:


# draw the figure
plt.figure(figsize=(12, 8))
plt.plot(phases_eff, interval_sti_peak_go, 'o-')
plt.xlabel('Phase')
plt.ylabel('Time  (ms)')
plt.title('Reaction time at different phases')#'Time interval between stimulus and peak for different phases'
plt.show()

plt.figure(figsize=(12, 8))
plt.plot(phases_eff, peak_go_values, 'o-')
plt.xlabel('Phase')
plt.ylabel('Peak value of the output')
plt.title('Peak energy for different phases')
plt.show()


In [None]:
# plot the mean and std of peak value for different phases in nogo condition
plt.figure(figsize=(10, 5))
plt.errorbar(phases_eff, mean_peak_nogo, yerr=std_peak_nogo, fmt='o', capsize=5,ecolor=color_Nogo,color=color_Nogo)
plt.xlabel('Phase')
plt.ylabel('Peak Energy')
# plt.title('Dependency of Peak Output Energy on Gamma Oscillation Phase')
ticks = [-np.pi, -np.pi/2, 0, np.pi/2, np.pi]
labels = [r"$-\pi$", r"$-\frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]

plt.xticks(ticks, labels)
plt.show()

In [None]:
# save the phases_eff, std_interval, std_peak, mean_interval, mean_peak seperately as csv file (file name is the same as their own names)
import csv
import datetime
now = datetime.datetime.now()
folder = f'./data_phase_sensitivity'
# os.makedirs(folder)
filename = folder+'/phases_eff.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(phases_eff)
filename = folder+'/std_interval.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(std_interval)
filename = folder+'/std_peak.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(std_peak)
filename = folder+'/mean_interval.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(mean_interval)
filename = folder+'/mean_peak.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(mean_peak)

filename = folder+'/std_interval_nogo.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(std_interval_nogo)
filename = folder+'/std_peak_nogo.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(std_peak_nogo)
filename = folder+'/mean_interval_nogo.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(mean_interval_nogo)
filename = folder+'/mean_peak_nogo.csv'
with open(filename, mode='w') as file:
    writer = csv.writer(file)
    writer.writerow(mean_peak_nogo)





In [None]:
# draw the peak energy of go and nogo at different states
def reaction_time_amplitude( Out_go_rec, Input_go_rec, dt):
    interval_sti_peak_go = []
    peak_go_values = []

    go_signal_energy = Out_go_rec.squeeze()**2
    # nogo_signal_energy = Out_nogo_rec[i].squeeze()**2
    Input_go = Input_go_rec
    # Input_nogo = Input_go_rec[i]
    # calculate the time when the output rich the peak
    peak_time_go = np.argmax(go_signal_energy)*dt
    peak_value_go = np.max(go_signal_energy)
    sti_start_time = np.nonzero(Input_go[0])[0][0]*dt
    sti_end_time = np.nonzero(Input_go[0])[0][-1]*dt
    interval_sti_peak_go.append(peak_time_go - sti_start_time)
    peak_go_values.append(peak_value_go)
    return interval_sti_peak_go, peak_go_values


peak_rec_gamma = []

peak_rec_nogo_gamma = []

dt = 0.1

for folder in os.listdir('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_peak_energy_gamma/'): #please change to your path
    # print(folder)
    path_folder = os.path.join('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_peak_energy_gamma/', folder)
    # print(path)
    if os.path.isdir(path_folder):
        for file in os.listdir(path_folder):
            path = os.path.join(path_folder, file)
            # print(path)
            if file == 'Input_go_rec.npy':
                Input_go_rec = np.load(path)
            elif file == 'Out_go_rec.npy':
                Out_go_rec = np.load(path)
            elif file == 'Out_nogo_rec.npy':
                Out_nogo_rec = np.load(path)
        interval_sti_peak_go, peak_go_values = reaction_time_amplitude(Out_go_rec, Input_go_rec, dt)
        interval_sti_peak_nogo, peak_nogo_values = reaction_time_amplitude(Out_nogo_rec, Input_go_rec, dt)
        # interval_rec_nogo.append(interval_sti_peak_nogo)
        peak_rec_nogo_gamma.append(peak_nogo_values)
        # interval_rec.append(interval_sti_peak_go)
        peak_rec_gamma.append(peak_go_values)
        # phases_eff_rec.append(phases)

In [None]:

# interval_rec = []
peak_rec_sta = []
# interval_rec_nogo = []
peak_rec_nogo_sta = []
# phases_eff_rec = []
dt = 0.1
# folder = '/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_phase_sensitivity/241202154338/'

for folder in os.listdir('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_peak_energy_stationary/'):
    # print(folder)
    path_folder = os.path.join('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_peak_energy_stationary/', folder)
    # print(path)
    if os.path.isdir(path_folder):
        for file in os.listdir(path_folder):
            path = os.path.join(path_folder, file)
            # print(path)
            if file == 'Input_go_rec.npy':
                Input_go_rec = np.load(path)
            elif file == 'Out_go_rec.npy':
                Out_go_rec = np.load(path)
            elif file == 'Out_nogo_rec.npy':
                Out_nogo_rec = np.load(path)
        interval_sti_peak_go, peak_go_values = reaction_time_amplitude(Out_go_rec, Input_go_rec, dt)
        interval_sti_peak_nogo, peak_nogo_values = reaction_time_amplitude(Out_nogo_rec, Input_go_rec, dt)
        # interval_rec_nogo.append(interval_sti_peak_nogo)
        peak_rec_nogo_sta.append(peak_nogo_values)
        # interval_rec.append(interval_sti_peak_go)
        peak_rec_sta.append(peak_go_values)
        # phases_eff_rec.append(phases)

In [None]:

# interval_rec = []
peak_rec_lowfre = []
# interval_rec_nogo = []
peak_rec_nogo_lowfre = []
# phases_eff_rec = []
dt = 0.1
# folder = '/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_phase_sensitivity/241202154338/'

for folder in os.listdir('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_peak_energy_lowfrequency/'):
    # print(folder)
    path_folder = os.path.join('/mnt/SanDisk/Li/LowRank_ModifiedTheta_SNN/PYTHON/data_peak_energy_lowfrequency/', folder)
    # print(path)
    if os.path.isdir(path_folder):
        for file in os.listdir(path_folder):
            path = os.path.join(path_folder, file)
            # print(path)
            if file == 'Input_go_rec.npy':
                Input_go_rec = np.load(path)
            elif file == 'Out_go_rec.npy':
                Out_go_rec = np.load(path)
            elif file == 'Out_nogo_rec.npy':
                Out_nogo_rec = np.load(path)
        interval_sti_peak_go, peak_go_values = reaction_time_amplitude(Out_go_rec, Input_go_rec, dt)
        interval_sti_peak_nogo, peak_nogo_values = reaction_time_amplitude(Out_nogo_rec, Input_go_rec, dt)
        # interval_rec_nogo.append(interval_sti_peak_nogo)
        peak_rec_nogo_lowfre.append(peak_nogo_values)
        # interval_rec.append(interval_sti_peak_go)
        peak_rec_lowfre.append(peak_go_values)
        # phases_eff_rec.append(phases)

In [None]:
# draw the bar plot for the peak energy for different states
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(10, 5))

mean_peak_rec_gamma = np.mean(peak_rec_gamma, axis=0)
std_peak_rec_gamma = np.std(peak_rec_gamma, axis=0)
mean_peak_rec_nogo_gamma = np.mean(peak_rec_nogo_gamma, axis=0)
std_peak_rec_nogo_gamma = np.std(peak_rec_nogo_gamma, axis=0)

mean_peak_rec_sta = np.mean(peak_rec_sta, axis=0)
std_peak_rec_sta = np.std(peak_rec_sta, axis=0)
mean_peak_rec_nogo_sta = np.mean(peak_rec_nogo_sta, axis=0)
std_peak_rec_nogo_sta = np.std(peak_rec_nogo_sta, axis=0)

mean_peak_rec_lowfre = np.mean(peak_rec_lowfre, axis=0)
std_peak_rec_lowfre = np.std(peak_rec_lowfre, axis=0)
mean_peak_rec_nogo_lowfre = np.mean(peak_rec_nogo_lowfre, axis=0)
std_peak_rec_nogo_lowfre = np.std(peak_rec_nogo_lowfre, axis=0)


bar_width = 0.25
x = np.arange(len(mean_peak_rec_gamma))

plt.bar(x, mean_peak_rec_gamma, bar_width, yerr=std_peak_rec_gamma, label='Gamma', color='blue', alpha=0.7)
plt.bar(x + bar_width, mean_peak_rec_nogo_gamma, bar_width, yerr=std_peak_rec_nogo_gamma, label='Nogo Gamma', color='orange', alpha=0.7)
plt.bar(x + 2*bar_width, mean_peak_rec_sta, bar_width, yerr=std_peak_rec_sta, label='STA', color='green', alpha=0.7)
plt.bar(x + 3*bar_width, mean_peak_rec_nogo_sta, bar_width, yerr=std_peak_rec_nogo_sta, label='Nogo STA', color='red', alpha=0.7)
plt.bar(x + 4*bar_width, mean_peak_rec_lowfre, bar_width, yerr=std_peak_rec_lowfre, label='Low Frequency', color='purple', alpha=0.7)
plt.bar(x + 5*bar_width, mean_peak_rec_nogo_lowfre, bar_width, yerr=std_peak_rec_nogo_lowfre, label='Nogo Low Frequency', color='pink', alpha=0.7)

# plt.xticks(x + 2.5*bar_width, ['go_gamma', 'nogo_gamma', 'go_sta', 'nogo_sta', 'go_lowfre', 'nogo_lowfre'])
plt.xlabel('States')
plt.ylabel('Peak Energy (A.U.)')
plt.title('Peak Energy for Different States')
# plt.legend()
plt.show()



In [None]:

# draw the bar plot for the peak energy for different states separately
# gamma
bar_width = 0.25
plt.figure(figsize=(12, 8))
plt.bar(x, mean_peak_rec_gamma, bar_width, yerr=std_peak_rec_gamma, label='Gamma', color=color_Go, alpha=0.7,capsize=5)
plt.bar(x + bar_width, mean_peak_rec_nogo_gamma, bar_width, yerr=std_peak_rec_nogo_gamma, label='Nogo Gamma', color=color_Nogo, alpha=0.7,capsize=5)
# plt.xlabel('States')
plt.ylabel('Peak Energy')
plt.title('Gamma Oscillation State')
# plt.xticks('go_gamma', 'nogo_gamma')
plt.xticks([0, bar_width], ['go', 'nogo'])

plt.xlim(-0.25, 0.5)
plt.show()

In [None]:
# stationary
plt.figure(figsize=(12, 8))
plt.bar(x , mean_peak_rec_sta, bar_width, yerr=std_peak_rec_sta, label='STA', color=color_Go, alpha=0.7,capsize=5)
plt.bar(x + bar_width, mean_peak_rec_nogo_sta, bar_width, yerr=std_peak_rec_nogo_sta, label='Nogo STA', color=color_Nogo, alpha=0.7,capsize=5)
# plt.xlabel('States')
plt.ylabel('Peak Energy')
plt.title('Stationary State')
# plt.xticks('go_gamma', 'nogo_gamma')
plt.xticks([0, bar_width], ['go', 'nogo'])
# x1 = 0
# x2 = bar_width
plt.xlim(-0.25, 0.5)
plt.show()


In [None]:
# low frequency
plt.figure(figsize=(12, 8))
plt.bar(x , mean_peak_rec_lowfre, bar_width, yerr=std_peak_rec_lowfre, label='Low Frequency', color=color_Go, alpha=0.7,capsize=5)
plt.bar(x + bar_width, mean_peak_rec_nogo_lowfre, bar_width, yerr=std_peak_rec_nogo_lowfre, label='Nogo Low Frequency', color=color_Nogo, alpha=0.7,capsize=5)
# plt.xlabel('States')
plt.ylabel('Peak Energy')
plt.title('Low frequency oscillation state')
# plt.xticks('go_gamma', 'nogo_gamma')
plt.xticks([0, bar_width], ['go', 'nogo'])
# x1 = 0
# x2 = bar_width
plt.xlim(-0.25, 0.5)
plt.show()
