In [None]:
import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt
import os
import  Neuron_models as NM
import  calculation   as cal
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"  # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')

In [None]:
rng = np.random.RandomState(28)
m, n = 7500, 10000
# m, n = 3000, 5024
# random design
A = rng.randn(m, n)  # random design
A_norm = np.linalg.norm(A,ord=2,axis = 0,keepdims =True)
A = A /A_norm
x0 = rng.rand(n)
x0[x0 < 0.9] = 0
b_np = np.dot(A, x0)
b = np.dot(A.T,b_np)
w = np.dot(A.T,A)
w[np.diag_indices_from(w)] = 0
w = bm.array(w)
b = bm.array(b)
l = 0.1 # regularization parameter


In [None]:
def measure_execution_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()  # 记录开始时间
        result = func(*args, **kwargs)  # 执行要测量运行时间的代码块
        end_time = time.time()  # 记录结束时间
        execution_time = end_time - start_time  # 计算运行时间
        print("Execution Time: {:.6f} seconds".format(execution_time))
        return result, execution_time
    return wrapper

@measure_execution_time
def measure_runner_time(runner, total_period):
    runner.run(total_period)

In [None]:
import time 
pars = (w.shape[0], w, b, l)
runner_pars = dict(monitors=['N.spike'], dt = 0.01)

# net_LIF    = NM.SLCA_IF(*pars)
net_LIF    = NM.SLCA_rk2(*pars)
# net_Izh   = NM.SLCA_Izh(*pars)
net_Izh   = NM.SLCA_GIF(*pars)
net_ML     = NM.SLCA_ML(*pars)
net_WB    = NM.SLCA_WB(*pars)

total_period = 200
# runner = bp.DSRunner(net,monitors=['N.spike','N.V','N2N.g'], inputs = [('N.input', b)], dt = 0.01)
runner_LIF = bp.DSRunner(net_LIF, **runner_pars)
runner_Izh = bp.DSRunner(net_Izh, **runner_pars)
runner_ML  = bp.DSRunner(net_ML, **runner_pars )
runner_WB  = bp.DSRunner(net_WB, **runner_pars )

_, LIF_time = measure_runner_time(runner_LIF, total_period)
_, Izh_time = measure_runner_time(runner_Izh, total_period)
_, ML_time  = measure_runner_time(runner_ML, total_period)
_, WB_time  = measure_runner_time(runner_WB, total_period)


In [None]:
point_number = 15
time_interval = int(1.0/runner_LIF.dt)
slice_value = np.linspace(1,int(total_period)*time_interval, point_number)
list_value  = np.zeros((4,slice_value.shape[0]))
list_value = {}
for i,runner_method in enumerate([runner_LIF,runner_Izh,runner_ML,runner_WB]):
    key = f"list_value_{i}"
    list_value[key] = [np.sum(runner_method.mon['N.spike'][:int(value), :] / value * time_interval, axis=0) for value in slice_value]

In [None]:
nmse_LIF = np.array([cal.calculate_nmse(x0, x) for x in list_value['list_value_0']])
nmse_Izh   = np.array([cal.calculate_nmse(x0, x) for x in list_value['list_value_1']])
nmse_ML   = np.array([cal.calculate_nmse(x0, x) for x in list_value['list_value_2']])
nmse_WB  = np.array([cal.calculate_nmse(x0, x) for x in list_value['list_value_3']])

LIF_time = np.linspace(0,LIF_time,point_number)
Izh_time = np.linspace(0,Izh_time,point_number)
ML_time  = np.linspace(0,ML_time,point_number)
WB_time  = np.linspace(0,WB_time,point_number)

In [None]:
fista_value, L2_error, time_consume = cal.fista(A,b_np, x0,l, 1600)

In [None]:
# colors = ['#8ECFC9', '#FFBE7A', '#FA7F6F', '#82B0D2','#BEB8DC','#E7DAD2']
colors = ['#2878b5', '#c82423', '#800080', '#006400','#000000']
# 复古配色
# colors = ["#0780cf", "#765005", "#fa6d1d", "#0e2c82", "#b6b51f", "#da1f18", "#701866", "#f47a75", "#009db2", "#024b51", "#0780cf", "#765005"]
#新特色
# colors = ["#63b2ee", "#76da91", "#f8cb7f", "#f89588", "#7cd6cf", "#9192ab", "#7898e1", "#efa666", "#eddd86", "#9987ce", "#63b2ee", "#76da91"]
markers = ['*', 'o', 's', 'D','X']

plt_style = {
    'figure.autolayout' : True,
    'font.size' : 18,
    'lines.linewidth' : 2,
    'lines.markersize': 8,
    'xtick.labelsize' : 'medium',
    'ytick.labelsize' : 'medium',
    'legend.fontsize' : 'small',
    # 'axes.spines.top' : False,
    # 'axes.spines.right' : False,
    'xtick.labelsize' : 24,
    'ytick.labelsize' : 24,
    'xtick.major.size' : 6,
    'ytick.major.size' : 6,
    'legend.fontsize' : 16,
    'axes.labelsize' : 24,
    'axes.titlesize' : 24,
    # 'font.family': "Times New Roman"
    }
for key, val in plt_style.items():
    plt.rcParams[key] = val
    
interval = len(L2_error) // point_number

plt.plot(time_consume[::interval] , L2_error[::interval],  color=colors[0], marker=markers[0], markerfacecolor='none', label='FISTA')
plt.plot(LIF_time, nmse_LIF, color=colors[1], marker=markers[1], markerfacecolor='none', label='SLCA-LIF',)
# plt.plot(Izh_time, nmse_LIF*0.98,   color=colors[2], marker=markers[2], markerfacecolor='none', label='SLCA-Izh')
plt.plot(Izh_time, nmse_Izh,   color=colors[2], marker=markers[2], markerfacecolor='none', label='SLCA-GIF')
plt.plot(ML_time,  nmse_ML,   color=colors[3], marker=markers[3], markerfacecolor='none', label='SLCA-ML')
plt.plot(WB_time , nmse_WB,  color=colors[4], marker=markers[4], markerfacecolor='none', label='SLCA-WB')

plt.legend()
plt.xlim([0,15])
plt.xlabel('Execution Time (s)')
plt.ylabel('NMSE (dB)')
plt.savefig('Figure/mult_neurons.pdf', format='pdf')
plt.savefig('Figure/multi_neurons', format='png')

In [None]:
np.savez('Figure/Fig2b.npz', time_consume= time_consume, L2_error=L2_error, LIF_time=LIF_time, nmse_LIF=nmse_LIF, Izh_time=Izh_time, nmse_Izh=nmse_Izh, ML_time=ML_time, nmse_ML=nmse_ML, WB_time=WB_time, nmse_WB=nmse_WB )