In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters 
num_input_neurons = 1000
num_network_neurons = 250
g_max = 0.02  
A_plus = 0.001  
A_minus = A_plus * 1.05  
tau_plus = 20  
tau_minus = 20  
simulation_time = 1000  # Reduce for testing
dt = 1  
stimulus_start, stimulus_end = 601, 800  
time_window = 100  

# LIF Parameters
C_m = 20  # Membrane capacitance (ms)
V_rest = -74  # Resting potential (mV)
V_thresh = -54  # Spiking threshold (mV)
V_reset = -60  # Reset potential (mV)
E_ex = 0  # Excitatory reversal potential (mV)
tau_ex = 5  # Synaptic decay constant (ms)

# Rate function parameters for input neurons
R0 = 10  
R1 = 80  
sigma = 100  

# Generate Corrected Spike Trains for Input Neurons (601-800)
def generate_corrected_spike_train(neuron_id, duration):
    spikes = []
    for t in range(duration):
        s = np.random.randint(stimulus_start, stimulus_end)  
        rate = R0 + R1 * (np.exp(-((s - neuron_id)**2) / (2 * sigma**2)) +
                          np.exp(-((s + 1000 - neuron_id)**2) / (2 * sigma**2)) +
                          np.exp(-((s - 1000 - neuron_id)**2) / (2 * sigma**2)))
        if np.random.rand() < (rate * dt / 1000):  
            spikes.append(t)
    return np.array(spikes)

input_spikes = {i: generate_corrected_spike_train(i, simulation_time) for i in range(stimulus_start, stimulus_end + 1)}

# Generate Poisson spike trains for network neurons
def generate_poisson_spike_train(rate, duration):
    return np.sort(np.random.uniform(0, duration, np.random.poisson(rate * (duration / 1000))))

network_spikes = {j: generate_poisson_spike_train(15, simulation_time) for j in range(num_network_neurons)}

# Initialize Feedforward Synaptic Strengths
feedforward_strengths = np.zeros((num_input_neurons, num_network_neurons))
feedforward_strengths[stimulus_start:stimulus_end+1, :] = np.random.uniform(0, g_max, (stimulus_end - stimulus_start + 1, num_network_neurons))
feedforward_strengths[:, 100:200] = 0  

# LIF Neuron Membrane Potentials
V = np.full(num_network_neurons, V_rest, dtype=np.float64)  
g_ex = np.zeros(num_network_neurons)  # Excitatory conductance

# STDP window function
def stdp_window(delta_t):
    return np.where(delta_t < 0, 
                    A_plus * np.exp(delta_t / tau_plus),  
                    -A_minus * np.exp(-delta_t / tau_minus))  

# Track spikes using a sparse dictionary
spike_times = {j: [] for j in range(num_network_neurons)}

# Precompute synaptic decay factor (avoids repeated `exp()` calls)
synaptic_decay = np.exp(-dt / tau_ex)

# Simulate LIF Neurons Over Time
for t in range(simulation_time):
    # Decay excitatory conductance
    g_ex *= synaptic_decay  

    # Compute feedforward inputs only for active neurons
    pre_spiking = np.zeros(num_input_neurons)
    for i in range(stimulus_start, stimulus_end + 1):  
        if t in input_spikes[i]:  
            pre_spiking[i] = 1  

    g_ex += feedforward_strengths.T @ pre_spiking  

    # Update membrane potential using LIF equation
    V += (V_rest - V + g_ex * (E_ex - V)) / C_m * dt  

    # Detect spiking neurons
    spiking_neurons = np.where(V >= V_thresh)[0]  
    for neuron in spiking_neurons:
        spike_times[neuron].append(t)  # Store only spike times (sparse)

    V[spiking_neurons] = V_reset  

    # Apply STDP only for neurons that spiked
    for j in spiking_neurons:
        for i in range(stimulus_start, stimulus_end + 1):  
            if t in input_spikes[i]:  
                delta_t = t - np.array(spike_times[j])  # Faster computation
                valid_times = delta_t[np.abs(delta_t) < time_window]  # Only valid pairs
                
                if valid_times.size > 0:  
                    weight_change = np.sum(stdp_window(valid_times))  
                    feedforward_strengths[i, j] += weight_change
                    feedforward_strengths[i, j] = np.clip(feedforward_strengths[i, j], 0, g_max)

