In [None]:
import numpy as np
from brian2 import *
import matplotlib.pyplot as plt
import os
import pickle
import random
import scipy.io
random.seed(67)



In [None]:
# Define parameter ranges
C1_values = np.arange(0.25, 3, 0.25)
C2_values = np.arange(0.25, 3, 0.25)
r1 = 20 * Hz

# save data in the dictionary
results = {}
results_1 = {}

# Neurons parameters
N = 3
tau_E = 5.0
gl = 10.0
el = -75.0
v_th = -55.0
memc = 200.0 # 膜电容
current = 50
n_input = 100
strength_input = 1.0
eta = 1E-3
tau_plus = 40.0
tau_minus = tau_plus * 2
Aplus = 25
Aminus = -Aplus*2
w_exc = 1.0
w_max = 3.0

for C1 in C1_values:
    for C2 in C2_values:
        
        # Define input rates based on the ratios
        r2 = C1 * r1
        r3 = C2 * r1

        start_scope()

        # Equation of LIF neuron
        eqs_neurons = '''
        dv/dt = (-(gl*nsiemens)*(v - el*mV) - (g_E * v) + current*pA) / (memc*pfarad) : volt (unless refractory)
        dg_E/dt = -g_E / (tau_E*ms) : siemens
        '''

        G = NeuronGroup(N, model=eqs_neurons, threshold='v > v_th*mV',
                        reset='v = el*mV', refractory=10*ms, method='euler')

        # STDP equation
        eqs_stdp = '''
            w : 1
            dtrace_pre_plus/dt = -trace_pre_plus/(tau_plus*ms) : 1 (clock-driven)
            dtrace_post_minus/dt = -trace_post_minus/(tau_minus*ms) : 1 (clock-driven)
        '''

        eq_on_pre = '''
            trace_pre_plus += 1.0
            w = clip(w + eta * Aminus * trace_post_minus, 0, w_max)
            g_E += w_exc * nS
        '''

        eq_on_post = '''
            trace_post_minus += 1.0
            w = clip(w + eta * Aplus * trace_pre_plus, 0, w_max)
        '''

        # Create synaptic connections
        con = Synapses(G, G, model=eqs_stdp, on_pre=eq_on_pre, on_post=eq_on_post)
        con.connect(i=0, j=1)
        con.connect(i=0, j=2)
        con.connect(i=1, j=0)
        con.connect(i=1, j=2)
        con.connect(i=2, j=0)
        con.connect(i=2, j=1)
        con.w = w_exc

        # add the Poisson input to the 3 neurons 
        P1 = PoissonInput(G[0:1], 'g_E', n_input, rate=r1, weight=strength_input * nS)
        P2 = PoissonInput(G[1:2], 'g_E', n_input, rate=r2, weight=strength_input * nS)
        P3 = PoissonInput(G[2:], 'g_E', n_input, rate=r3, weight=strength_input * nS)

        # Set up monitors
        w_mon = StateMonitor(con, 'w', record=True, dt=1*ms)
        spike_mon = SpikeMonitor(G)
        trace_pre_mon = StateMonitor(con, 'trace_pre_plus', record=True, dt=1*ms)
        trace_post_mon = StateMonitor(con, 'trace_post_minus', record=True, dt=1*ms)

        # Run simulation
        run(300 * second, report='text')

        # Store results for plots
        results[(C1, C2)] = {
            'w_monitor_ws': np.array(w_mon.w),
            'spike_times': spike_mon.t / ms,
            'spike_indices': spike_mon.i,
            'trace_pre_values': np.array(trace_pre_mon.trace_pre_plus),
            'trace_post_values': np.array(trace_post_mon.trace_post_minus),
            'time': w_mon.t / second
        }
        
        # Store results_1 for motif analysis
        w_monitor_ws = np.array(w_mon.w)
        final_weight = w_mon.w[:, -1]
        final_weight_matrix = np.zeros((3,3))
        final_weight_matrix[0, 1] = final_weight[0]
        final_weight_matrix[0, 2] = final_weight[1]
        final_weight_matrix[1, 0] = final_weight[2]
        final_weight_matrix[1, 2] = final_weight[3]
        final_weight_matrix[2, 0] = final_weight[4]
        final_weight_matrix[2, 1] = final_weight[5]
        
        results_1[(C1, C2)] = {
            'final_weight_matrix': final_weight_matrix,
            'C1': C1,
            'C2': C2
        }

