In [None]:
import numpy as np

class STDPNetwork:
    def __init__(self, n_input, n_output, tau_m=20.0, tau_s=5.0, threshold=1.0, A_plus=0.01, A_minus=0.01):
            """
                    Initialize STDP network.
                            :param n_input: Number of input neurons
                                    :param n_output: Number of output neurons
                                            :param tau_m: Membrane time constant (ms)
                                                    :param tau_s: STDP time constant (ms)
                                                            :param threshold: Firing threshold
                                                                    :param A_plus: LTP strength
                                                                            :param A_minus: LTD strength
                                                                                    """
                                                                                            # Neuron parameters
                                                                                                    self.tau_m = tau_m
                                                                                                            self.tau_s = tau_s
                                                                                                                    self.threshold = threshold
                                                                                                                            self.A_plus = A_plus
                                                                                                                                    self.A_minus = A_minus
                                                                                                                                            
                                                                                                                                                    # Synaptic weights (input → output)
                                                                                                                                                            self.weights = np.random.rand(n_input, n_output) * 0.1  # Small initial weights
                                                                                                                                                                    
                                                                                                                                                                            # Neuron state tracking
                                                                                                                                                                                    self.membrane = np.zeros(n_output)        # Output neuron membrane potentials
                                                                                                                                                                                            self.spike_times = {}                     # Spike times: {time_step: [neuron_indices]}
                                                                                                                                                                                                
                                                                                                                                                                                                    def reset(self):
                                                                                                                                                                                                            """Reset network state."""
                                                                                                                                                                                                                    self.membrane.fill(0)
                                                                                                                                                                                                                            self.spike_times.clear()
                                                                                                                                                                                                                                
                                                                                                                                                                                                                                    def lif_neuron(self, inputs, dt=1.0):
                                                                                                                                                                                                                                            """
                                                                                                                                                                                                                                                    Simulate LIF neuron dynamics.
                                                                                                                                                                                                                                                            :param inputs: Input currents (array of size n_input)
                                                                                                                                                                                                                                                                    :param dt: Time step (ms)
                                                                                                                                                                                                                                                                            """
                                                                                                                                                                                                                                                                                    # Compute input current to output neurons (weighted sum)
                                                                                                                                                                                                                                                                                            I = np.dot(inputs, self.weights)
                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                            # Update membrane potential (LIF equation)
                                                                                                                                                                                                                                                                                                                    self.membrane += (-self.membrane / self.tau_m + I) * dt
                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                    # Check for spikes
                                                                                                                                                                                                                                                                                                                                            spikes = self.membrane >= self.threshold
                                                                                                                                                                                                                                                                                                                                                    spike_indices = np.where(spikes)[0]
                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                    # Record spike times and reset membrane potential
                                                                                                                                                                                                                                                                                                                                                                            if len(spike_indices) > 0:
                                                                                                                                                                                                                                                                                                                                                                                        current_time = len(self.spike_times)
                                                                                                                                                                                                                                                                                                                                                                                                    self.spike_times[current_time] = spike_indices
                                                                                                                                                                                                                                                                                                                                                                                                                self.membrane[spike_indices] = 0  # Reset spiked neurons
                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                return spikes
                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                        def stdp_update(self, pre_spikes, post_spikes, current_time, dt=1.0):
                                                                                                                                                                                                                                                                                                                                                                                                                                                """
                                                                                                                                                                                                                                                                                                                                                                                                                                                        Update weights via STDP.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                :param pre_spikes: Input neuron spikes (n_input array)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                        :param post_spikes: Output neuron spikes (n_output array)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                :param current_time: Current time step
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        """
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # Update weights for all input-output connections
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        for i in np.where(pre_spikes)[0]:    # Presynaptic neuron (input)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    for j in np.where(post_spikes)[0]:  # Postsynaptic neuron (output)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    # LTP: Presynaptic fires before postsynaptic
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    delta_t = current_time - self.spike_times.get(current_time, [])
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    if delta_t > 0:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        dw = self.A_plus * np.exp(-delta_t / self.tau_s)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            self.weights[i, j] += dw
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    # LTD: Postsynaptic fires before presynaptic
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                for t in self.spike_times:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if t < current_time:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    delta_t = current_time - t
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        dw = -self.A_minus * np.exp(delta_t / self.tau_s)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            self.weights[i, :] += dw