# Normalize for Visualization
feedforward_strengths_normalized = feedforward_strengths / g_max

# PLOT 
plt.figure(figsize=(10, 6))

# Get the coordinates and strengths of non-zero synapses
y, x = np.where(feedforward_strengths_normalized > 0)  # y: input neurons, x: network neurons
weights = feedforward_strengths_normalized[y, x]  # Synaptic strengths
print(weights)

# Plot the scatter points
plt.scatter(x, y, c=weights, cmap='gray_r', s=10, marker='.', vmin=0, vmax=1)

# Add labels and title
plt.xlabel('Network Neuron')
plt.ylabel('Input Neuron')
plt.title('Feedforward Synaptic Strengths After STDP')

# Add colorbar
cbar = plt.colorbar(label='g/g_max')
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Set axis limits
plt.xlim(0, 250)  # Network neurons
plt.ylim(0, 1000)  # Input neurons

plt.show()


In [None]:
#Load or Extract Feedforward Strengths from Figure 5A
feedforward_strengths_normalized = np.random.uniform(0, 1, (num_input_neurons, num_network_neurons))  # Placeholder

#Transfer Final Network Neuron Activity from 5A to 5B
recurrent_strengths = feedforward_strengths_normalized.T @ feedforward_strengths_normalized  # Compute recurrent connections
recurrent_strengths = recurrent_strengths / np.max(recurrent_strengths) * g_max  # Normalize


# Apply STDP on Transferred Recurrent Connections
for i in range(num_network_neurons):
    for j in range(num_network_neurons):
        if i != j and not (100 <= i < 200 and 100 <= j < 200):  # Prevent self-connections & enforce hole
            pre_spikes = np.array(network_spikes[i])
            post_spikes = np.array(network_spikes[j])

            if pre_spikes.size > 0 and post_spikes.size > 0:
                delta_t = pre_spikes[:, None] - post_spikes[None, :]
                mask = np.abs(delta_t) < time_window
                delta_t_windowed = delta_t[mask]

                if delta_t_windowed.size > 0:
                    updates = stdp_window(delta_t_windowed)
                    weight_change = np.sum(updates)

                    # **Ensure STDP does not modify w_ii**
                    if i != j:
                        recurrent_strengths[i, j] += weight_change
                        recurrent_strengths[i, j] = np.clip(recurrent_strengths[i, j], 0, g_max)


# **Enforce the hole again after STDP**
recurrent_strengths[100:200, 100:200] = 0  
# Same network neurons have no weight
recurrent_strengths[0:250, 0:100] = 0  
recurrent_strengths[0:250, 200:250] = 0  


# Normalize 
recurrent_strengths_normalized = recurrent_strengths / g_max

plt.figure(figsize=(10, 6))

# Get coordinates of non-zero recurrent synapses
y, x = np.where(recurrent_strengths_normalized > 0)
weights = recurrent_strengths_normalized[y, x]

plt.scatter(x, y, c=weights, cmap='gray_r', s=10, marker='.', vmin=0, vmax=1)

# Add labels and title
plt.xlabel('Postsynaptic Network Neuron')
plt.ylabel('Presynaptic Network Neuron')
plt.title('Figure 5B')

# Add colorbar
cbar = plt.colorbar(label='g/g_max')
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Set axis limits
plt.xlim(0, 250)
plt.ylim(0, 250)

plt.show()


Latest Update

In [None]:
import numpy as np
import matplotlib.pyplot as plt

