In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random as rd
import importlib 
import Rebound_WTA_lib as WTA
importlib.reload(WTA)

## Fig 1

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

# Hodgkin-Huxley parameters
C   = 1.0        # membrane capacitance, uF/cm^2
gNa = 120.0      # maximum sodium conductance, mS/cm^2
gK  = 36.0       # maximum potassium conductance, mS/cm^2
gL  = 0.3        # leak conductance, mS/cm^2
ENa = 50.0       # sodium reversal potential, mV
EK  = -77.0      # potassium reversal potential, mV
EL  = -54.387    # leak reversal potential, mV

# Rate functions (voltage in mV)
def alpha_m(V):
    if abs(V + 40) < 1e-6:
        return 1.0
    return 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))

def beta_m(V):
    return 4.0 * np.exp(-(V + 65) / 18)

def alpha_h(V):
    return 0.07 * np.exp(-(V + 65) / 20)

def beta_h(V):
    return 1.0 / (1 + np.exp(-(V + 35) / 10))

def alpha_n(V):
    if abs(V + 55) < 1e-6:
        return 0.1
    return 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))

def beta_n(V):
    return 0.125 * np.exp(-(V + 65) / 80)

# Define the HH differential equations
def diff(t, x, u_func):
    V, m, h, n = x
    u = u_func(t)  # external current at time t
    INa = gNa * m**3 * h * (V - ENa)
    IK  = gK  * n**4      * (V - EK)
    IL  = gL * (V - EL)
    dVdt = (u - INa - IK - IL) / C
    dmdt = alpha_m(V)*(1 - m) - beta_m(V)*m
    dhdt = alpha_h(V)*(1 - h) - beta_h(V)*h
    dndt = alpha_n(V)*(1 - n) - beta_n(V)*n
    return [dVdt, dmdt, dhdt, dndt]

# External current: hyperpolarizing pulse (inhibition)
def u_func(t):
    if t < 10:
        return 0.0
    elif t < 30:
        return -20.0
    else:
        return 0.0

# Simulation parameters
t_start = 0
t_end   = 45   # simulation time in ms
t_eval  = np.linspace(t_start, t_end, 30001)

# Initial conditions (resting state at V = -65 mV)
V0 = -65.0
m0 = alpha_m(V0) / (alpha_m(V0) + beta_m(V0))
h0 = alpha_h(V0) / (alpha_h(V0) + beta_h(V0))
n0 = alpha_n(V0) / (alpha_n(V0) + beta_n(V0))
x0 = [V0, m0, h0, n0]

# Solve the HH ODEs using solve_ivp
sol = solve_ivp(lambda t, x: diff(t, x, u_func),
                (t_start, t_end), x0, t_eval=t_eval)

# Compute simulation results
V = sol.y[0]
m = sol.y[1]
h = sol.y[2]
n = sol.y[3]

INa = gNa * m**3 * h * (V - ENa)
IK  = gK  * n**4 * (V - EK)

# Set font sizes
caption_size = 20
title_size   = 20
axis_size    = 16

# Create a figure with three subplots:
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10,10), 
                                      gridspec_kw={'height_ratios': [2, 2, 1]})

# Top subplot - Membrane potential (voltage)
ax1.plot(sol.t, V, linewidth=2, color='blue')
# Dashed vertical lines in black
ax1.axvline(x=10, color='black', linestyle='--', linewidth=2)
ax1.axvline(x=30, color='black', linestyle='--', linewidth=2)
ax1.axvline(x=37, color='black', linestyle='--', linewidth=2)
# Remove x-axis label (time removed)
ax1.set_xlabel("")
# y-axis only for units
ax1.set_ylabel("(mV)", fontsize=caption_size)
ax1.set_title("Hodgkin–Huxley Model: Membrane Voltage", fontsize=title_size)

