In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters 
num_input_neurons = 1000
num_network_neurons = 250
g_max = 0.02  
A_plus = 0.001  
A_minus = A_plus * 1.05  
tau_plus = 20  
tau_minus = 20  
simulation_time = 1000  # Reduce for testing
dt = 1  
stimulus_start, stimulus_end = 601, 800  
time_window = 100  

# LIF Parameters
C_m = 20  # Membrane capacitance (ms)
V_rest = -74  # Resting potential (mV)
V_thresh = -54  # Spiking threshold (mV)
V_reset = -60  # Reset potential (mV)
E_ex = 0  # Excitatory reversal potential (mV)
tau_ex = 5  # Synaptic decay constant (ms)

# Rate function parameters for input neurons
R0 = 10  
R1 = 80  
sigma = 100  

# Generate Corrected Spike Trains for Input Neurons (601-800)
def generate_corrected_spike_train(neuron_id, duration):
    spikes = []
    for t in range(duration):
        s = np.random.randint(stimulus_start, stimulus_end)  
        rate = R0 + R1 * (np.exp(-((s - neuron_id)**2) / (2 * sigma**2)) +
                          np.exp(-((s + 1000 - neuron_id)**2) / (2 * sigma**2)) +
                          np.exp(-((s - 1000 - neuron_id)**2) / (2 * sigma**2)))
        if np.random.rand() < (rate * dt / 1000):  
            spikes.append(t)
    return np.array(spikes)

input_spikes = {i: generate_corrected_spike_train(i, simulation_time) for i in range(stimulus_start, stimulus_end + 1)}

# Generate Poisson spike trains for network neurons
def generate_poisson_spike_train(rate, duration):
    return np.sort(np.random.uniform(0, duration, np.random.poisson(rate * (duration / 1000))))

network_spikes = {j: generate_poisson_spike_train(15, simulation_time) for j in range(num_network_neurons)}

# Initialize Feedforward Synaptic Strengths
feedforward_strengths = np.zeros((num_input_neurons, num_network_neurons))
feedforward_strengths[stimulus_start:stimulus_end+1, :] = np.random.uniform(0, g_max, (stimulus_end - stimulus_start + 1, num_network_neurons))
feedforward_strengths[:, 100:200] = 0  

# LIF Neuron Membrane Potentials
V = np.full(num_network_neurons, V_rest, dtype=np.float64)  
g_ex = np.zeros(num_network_neurons)  # Excitatory conductance

# STDP window function
def stdp_window(delta_t):
    return np.where(delta_t < 0, 
                    A_plus * np.exp(delta_t / tau_plus),  
                    -A_minus * np.exp(-delta_t / tau_minus))  

# Track spikes using a sparse dictionary
spike_times = {j: [] for j in range(num_network_neurons)}

# Precompute synaptic decay factor (avoids repeated `exp()` calls)
synaptic_decay = np.exp(-dt / tau_ex)

# Simulate LIF Neurons Over Time
for t in range(simulation_time):
    # Decay excitatory conductance
    g_ex *= synaptic_decay  

    # Compute feedforward inputs only for active neurons
    pre_spiking = np.zeros(num_input_neurons)
    for i in range(stimulus_start, stimulus_end + 1):  
        if t in input_spikes[i]:  
            pre_spiking[i] = 1  

    g_ex += feedforward_strengths.T @ pre_spiking  

    # Update membrane potential using LIF equation
    V += (V_rest - V + g_ex * (E_ex - V)) / C_m * dt  

    # Detect spiking neurons
    spiking_neurons = np.where(V >= V_thresh)[0]  
    for neuron in spiking_neurons:
        spike_times[neuron].append(t)  # Store only spike times (sparse)

    V[spiking_neurons] = V_reset  

    # Apply STDP only for neurons that spiked
    for j in spiking_neurons:
        for i in range(stimulus_start, stimulus_end + 1):  
            if t in input_spikes[i]:  
                delta_t = t - np.array(spike_times[j])  # Faster computation
                valid_times = delta_t[np.abs(delta_t) < time_window]  # Only valid pairs
                
                if valid_times.size > 0:  
                    weight_change = np.sum(stdp_window(valid_times))  
                    feedforward_strengths[i, j] += weight_change
                    feedforward_strengths[i, j] = np.clip(feedforward_strengths[i, j], 0, g_max)