print("Simulation completed. Results stored!!!")

# Dictionary to store final data for MATLAB
matlab_data = {}

# Store (C1, C2) values in an array
C1_C2_list = []
final_matrices_list = []

# Iterate over the results dictionary
for i, values in enumerate(results_1.values(), start=1):
    C1 = values['C1']
    C2 = values['C2']
    final_weight_matrix = values['final_weight_matrix']

    # Store matrix using an index-based name
    var_name = f"matrix_{i}"
    matlab_data[var_name] = final_weight_matrix  # Save the final weight matrix

    # Store the (C1, C2) mapping
    C1_C2_list.append([i, C1, C2])

# Convert (C1, C2) mapping to a NumPy array
matlab_data["C1_C2_mapping"] = np.array(C1_C2_list)

# Save to MATLAB file
scipy.io.savemat('results_2_2.mat', matlab_data)

print("✅ Successfully saved 'results_0_5_0_5.mat' with only the values.")

for (C1, C2), data in results.items():
    
    fig1, ax1 = plt.subplots(figsize=(12, 5))

    
    pre_spike_times = data['spike_times'][data['spike_indices'] == 0]
    post_spike_times = data['spike_times'][data['spike_indices'] == 1]

    
    ax1.plot(data['time'], data['trace_pre_values'][0], label="Pre-Synaptic Trace (P)", color='blue')
    ax1.plot(data['time'], data['trace_post_values'][0], label="Post-Synaptic Trace (M)", color='red', linestyle='dashed')

    
    ax1.vlines(pre_spike_times / 1000, 0, 0.1, color='blue', alpha=0.5, label="Pre-Synaptic Spikes")
    ax1.vlines(post_spike_times / 1000, 0.3, 0.4, color='red', alpha=0.5, label="Post-Synaptic Spikes")

    ax1.set_xlabel("Time (s)")
    ax1.set_ylabel("Trace / Neuron Index")
    ax1.set_title(f"Pre & Post Synaptic Traces with Spikes (C1={C1}, C2={C2})")

    
    ax1.legend(loc="upper right")
    ax1.set_xlim(0, 1)  

    
    filename1 = f"STDP_Traces_Spikes_C1_{C1}_C2_{C2}_2_2_20.png"
    plt.savefig(filename1, dpi=300)
    print(f"Saved: {filename1}")

    
    plt.show()

    
    fig2, ax2 = plt.subplots(figsize=(10, 5))

    ax2.plot(data['time'], data['w_monitor_ws'][0], label="Synaptic Weight (Neuron 0 → 1)")
    ax2.plot(data['time'], data['w_monitor_ws'][1], label="Synaptic Weight (Neuron 0 → 2)")
    ax2.plot(data['time'], data['w_monitor_ws'][2], label="Synaptic Weight (Neuron 1 → 0)")
    ax2.plot(data['time'], data['w_monitor_ws'][3], label="Synaptic Weight (Neuron 1 → 2)")
    ax2.plot(data['time'], data['w_monitor_ws'][4], label="Synaptic Weight (Neuron 2 → 0)")
    ax2.plot(data['time'], data['w_monitor_ws'][5], label="Synaptic Weight (Neuron 2 → 1)")

    ax2.set_xlabel("Time (s)")
    ax2.set_ylabel("Weight (w)")
    ax2.set_title(f"Synaptic Weight Evolution (C1={C1}, C2={C2})")
    ax2.legend()
    ax2.grid()

    
    ax2.set_xlim(0, 300)

    
    filename2 = f"Synaptic_Weights_C1_{C1}_C2_{C2}_2_2_20.png"
    plt.savefig(filename2, dpi=300)
    print(f"Saved: {filename2}")

   
    plt.show()