# Stage labels (using black text)
y_min, y_max = ax1.get_ylim()
y_text = y_min + 0.4*(y_max - y_min)
ax1.text(5, y_text, "0\n free state", fontsize=caption_size, color='black', ha='center')
ax1.text(20, y_text, "1\n inhibition", fontsize=caption_size, color='black', ha='center')
ax1.text(33.3, y_text, "2\n rebound", fontsize=caption_size, color='black', ha='center')
ax1.text(43, y_text, "3\n spike & \n reset", fontsize=caption_size*0.9, color='black', ha='center')

# Middle subplot - INa and IK currents
ax2.plot(sol.t, INa, linewidth=2, color='blue', label="I_Na")
ax2.plot(sol.t, IK,  linewidth=2, color='red',  label="I_K")
ax2.axvline(x=10, color='black', linestyle='--', linewidth=2)
ax2.axvline(x=30, color='black', linestyle='--', linewidth=2)
ax2.axvline(x=37, color='black', linestyle='--', linewidth=2)
# Remove x-axis label for this subplot
ax2.set_xlabel("")
ax2.set_ylabel("(µA/cm²)", fontsize=caption_size)
ax2.set_yscale('symlog', linthresh=1)
ax2.set_title("Hodgkin–Huxley Model: INa and IK Currents", fontsize=title_size)
# Keep the legend here because of the two signals
ax2.legend(fontsize=caption_size)
# Stage labels in black
y_min2, y_max2 = ax2.get_ylim()
y_text_0 = y_min2 + 0.3*(y_max2 - y_min2)
y_text_1 = y_min2 + 0.3*(y_max2 - y_min2)
y_text_2 = y_min2 + 0.3*(y_max2 - y_min2)
y_text_3 = y_min2 + 0.3*(y_max2 - y_min2)
ax2.text(5, y_text_0, "0\n free state", fontsize=caption_size, color='black', ha='center')
ax2.text(20, y_text_1, "1\n inhibition", fontsize=caption_size, color='black', ha='center')
ax2.text(33.3, y_text_2, "2\n rebound", fontsize=caption_size, color='black', ha='center')
ax2.text(43.5, y_text_3, "3\n spike & \n reset", fontsize=caption_size*0.9, color='black', ha='center')

# Bottom subplot - External current
u_values = np.array([u_func(t) for t in sol.t])
ax3.plot(sol.t, u_values, linewidth=2, color='blue')
# Keep the time label on the x-axis for this subplot
ax3.set_xlabel("Time (ms)", fontsize=caption_size)
ax3.set_ylabel("(µA/cm²)", fontsize=caption_size)
ax3.set_title("Hodgkin–Huxley Model: External Current", fontsize=title_size)
# Remove the legend as there's only one signal here

plt.tight_layout()
plt.savefig('rebound_spike.png', dpi=300)
plt.show()


## Fig 2

In [None]:
Time=220
Num_sample=22000
dt=Time/Num_sample
Time_line=np.arange(0,Time,dt)

num_neuron=2

x0=np.random.rand(num_neuron,6)*0
x0[0,0]=-75
x0[1,0]=0
x=np.copy(x0)
outputs=[x0]

x02=np.random.rand(2,6)
x2=np.copy(x02)
outputs2=[x2]

inhibit_w=10

syn_current=[0]

for i in range(Num_sample):

    u=np.zeros(num_neuron)

    dx=WTA.ring_ss_hh_center(num_neuron,x,np.random.rand(num_neuron)*0.0+u,syn_strength=0.0,noise=0,inhibit_weight=inhibit_w,current=-1.0)

    x=x+dx*dt

    syn_current.append(WTA.Syn_hh(x[0,4],-1,-65)*inhibit_w)

    outputs.append(x)

outputs_HCO=np.array(outputs)
syn_current_HCO=np.array(syn_current)

In [None]:
plot_start = 0

# Create a figure with 2 row and 2 columns of subplots
fig, axs = plt.subplots(2, 2, figsize=(6, 6))
ax1, ax2, ax3, ax4 = axs.flatten()  # Flatten the 2x2 array to a 1D array of four axes