# Normalize for Visualization
feedforward_strengths_normalized = feedforward_strengths / g_max

# PLOT 
plt.figure(figsize=(10, 6))

# Get the coordinates and strengths of non-zero synapses
y, x = np.where(feedforward_strengths_normalized > 0)  # y: input neurons, x: network neurons
weights = feedforward_strengths_normalized[y, x]  # Synaptic strengths
print(weights)

# Plot the scatter points
plt.scatter(x, y, c=weights, cmap='gray_r', s=10, marker='.', vmin=0, vmax=1)

# Add labels and title
plt.xlabel('Network Neuron')
plt.ylabel('Input Neuron')
plt.title('Feedforward Synaptic Strengths After STDP')

# Add colorbar
cbar = plt.colorbar(label='g/g_max')
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Set axis limits
plt.xlim(0, 250)  # Network neurons
plt.ylim(0, 1000)  # Input neurons

plt.show()


In [None]:
#Load or Extract Feedforward Strengths from Figure 5A
feedforward_strengths_normalized = np.random.uniform(0, 1, (num_input_neurons, num_network_neurons))  # Placeholder

#Transfer Final Network Neuron Activity from 5A to 5B
recurrent_strengths = feedforward_strengths_normalized.T @ feedforward_strengths_normalized  # Compute recurrent connections
recurrent_strengths = recurrent_strengths / np.max(recurrent_strengths) * g_max  # Normalize


# Apply STDP on Transferred Recurrent Connections
for i in range(num_network_neurons):
    for j in range(num_network_neurons):
        if i != j and not (100 <= i < 200 and 100 <= j < 200):  # Prevent self-connections & enforce hole
            pre_spikes = np.array(network_spikes[i])
            post_spikes = np.array(network_spikes[j])

            if pre_spikes.size > 0 and post_spikes.size > 0:
                delta_t = pre_spikes[:, None] - post_spikes[None, :]
                mask = np.abs(delta_t) < time_window
                delta_t_windowed = delta_t[mask]

                if delta_t_windowed.size > 0:
                    updates = stdp_window(delta_t_windowed)
                    weight_change = np.sum(updates)

                    # **Ensure STDP does not modify w_ii**
                    if i != j:
                        recurrent_strengths[i, j] += weight_change
                        recurrent_strengths[i, j] = np.clip(recurrent_strengths[i, j], 0, g_max)


# **Enforce the hole again after STDP**
recurrent_strengths[100:200, 100:200] = 0  
# Same network neurons have no weight
recurrent_strengths[0:250, 0:100] = 0  
recurrent_strengths[0:250, 200:250] = 0  


# Normalize 
recurrent_strengths_normalized = recurrent_strengths / g_max

plt.figure(figsize=(10, 6))

# Get coordinates of non-zero recurrent synapses
y, x = np.where(recurrent_strengths_normalized > 0)
weights = recurrent_strengths_normalized[y, x]

plt.scatter(x, y, c=weights, cmap='gray_r', s=10, marker='.', vmin=0, vmax=1)

# Add labels and title
plt.xlabel('Postsynaptic Network Neuron')
plt.ylabel('Presynaptic Network Neuron')
plt.title('Figure 5B')

# Add colorbar
cbar = plt.colorbar(label='g/g_max')
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Set axis limits
plt.xlim(0, 250)
plt.ylim(0, 250)

plt.show()


Latest Update

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
num_network_neurons = 250
num_input_neurons = 1000  # Match paper's input neuron count
learning_rate = 0.005  # A+ from paper
B = 1.02  # Reduced LTD dominance
tau_plus = 20  # LTP time window in ms
tau_minus = 100  # LTD time window in ms
g_max = 1.0  # Maximum synaptic weight
num_timesteps = 100  # Simulation duration in ms
connection_prob = 0.5  # Probability of recurrent connection
background_rate = 500  # Hz, Poisson background input frequency
background_weight = 0.096  # Background input strength
sigma_correlation = 100  # Width of the Gaussian tuning curve
stimulus_location = 700   # Middle of the input neuron array
R0 = 10  # Baseline firing rate in Hz
R1 = 80  # Peak firing rate in Hz