class STDP_Network:
    def __init__(self, num_neurons=250, num_poisson=250, num_inputs=1000, dt=0.1,
                 tau_pre=20.0, tau_post=20.0, tau_m=20.0, V_rest=-74.0, V_reset=-60.0,
                 V_thresh=-54.0, C_m=0.9, g_leak=0.2, tau_s=5.0, 
                 A_plus_ff=0.005, A_minus_ff=0.005, A_plus_recur=0.001, A_minus_recur=0.001,
                 B_ff=1.06, B_recur=1.04, g_max=1.0, poisson_rate=10, stimulus_width=100,
                 mean_stim_time=100, R0=10, R1=80, sigma=100):
        """
        Initialize the network with LIF neuron dynamics for the network neurons (not input neurons).
        """
        self.num_neurons = num_neurons  # Postsynaptic (network) neurons
        self.num_inputs = num_inputs    # Presynaptic (input) neurons
        self.dt = dt
        self.tau_pre = tau_pre
        self.tau_post = tau_post
        
        # LIF neuron parameters
        self.tau_m = tau_m  # Membrane time constant (ms)
        self.V_rest = V_rest  # Resting potential (mV)
        self.V_reset = V_reset  # Reset potential (mV)
        self.V_thresh = V_thresh  # Threshold potential (mV)
        self.C_m = C_m  # Membrane capacitance 
        self.g_leak = g_leak  # Leak conductance 
        self.E_ex = 0
        self.tau_s = tau_s  # Synaptic time constant (ms)
        self.mean_stim_time = mean_stim_time
        self.R0 = R0
        self.R1 = R1
        self.sigma = sigma
        # Feedforward learning rates and scaling factor
        self.A_plus_ff = A_plus_ff
        self.A_minus_ff = A_minus_ff
        self.B_ff = B_ff
        
        # Recurrent learning rates and scaling factor
        self.A_plus_recur = A_plus_recur
        self.A_minus_recur = A_minus_recur
        self.B_recur = B_recur
        
        self.g_max = g_max
        
        # Initialize feedforward weights: shape (num_inputs, num_neurons)
        # self.ff_weights = np.random.rand(num_inputs, num_neurons) * g_max
        #self.ff_weights = np.random.uniform(0.4, 1.0, (num_inputs, num_neurons))  # Stronger initial weights
        # Initialize feedforward weights: shape (num_inputs, num_neurons)
        # Define the center point (neuron 700)
        center_neuron = 700

        # Generate Gaussian distribution for weights with peak at neuron 700
        weight_distribution = np.exp(-((np.arange(self.num_inputs) - center_neuron)**2) / (2 * stimulus_width**2))
        weight_distribution = weight_distribution / np.max(weight_distribution)  # Normalize

        # Initialize weights using this distribution, scaled by the desired maximum weight
        self.ff_weights = np.random.uniform(0.6, 1.0, (self.num_inputs, self.num_neurons)) * weight_distribution[:, None]




        # Create and enforce sparsity mask (20% chance) on feedforward connections:
        self.mask = (np.random.rand(num_inputs, num_neurons) < 0.2).astype(int)
        self.ff_weights *= self.mask
        
        # Initialize recurrent weights (network-to-network); these start at 0.
        self.recur_weights = np.zeros((num_neurons, num_neurons))

        # Poisson neurons: here we use one per network neuron.
        # Their spikes will directly update the excitatory conductance g_ex.
        self.poisson_rate = poisson_rate
        # g_ex: excitatory conductance for each network neuron.
        self.g_ex = np.zeros(num_neurons)
        
        # Membrane potential for each network neuron
        self.V_mem = np.full(num_neurons, V_rest)  # Initialize to resting potential
        self.spike_train = np.zeros(num_neurons)  # Track spikes
        
        # Stimulus properties (for generating input spike patterns)
        self.stimulus_width = stimulus_width
        self.preferred_stimulus = np.linspace(0, 1000, num_inputs)
        self.current_stimulus = None
        
        # Traces for feedforward updates:
        self.pre_trace_feed = np.zeros(num_inputs)    # one per input neuron
        self.post_trace_feed = np.zeros(num_neurons)    # one per network neuron
        
        # Traces for recurrent updates:
        self.pre_trace_recur = np.zeros(num_neurons)      # for network neurons (presynaptic part of recurrent)
        self.post_trace_recur = np.zeros(num_neurons)     # for network neurons (postsynaptic part of recurrent)
    
    def decay_traces(self):
        """Decay all traces over one time step."""
        self.pre_trace_feed *= np.exp(-self.dt / self.tau_pre)
        self.post_trace_feed *= np.exp(-self.dt / self.tau_post)
        self.pre_trace_recur *= np.exp(-self.dt / self.tau_pre)
        self.post_trace_recur *= np.exp(-self.dt / self.tau_post)
    
    def update_poisson_input(self):
        """
        Each network neuron receives input from one dedicated Poisson neuron.
        For each network neuron, if its Poisson neuron fires (with probability poisson_rate*dt/1000),
        g_ex is updated by a fixed factor (here, 0.096).
        """
        poisson_spikes = np.random.rand(self.num_neurons) < (self.poisson_rate * self.dt / 1000.0)
        self.g_ex = 0.5 * poisson_spikes.astype(float)  #0.096 in the paper but made it bigger 
        
    def generate_stimulus(self, sim_time):
        num_steps = int(sim_time / self.dt)
        spike_array = np.zeros((num_steps, self.num_inputs), dtype=int)
        current_step = 0
        
        while current_step < num_steps:
            interval_duration = np.random.exponential(scale=self.mean_stim_time)
            interval_steps = max(1, int(interval_duration / self.dt))
            end_step = min(num_steps, current_step + interval_steps)
            
            s = np.random.uniform(1, 1000)
            a_vals = np.arange(self.num_inputs)
            rates = self.R0 + self.R1 * (
                np.exp(-((s - a_vals) ** 2) / (2 * self.sigma ** 2)) +
                np.exp(-((s + 1000 - a_vals) ** 2) / (2 * self.sigma ** 2)) +
                np.exp(-((s - 1000 - a_vals) ** 2) / (2 * self.sigma ** 2))
            )
            
            p = rates * self.dt / 1000
            
            for t in range(current_step, end_step):
                spike_array[t, :] = (np.random.rand(self.num_inputs) < p).astype(int)
            
            current_step = end_step
        
        return spike_array


    def update_membrane_potential(self, pre_spikes_feed, pre_spikes_recur):
        """
        Update the membrane potential of each network neuron (LIF dynamics).
        Includes:
        - Exponential decay of excitatory conductance g_ex
        - Updates from presynaptic spikes (both feedforward and recurrent)
        - LIF membrane potential update
        """
        # Decay excitatory conductance
        self.g_ex *= np.exp(-self.dt / self.tau_s)
        print(f"g_ex before input spikes (first 10 neurons): {self.g_ex[:10]}")

        # Increase g_ex due to feedforward spikes
        self.g_ex += np.dot(pre_spikes_feed, self.ff_weights)*3  # Input spikes to Network neurons

        # Increase g_ex due to recurrent spikes
        self.g_ex += np.dot(pre_spikes_recur, self.recur_weights)*3  # Recurrent spikes to Network neurons
        print(f"g_ex after input spikes (first 10 neurons): {self.g_ex[:10]}")

        # Compute synaptic current
        synaptic_current = self.g_ex - self.g_leak * (self.V_mem - self.V_rest)
        # Compute synaptic and leak currents
        synaptic_current = self.g_ex * (self.E_ex - self.V_mem)  # Excitatory synaptic current
        leak_current = self.g_leak * (self.V_rest - self.V_mem)  # Leak current
        total_current = synaptic_current + leak_current

        # Update membrane potential using LIF dynamics
        dV = (synaptic_current / self.C_m) * (self.dt / self.tau_m)
        self.V_mem += dV

        # Check for spikes
        spike = self.V_mem >= self.V_thresh

        # Reset spiking neurons
        self.V_mem[spike] = self.V_reset
        self.spike_train[spike] = 1  # Mark spike occurrence

    
    def update_pre_feedforward(self, pre_spikes):
        print("Updating pre feedforward weights...")
        print("Pre-spikes count:", np.sum(pre_spikes))
        self.pre_trace_feed[pre_spikes] += 1
        print("Pre-trace (sample):", self.pre_trace_feed[:10])  # Print first 10 values
        
        for i in np.where(pre_spikes)[0]:
            self.ff_weights[i, :] -= self.B_ff * self.A_minus_ff * self.post_trace_feed * (self.mask[i, :] == 1)
        self.ff_weights *= self.mask  # Enforce sparsity
        self.ff_weights = np.clip(self.ff_weights, 0, self.g_max)

    def update_post_feedforward(self, post_spikes):
        print("Updating post feedforward weights...")
        print("Post-spikes count:", np.sum(post_spikes))
        self.post_trace_feed[post_spikes] += 1
        print("Post-trace (sample):", self.post_trace_feed[:10])  # Print first 10 values
        
        for j in np.where(post_spikes)[0]:
            self.ff_weights[:, j] += self.B_ff * self.A_plus_ff * self.pre_trace_feed * (self.mask[:, j] == 1)
        self.ff_weights *= self.mask
        self.ff_weights = np.clip(self.ff_weights, 0, self.g_max)
    
    def update_recurrent_weights(self, pre_spikes, post_spikes):
        """
        Update recurrent weights based on pre and post spikes.
        """
        # Update pre-synaptic trace for neurons that spiked
        self.pre_trace_recur[pre_spikes] += 1
        
        # Apply STDP depression (A_minus) for presynaptic spikes
        for i in np.where(pre_spikes)[0]:
            self.recur_weights[i, :] -= self.B_recur * self.A_minus_recur * self.post_trace_recur

        # Update post-synaptic trace for neurons that spiked
        self.post_trace_recur[post_spikes] += 1

        # Apply STDP potentiation (A_plus) for postsynaptic spikes
        for j in np.where(post_spikes)[0]:
            self.recur_weights[:, j] += self.B_recur * self.A_plus_recur * self.pre_trace_recur
        print("Before Clipping: Min weight =", np.min(self.recur_weights))
        self.recur_weights = np.clip(self.recur_weights, 0, self.g_max)
        print("After Clipping: Min weight =", np.min(self.recur_weights))

        # Clip weights between 0 and g_max
        #self.recur_weights = np.clip(self.recur_weights, 0, self.g_max)
    
    def simulate(self, T=500, feed_pre_times=[10, 50, 120], feed_post_times=[15, 55, 130],
                recur_pre_times=[60, 140], recur_post_times=[65, 150]):
        """
        Run the simulation for T time steps.
        """
        weights_history_ff = []  # Track feedforward weights over time
        weights_history_recur = []  # Track recurrent weights over time
        
        for t in range(T):
            self.decay_traces()
            self.update_poisson_input()
            input_spikes = self.generate_stimulus(T)  # Generate input spikes

            self.spike_train.fill(0)  # Reset spike train before each time step
            print(f"Input spikes (t={t}): {input_spikes}")
            print(f"Time {t}: Input spikes count = {np.sum(input_spikes)}")

            # Extract pre-synaptic spike activity
            #pre_spikes_feed = input_spikes.astype(bool)  # Input neurons
            pre_spikes_feed = input_spikes[t]
            pre_spikes_recur = self.spike_train.astype(bool)  # Network neurons (previous step)

            # Update membrane potentials with proper conductance dynamics
            self.update_membrane_potential(pre_spikes_feed, pre_spikes_recur)
            # **Detect postsynaptic spikes in the current time step**
            post_spikes = self.spike_train.astype(bool)
            print(f"Time {t}: {np.sum(post_spikes)} neurons spiked")
            print(f"Time {t}: Membrane potential sample: {self.V_mem[:250]}")
            print(f"Max V_mem at t={t}: {np.max(self.V_mem)}")
            print(f"Time {t}: g_ex (first 10 neurons) = {self.g_ex[:10]}")
            # Update feedforward weights
            if t in feed_pre_times:
                self.update_pre_feedforward(pre_spikes_feed)
                print(f"Time {t}: Updating pre feedforward weights...")
            if t in feed_post_times:
                self.update_post_feedforward(post_spikes)
                print(f"Time {t}: Updating post feedforward weights...")

            # Update recurrent weights
            self.update_recurrent_weights(pre_spikes_recur, post_spikes)

            # Store weight history
            weights_history_ff.append(self.ff_weights.copy())
            weights_history_recur.append(self.recur_weights.copy())

        return np.array(weights_history_ff), np.array(weights_history_recur)


# Run the simulation
network = STDP_Network()
ff_history, recur_history = network.simulate()

# Plotting final feedforward and recurrent weight distributions
plt.figure(figsize=(8, 6))
plt.imshow(ff_history[-1], aspect='auto', cmap='gray_r',
           extent=[0, network.num_neurons, 0, network.num_inputs])
plt.colorbar(label="Feedforward Weight Strength")
plt.xlabel("Network Neuron")
plt.ylabel("Input Neuron")
plt.title("Final Feedforward Synaptic Weight Distribution")
plt.show()

plt.figure(figsize=(8, 6))
plt.imshow(recur_history[-1], aspect='auto', cmap='gray',
           extent=[0, network.num_neurons, 0, network.num_neurons])
plt.colorbar(label="Recurrent Weight Strength")
plt.xlabel("Network Neuron (Postsynaptic)")
plt.ylabel("Network Neuron (Presynaptic)")
plt.title("Final Recurrent Synaptic Weight Distribution")
plt.show()