# --- Left plot: Voltage trajectories line plot ---
time_axis = np.arange(0, len(outputs_HCO) - plot_start - 1, 1) * dt
ax1.plot(time_axis, np.array(outputs_HCO)[plot_start+1:,:,0])
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane Voltage (mV)')
ax1.set_title('Voltage Trajectories of \n Neurons in HCO')
ax1.legend([f'Neuron {i+1}' for i in range(2)], loc='best')

# --- Right plot: Voltage raster plot ---
data = np.array(outputs_HCO)[plot_start:,:,0]
im = ax3.imshow(data.T, aspect='auto', interpolation='nearest',
                cmap='inferno', extent=[0, data.shape[0] * dt, 0, data.shape[1]])

# Define neuron names, for example:
neuron_names = ['Neuron 2','Neuron 1']

# Set tick positions at the center of each row (pixel)
ax3.set_yticks(np.arange(0.5, data.shape[1], 1))
ax3.set_yticklabels(neuron_names)

ax3.set_xlabel('Time (ms)')
ax3.set_ylabel('Neuron')
ax3.set_title('Voltage Raster Plot of \n Neurons in a HCO')
fig.colorbar(im, ax=ax3)

# --- Bottom left plot: Synaptic current line plot ---

plot_start = 1000
plot_end = 7000

# time_axis = np.arange(0, len(syn_current_HCO) - plot_start - 1, 1) * dt
# ax2.plot(time_axis[plot_start:plot_end], np.array(syn_current_HCO)[plot_start:plot_end])
# ax2.set_xlabel('Time (ms)')
# ax2.set_ylabel('Synaptic Current (µA/cm²)')
# ax2.set_title('Synaptic Current from \n Neuron(1) to Neuron(2)')

time_axis = np.arange(0, len(outputs_HCO) - plot_start - 1, 1) * dt
ax2.plot(time_axis[plot_start:plot_end], np.array(outputs_HCO)[plot_start:plot_end,0,0])
ax2.axvline(x=33, color='green', linestyle='--')
ax2.axvline(x=39, color='brown', linestyle='--')
ax2.axvline(x=45, color='grey', linestyle='--')
ax2.axvline(x=51, color='red', linestyle='--')

y_min, y_max = ax2.get_ylim()
y_text_0 = y_min + 0.07 * (y_max - y_min)
y_text_1 = y_min + 0.07 * (y_max - y_min)
y_text_2 = y_min + 0.07 * (y_max - y_min)
y_text_3 = y_min + 0.07 * (y_max - y_min)

ax2.text(36, y_text_0, "3", fontsize=caption_size*0.5, color='green', ha='center')

ax2.text(42, y_text_1, "0", fontsize=caption_size*0.5, color='brown', ha='center')

ax2.text(48, y_text_2, "1", fontsize=caption_size*0.5, color='grey', ha='center')

ax2.text(54, y_text_3, "2", fontsize=caption_size*0.5, color='red', ha='center')

ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Membrane Voltage (mV)')
ax2.set_title('Voltage Trajectories of \n Neuron1 in HCO')


time_axis = np.arange(0, len(outputs_HCO) - plot_start - 1, 1) * dt
ax4.plot(time_axis[plot_start:plot_end], np.array(outputs_HCO)[plot_start:plot_end,1,0],color='darkorange')
ax4.axvline(x=33, color='gray', linestyle='--')
ax4.axvline(x=39, color='red', linestyle='--')
ax4.axvline(x=45, color='green', linestyle='--')
ax4.axvline(x=51, color='brown', linestyle='--')

y_min, y_max = ax4.get_ylim()
y_text_0 = y_min + 0.07 * (y_max - y_min)
y_text_1 = y_min + 0.07 * (y_max - y_min)
y_text_2 = y_min + 0.07 * (y_max - y_min)
y_text_3 = y_min + 0.07 * (y_max - y_min)

ax4.text(36, y_text_0, "1", fontsize=caption_size*0.5, color='gray', ha='center')

ax4.text(42, y_text_1, "2", fontsize=caption_size*0.5, color='red', ha='center')

# Stage 2: End of inhibition / start of rebound at x=50
ax4.text(48, y_text_2, "3", fontsize=caption_size*0.5, color='green', ha='center')