# Initialize synaptic weights with higher starting values
synaptic_weights = np.zeros((num_network_neurons, num_input_neurons))
recurrent_weights = np.zeros((num_network_neurons, num_network_neurons))

np.fill_diagonal(recurrent_weights, 0)  # No self-connections

# Generate Gaussian firing rates for input neurons (centered at stimulus_location)
input_firing_rates = np.zeros(num_input_neurons)
for a in range(num_input_neurons):
    input_firing_rates[a] = R0 + R1 * np.exp(-((stimulus_location - a) ** 2 / (2 * sigma_correlation ** 2)))

# Generate deterministic input spikes based on the Gaussian firing rates
input_spikes = np.zeros((num_input_neurons, num_timesteps), dtype=bool)
for i in range(num_input_neurons):
    # Determine the number of spikes for each input neuron based on its firing rate
    num_spikes = int(input_firing_rates[i] * num_timesteps / 1000)  # Convert Hz to total spikes
    spike_times = np.linspace(0, num_timesteps - 1, num_spikes, dtype=int)  # Distribute spikes evenly
    input_spikes[i, spike_times] = True

# Generate correlated spikes for network neurons
network_spikes = np.zeros((num_network_neurons, num_timesteps), dtype=bool)
for t in range(num_timesteps):
    if np.any(input_spikes[:, t]):  # If any input fires, correlated neurons fire together
        network_spikes[:, t] = True  # All network neurons fire in response to input spikes
    else:
        network_spikes[:, t] = np.random.rand(num_network_neurons) < 0.1  # 10% random firing

    # Add background Poisson input
    background_spikes = np.random.rand(num_network_neurons) < (background_rate / 1000)
    network_spikes[:, t] |= background_spikes  # Combine with existing activity

# STDP update function with exponential timing dependence
def stdp_update(pre_times, post_times, weight):
    for pre_t in pre_times:
        for post_t in post_times:
            delta_t = post_t - pre_t
            if delta_t > 0:
                weight += learning_rate * np.exp(-delta_t / tau_plus) * (g_max - weight)  # LTP
            elif delta_t < 0:
                weight -= (learning_rate * B) * np.exp(delta_t / tau_minus) * weight  # LTD
    return np.clip(weight, 0, g_max)  # Keep within bounds

# Track spike times
pre_times = [np.where(input_spikes[i])[0] for i in range(num_input_neurons)]
post_times = [np.where(network_spikes[i])[0] for i in range(num_network_neurons)]

# Simulate STDP over time for feedforward connections
for n in range(num_network_neurons):
    print('network neuron=', n)
    for i in range(num_input_neurons):
        synaptic_weights[n, i] = stdp_update(pre_times[i], post_times[n], synaptic_weights[n, i])

# Simulate STDP for recurrent connections
for i in range(num_network_neurons):
    print('network neuron=', i)
    for j in range(num_network_neurons):
        if recurrent_weights[i, j] > 0:  # Only update existing connections
            recurrent_weights[i, j] = stdp_update(post_times[i], post_times[j], recurrent_weights[i, j])

# Plot grayscale synaptic weight matrix like in the paper
plt.figure(figsize=(6, 5))
plt.imshow(synaptic_weights.T, cmap='gray_r', aspect='auto', vmin=0, vmax=1)  # Transpose the matrix and set colorbar range
plt.colorbar(label='Synaptic Strength')
plt.xlabel("Network Neuron")  
plt.ylabel("Input Neuron (Gaussian Correlated)")  
plt.title("Grayscale Representation of Feedforward Synaptic Weights")
plt.gca().invert_yaxis()  # Invert the y-axis so 0 is at the bottom
plt.show()