# Stage 3: End of rebound / start of spike at x=57
ax4.text(54, y_text_3, "0", fontsize=caption_size*0.5, color='brown', ha='center')

ax4.set_xlabel('Time (ms)')
ax4.set_ylabel('Membrane Voltage (mV)')
ax4.set_title('Voltage Trajectories of \n Neuron2 in HCO')

# Adjust layout so the subplots fit nicely
plt.tight_layout()

# Save the combined figure
plt.savefig('HCO_v_syn.png',dpi=300)
plt.show()

## Fig 4

In [None]:
Time=220
Num_sample=22000
dt=Time/Num_sample
Time_line=np.arange(0,Time,dt)

num_neuron=5

x0=np.random.rand(num_neuron,6)*0
x0[0,0]=-75
x0[1,0]=-75
x0[2,0]=-75
x0[3,0]=-75
x=np.copy(x0)
outputs=[x0]

x02=np.random.rand(2,6)
x2=np.copy(x02)
outputs2=[x2]

inhibit_w=10

for i in range(Num_sample):

    u=np.zeros(num_neuron)

    dx=WTA.ring_ss_hh(num_neuron,x,np.random.rand(num_neuron)*0.0+u,syn_strength=0.5,noise=0,inhibit_weight=inhibit_w,current=-1.0)

    x=x+dx*dt


    outputs.append(x)

In [None]:
plot_start=1000
data = np.array(outputs)[plot_start:,:,0]

data= np.flip(data,-1)
# Use imshow to create a raster plot, transposing so neurons appear on the y-axis.
plt.imshow(data.T, aspect='auto', interpolation='none', cmap='Greys',extent=[0, data.shape[0] * dt, 0, data.shape[1]])
plt.xlabel('Time(ms)')
plt.ylabel('Neuron')
plt.title('Ring Oscillator consisting of 5 Hodgkin Huxley Neurons')
plt.colorbar(label='Membrane Voltage (mV)')

plt.savefig('HH_ring_raster.png')


In [None]:
Time=500
Num_sample=50000
dt=Time/Num_sample
Time_line=np.arange(0,Time,dt)

num_neuron=5

x0=np.random.rand(num_neuron,3)
x=np.copy(x0)
outputs=[x0]

x02=np.random.rand(2,3)
x2=np.copy(x02)
outputs2=[x2]

for i in range(Num_sample):

    u=np.zeros(num_neuron)

    dx=WTA.ring_ss_Luka(num_neuron,x,np.random.rand(num_neuron)*0.1+u,0.3,noise=0.1,inhibit_weight=5,current=1)

    x=x+dx*dt


    outputs.append(x)

In [None]:
plot_start=1000
data = np.array(outputs)[plot_start:,:,0]

data=np.flip(data,-1)
# Use imshow to create a raster plot, transposing so neurons appear on the y-axis.
plt.imshow(data.T, aspect='auto', interpolation='none', cmap='Greys',extent=[0, data.shape[0] * dt, 0, data.shape[1]])
plt.xlabel('Time(ms)')
plt.ylabel('Neuron')
plt.title('Ring Oscillator consisting of 5 Ribar-Sepulchre Neurons')
plt.colorbar(label='Membrane Voltage (mV)')

plt.savefig('Luka_ring_raster.png')

## Fig 5

In [None]:
Time=6000 
Num_sample=600000 
dt=Time/Num_sample 
Time_line=np.arange(0,Time,dt) 

num_neuron=5 

x0=np.random.rand(num_neuron,6)*0 
x0[0,0]=0
x0[1,0]=-75 
x0[2,0]=-75 
x0[3,0]=-75 


noise=0.1
current=2
i_shift=-65
e_shift=10
inhibition_matrix,excitation_matrix=WTA.ring_topology_gen(num_neuron,1,15)

excitation_matrix=excitation_matrix

external_input=np.random.rand(Num_sample,num_neuron)*0.0


print(inhibition_matrix)
print(excitation_matrix)

In [None]:
@numba.njit
def simulation_optimized(num_neuron, Num_sample, dt, x0, excitation_matrix, inhibition_matrix, noise, current, i_shift,e_shift):
    # Initial conditions
    e_event = 0.0
    i_e_event = 0.0
    d_current = 0.0
    excitation_inhibition_control=0
    event_counter = True
    event_counter1 = True

    outputs = np.empty((Num_sample + 1, num_neuron, 6))
    inputs = np.empty((Num_sample + 1, 1))
    d_current_path = np.empty(Num_sample + 1)
    e_event_path = np.empty(Num_sample + 1)
    i_e_event_path = np.empty(Num_sample + 1)
    
    d_current_path[0] = d_current
    i_e_event_path[0] = i_e_event
    e_event_path[0] = e_event 
    outputs[0] = x0.copy()
    inputs[0, 0] = 0.0
    x = x0.copy()
    
    ent_control = 0.0
    fre_control = 0.0

    for i in range(Num_sample):
        # Activate control after time > 1000
        if i * dt > 1100.0:
            ent_control = 1.0
        if i * dt > 2100.0:
            fre_control = 1.0

        if i*dt > 100:

            excitation_inhibition_control = 1

        y = np.sin(i * dt / 7.0)
        
        event = 0
        event1 = 0
        
        if event_counter and (x[0, 0] > -40.0):
            event = 1
            event_counter = False
        elif x[0, 0] < -45.0:
            event_counter = True

        if event_counter1 and (y > 0.97):
            event1 = 1
            event_counter1 = False
        elif y < 0.95:
            event_counter1 = True


        # external_rhythm = 0.0

        # Update e_event exactly as in your original:
        delta_e = (((event1 - event) - e_event * dt) / 250.0) * fre_control
        e_event = e_event + delta_e
        # if e_event > 10/50.0:
        #     e_event = 10/50.0
        # elif e_event < -10/50.0:
        #     e_event = -10/50.0

        i_e_event += (-i_e_event + 1000*abs(e_event)) * fre_control *dt/500

        # Update d_current as in the original (with clamping)
        temp_d = d_current + e_event * 2/250 * dt 
        if temp_d > (3.0 - current):
            temp_d = 3.0 - current
        elif temp_d < (-2 - current):
            temp_d = -2 - current
        d_current = temp_d * fre_control

        # Build the external input vector u
        u = np.zeros(num_neuron)
        # Replicate: (event_counter1==False) * (40.0 + 80.0) - 80.0
        temp_val = (int(not event_counter1) * 120.0) - 80.0
        # Multiply by control and the boolean check (converted to 0 or 1)
        factor = ent_control * (1.0 if np.abs(i_e_event) > (1.5) else 1.0)
        
        # Compute the synaptic drive for all neurons
        syn_val = WTA.syn_hh_numba(temp_val, 0, -45.0)
        for j in range(num_neuron):
            u[j] += syn_val * factor
        # For the first neuron, use a different parameter value (2.2)
        u[0] = WTA.syn_hh_numba(temp_val, 2, -45.0) * factor

        # Compute the state change with your numba‐optimized topology function
        dx = WTA.ss_hh_topology(num_neuron, x, u, excitation_matrix*excitation_inhibition_control, inhibition_matrix*excitation_inhibition_control,
                                  noise, current+d_current , i_shift,e_shift, 0.1, 0.1)
        x = x + dx * dt

        # Build inputs similarly to original code:
        inputs[i + 1, 0] = (int(not event_counter1) * 120.0) * ent_control - 80.0
        outputs[i + 1] = x
        d_current_path[i + 1] = d_current
        e_event_path[i + 1] = e_event
        i_e_event_path[i + 1] = i_e_event

    return outputs, inputs, d_current_path, e_event_path,i_e_event_path

In [None]:
outputs, inputs, d_current_path, e_event_path, i_e_event_path = simulation_optimized(num_neuron, Num_sample, dt, x0, excitation_matrix*10, inhibition_matrix, noise, current,i_shift,e_shift)

In [None]:
Time_line=np.arange(0,Time,dt)
start_time=int(100/dt)
data1 = np.array(outputs)[start_time+1:,::-1,0]
data2 = (inputs)[start_time+1:] 

data= np.concatenate((data1.T, data2.T), axis=0).T

print(data.shape)  # Expected output: (10, number_of_time_steps)

vmin_val = -90 
vmax_val = np.max(data)

caption_size=20
tick_size=15
label_size=20

vertical_lines = [900, 1900, 5000]

n_timepoints = data.shape[0]


fig, axes = plt.subplots(4, 1, figsize=(20, 12), sharey=True)


ax = axes[0]
start = 0
# Ensure the last segment includes any leftover points
end = int(900/dt)
segment = data[start+start_time:end+start_time, :].T
im = ax.imshow(segment, aspect='auto', interpolation='none', cmap='Greys', vmin=vmin_val, vmax=vmax_val,
                extent=[(start +start_time)* dt, (end +start_time)*dt, 0, data.shape[1]])
tick_positions = np.arange(0.5, 6, 1)
# Define the neuron names you want to show (make sure the list length equals num_neurons)
neuron_names = ['Input','Neuron 1','Neuron 2','Neuron 3','Neuron 4','Neuron 5']
ax.set_yticks(tick_positions, neuron_names, fontsize=tick_size)
ax.tick_params(axis='x', labelsize=tick_size)

ax = axes[1]
start = int(900/dt)
# Ensure the last segment includes any leftover points
end = int(1800/dt)
segment = data[start+start_time:end+start_time, :].T
im = ax.imshow(segment, aspect='auto', interpolation='none', cmap='Greys', vmin=vmin_val, vmax=vmax_val,
                extent=[(start +start_time)* dt, (end +start_time)*dt, 0, data.shape[1]])
tick_positions = np.arange(0.5, 6, 1)
# Define the neuron names you want to show (make sure the list length equals num_neurons)
neuron_names = ['Input','Neuron 1','Neuron 2','Neuron 3','Neuron 4','Neuron 5']
ax.set_yticks(tick_positions, neuron_names, fontsize=tick_size)
ax.tick_params(axis='x', labelsize=tick_size)

ax = axes[2]
start = int(2700/dt)
# Ensure the last segment includes any leftover points
end = int(3600/dt)
segment = data[start+start_time:end+start_time, :].T
im = ax.imshow(segment, aspect='auto', interpolation='none', cmap='Greys', vmin=vmin_val, vmax=vmax_val,
                extent=[(start +start_time)* dt, (end +start_time)*dt, 0, data.shape[1]])
tick_positions = np.arange(0.5, 6, 1)
# Define the neuron names you want to show (make sure the list length equals num_neurons)
neuron_names = ['Input','Neuron 1','Neuron 2','Neuron 3','Neuron 4','Neuron 5']
ax.set_yticks(tick_positions, neuron_names, fontsize=tick_size)
ax.tick_params(axis='x', labelsize=tick_size)

ax = axes[3]
start = int(3600/dt)
# Ensure the last segment includes any leftover points
end = int(4500/dt)
segment = data[start+start_time:end+start_time, :].T
im = ax.imshow(segment, aspect='auto', interpolation='none', cmap='Greys', vmin=vmin_val, vmax=vmax_val,
                extent=[(start +start_time)* dt, (end +start_time)*dt, 0, data.shape[1]])
tick_positions = np.arange(0.5, 6, 1)
# Define the neuron names you want to show (make sure the list length equals num_neurons)
neuron_names = ['Input','Neuron 1','Neuron 2','Neuron 3','Neuron 4','Neuron 5']
ax.set_yticks(tick_positions, neuron_names, fontsize=tick_size)
ax.tick_params(axis='x', labelsize=tick_size)


plt.xlabel('Time (ms)', fontsize=label_size)
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), label='Membrane Voltage (mV)')
cbar.set_label('Membrane Voltage (mV)', fontsize=label_size)  # Change label font size
cbar.ax.tick_params(labelsize=tick_size)

plt.show()
fig.savefig('new_total_voltage_raster_2.png')