In [None]:
import numpy as np # type: ignore
import matplotlib.pyplot as plt # type: ignore
import heapq

np.random.seed(0)

In [None]:
class Neuron:
    """Represents a neuron in the spiking neural network using SRM."""
    tau = 10.0  # PSP decay time constant (ms)
    tau_ref = 5.0  # Refractory decay time constant (ms)
    threshold = 1.0  # Firing threshold
    tau_abs = 2.0  # Absolute refractory period (ms)
    
    def __init__(self):
        """
        Initialize a neuron with SRM parameters.
        
        Args:
            neuron_id (int): Unique identifier for the neuron.
            tau (float): PSP decay time constant (ms).
            tau_ref (float): Refractory decay time constant (ms).
            threshold (float): Firing threshold.
            tau_abs (float): Absolute refractory period (ms).
        """
        self.presynaptic = []  # List of (pre_neuron_id, weight) tuples
        self.postsynaptic = {} # Dictionary of (post_neuron_id: probability) pairs
        self.v = {}  # PSP contributions {pre_neuron_id: v_j}
        self.r = 0.0  # Refractory effect
        self.last_update_time = 0.0
        self.last_spike_time = -float('inf')
        self.spike_list = []  # Record of spike times

    def update_state(self, time):
        """Update PSPs and refractory effect to the current time."""
        if time > self.last_update_time:
            delta_t = time - self.last_update_time
            exp_tau = np.exp(-delta_t / self.tau)
            for pre_id in self.v:
                self.v[pre_id] *= exp_tau
            self.r *= np.exp(-delta_t / self.tau_ref)
            self.last_update_time = time

    def compute_potential(self):
        """Compute current membrane potential."""
        u = sum(weight * self.v.get(pre_id, 0.0) for pre_id, weight in self.presynaptic)
        u += self.r
        return u

    def receive_spike(self, time, pre_neuron_ids):
        """
        Modified to process multiple presynaptic spikes and update the potential accordingly.
        
        Returns:
            list: IDs of postsynaptic neurons that received the spike.
            bool: False if no spike was generated, otherwise the list of postsynaptic neurons.
        """
        self.update_state(time)
        # Increment PSP contribution from presynaptic neuron
        for pre_id in pre_neuron_ids:
            if pre_id not in self.v:
                self.v[pre_id] = 0.0
            self.v[pre_id] += 1.0
        # Update the refractory effect
        u = self.compute_potential()
        if u >= self.threshold and time > self.last_spike_time + self.tau_abs:
            self.spike_list.append(time)
            self.last_spike_time = time
            self.r -= self.threshold  # Apply refractory effect

            # Check if postsynaptic is empty (if this is an output neuron)
            if not self.postsynaptic:
                return True

            # Return the spike to the postsynaptic neurons according to the probability
            post_id_list = []
            for post_id, probability in self.postsynaptic.items():
                if np.random.rand() < probability:
                    post_id_list.append(post_id)
            return post_id_list
        return []


In [None]:
class SpikingNeuralMatrix:
    def __init__(self, num_input_neurons, num_output_neurons, num_hidden_neurons):
        """
        Initialize the spiking neural matrix with the specified number of neurons.
        
        Args:
            num_input_neurons (int): Number of input neurons.
            num_output_neurons (int): Number of output neurons.
            num_hidden_neurons (int): Number of hidden neurons.
        """
        self.input_neurons = {"in" + str(i): Neuron() for i in range(num_input_neurons)}
        self.hidden_neurons = {"hid" + str(i): Neuron() for i in range(num_hidden_neurons)}
        self.output_neurons = {"out" + str(i): Neuron() for i in range(num_output_neurons)}

        # Connect sensory neurons to input neurons, input neurons to hidden neurons with weights and probabilities
        id = 0
        for in_id, in_neuron in self.input_neurons.items():
            in_neuron.presynaptic.append(("sen" + str(id), 1))
            id += 1

            for hid_id, hid_neuron in self.hidden_neurons.items():
                in_neuron.postsynaptic[hid_id] = np.random.rand()
                hid_neuron.presynaptic.append((in_id, np.random.normal(0.5, 0.1)))

        # Connect hidden neurons to output neurons with random weights and probabilities
        for hid_id, hid_neuron in self.hidden_neurons.items():
            for out_id, out_neuron in self.output_neurons.items():
                hid_neuron.postsynaptic[out_id] = np.random.rand()
                out_neuron.presynaptic.append((hid_id, np.random.rand()))

            for hid_id2, hid_neuron2 in self.hidden_neurons.items():
                if hid_id != hid_id2:
                    hid_neuron.postsynaptic[hid_id2] = np.random.rand()
                    hid_neuron2.presynaptic.append((hid_id, np.random.normal(0.5, 0.1)))

        # Normalize the weights of the presynaptic neurons for output neurons by dividing by the total weight
        for neuron in self.output_neurons.values():
            total_weight = sum(weight for _, weight in neuron.presynaptic)
            if total_weight > 0:
                for i in range(len(neuron.presynaptic)):
                    pre_id, weight = neuron.presynaptic[i]
                    neuron.presynaptic[i] = (pre_id, weight / total_weight)

    def simulate(self, input_spikes, max_time=100.0, time_step=1):
        """
        Simulate the spiking neural network over a specified time period.
        
        Args:
            input_spikes (list): List of tuples (neuron_id, spike_time).
            time_step (float): Time step for simulation (ms).
            total_time (float): Total simulation time (ms).

        The processing flow:
        1. Process input spikes (If at event time) and update the state of the input neurons to get hidden postsynaptic neurons (that receive spikes).
        2. Process hidden postsynaptic spikes (If exits) and update the state of the hidden neurons to get hidden and output postsynaptic neurons (that receive spikes).
        3. Repeat the process.
        """
        output = {}
        event_queue = []
        for sensory_id, neuron_ids, spike_time in input_spikes:
            heapq.heappush(event_queue, (spike_time, sensory_id, neuron_ids, "Input"))

        # Event loop
        # Clear debug log file at the start
        with open("debug_log.txt", "w") as f:
            f.write("Starting new simulation\n\n")
        while event_queue:
            current_time, pre_neuron_ids, post_neuron_ids, event_type = heapq.heappop(event_queue)
            if current_time < 0:
                continue
            if current_time > max_time:
                break

            if not pre_neuron_ids or not any(post_neuron_ids):
                continue

            # Create a dictionary {post_neuron_id: [pre_neuron_id, ...]} to store spikes that post neurons receive
            post_neuron_spikes = {}
            for idx, post_neurons in enumerate(post_neuron_ids):
                pre_neuron_id = pre_neuron_ids[idx]
                for post_neuron_id in post_neurons:
                    if post_neuron_id not in post_neuron_spikes:
                        post_neuron_spikes[post_neuron_id] = []
                    post_neuron_spikes[post_neuron_id].append(pre_neuron_id)

            # Write debug log into a seperate txt file
            with open("debug_log.txt", "a") as f:
                f.write(f"Time: {current_time}, Pre_neuron_ids: {pre_neuron_ids}, Post_neuron_ids: {post_neuron_ids}, Event_type: {event_type}\n")
                f.write(f"Post neuron receiving spikes: {post_neuron_spikes}\n\n")


            if event_type == "Input":
                new_post_neurons = []
                new_pre_neurons = []
                for post_neuron_id, pre_neuron_ids in post_neuron_spikes.items():
                    spiked_neurons = self.input_neurons[post_neuron_id].receive_spike(current_time, pre_neuron_ids)
                    new_post_neurons.append(spiked_neurons)
                    new_pre_neurons.append(post_neuron_id)
                # Add the new spikes to the event queue
                heapq.heappush(event_queue, (current_time + time_step, new_pre_neurons, new_post_neurons, "Hidden"))

            elif event_type == "Hidden":
                new_post_neurons = []
                new_pre_neurons = []
                for post_neuron_id, pre_neuron_ids in post_neuron_spikes.items():
                    spiked_neurons = self.hidden_neurons[post_neuron_id].receive_spike(current_time, pre_neuron_ids)
                    new_post_neurons.append(spiked_neurons)
                    new_pre_neurons.append(post_neuron_id)
                # Add the new spikes to the event queue
                heapq.heappush(event_queue, (current_time + time_step, new_pre_neurons, new_post_neurons, "Output"))

            elif event_type == "Output":
                new_post_neurons = []
                new_pre_neurons = []
                for post_neuron_id, pre_neuron_ids in post_neuron_spikes.items():
                    if post_neuron_id in self.output_neurons.keys():
                        self.output_neurons[post_neuron_id].receive_spike(current_time, pre_neuron_ids)
                    else:
                        spiked_neurons = self.hidden_neurons[post_neuron_id].receive_spike(current_time, pre_neuron_ids)
                        new_post_neurons.append(spiked_neurons)
                        new_pre_neurons.append(post_neuron_id)
                # Add the new spikes to the event queue
                heapq.heappush(event_queue, (current_time + time_step, new_pre_neurons, new_post_neurons, "Output"))
        
        # Get the spike times of output neurons as the output
        for neuron_id, neuron in self.output_neurons.items():
            output[neuron_id] = neuron.spike_list
        f.close()
        return output


In [None]:
# Test the SpikingNeuralMatrix class
num_input_neurons = 40
num_hidden_neurons = 1000
num_output_neurons = 2

# Create the spiking neural matrix
snm = SpikingNeuralMatrix(num_input_neurons, num_output_neurons, num_hidden_neurons)

NameError: name 'np' is not defined

In [None]:
if __name__ == "__main__":
    # Simulate input spikes
    # Create a rich set of test input spikes
    # Format: (sensory_neuron_id, [list_of_input_neuron_ids], spike_time)
    input_spikes = [
        # Basic sequential inputs
        (["sen0"], [["in0"]], 0.0),
        (["sen1"], [["in1"]], 2.5),
        (["sen2"], [["in2"]], 5.0),
        
        # Multiple sensory neurons firing simultaneously
        (["sen3", "sen4", "sen5"], [["in3"], ["in4"], ["in5"]], 7.5),
        
        # Burst pattern (high frequency firing of same neuron)
        (["sen6"], [["in6"]], 10.0),
        (["sen6"], [["in6"]], 10.5),
        (["sen6"], [["in6"]], 11.0),
        (["sen6"], [["in6"]], 11.5),
        
        # Regular interval pattern
        (["sen7"], [["in7"]], 15.0),
        (["sen8"], [["in8"]], 20.0),
        (["sen9"], [["in9"]], 25.0),
        
        # Overlapping patterns
        (["sen10"], [["in10"]], 30.0),
        (["sen11"], [["in11"]], 30.5),
        
        # Sequential firing with decreasing intervals
        (["sen12"], [["in12"]], 35.0),
        (["sen13"], [["in13"]], 37.0),
        (["sen14"], [["in14"]], 38.5),
        (["sen15"], [["in15"]], 39.5),
        
        # Cyclic pattern
        (["sen16"], [["in16"]], 40.0),
        (["sen17"], [["in17"]], 45.0),
        (["sen18"], [["in18"]], 50.0),
        (["sen16"], [["in16"]], 55.0),  # Repeat pattern
        
        # Synchronized multiple inputs
        (["sen20", "sen21", "sen22"], [["in20"], ["in21"], ["in22"]], 60.0),
        
        # Mixed frequency patterns
        (["sen23"], [["in23"]], 65.0),
        (["sen23"], [["in23"]], 65.5),
        (["sen24"], [["in24"]], 70.0),
        (["sen24"], [["in24"]], 75.0),
        
        # Four neurons firing in sequence
        (["sen25"], [["in25"]], 80.0),
        (["sen26"], [["in26"]], 81.0),
        (["sen27"], [["in27"]], 82.0),
        (["sen28"], [["in28"]], 83.0),
        
        # Large synchronized burst at the end
        (["sen30", "sen31", "sen32", "sen33", "sen34"], 
        [["in30"], ["in31"], ["in32"], ["in33"], ["in34"]], 
        90.0),
        
        # Final sequence to end the simulation
        (["sen35"], [["in35"]], 95.0),
        (["sen36"], [["in36"]], 97.0),
        (["sen37"], [["in37"]], 98.0),
        (["sen38"], [["in38"]], 99.0),
        (["sen39"], [["in39"]], 100.0)
    ]
    # Simulate the network
    output_spikes = snm.simulate(input_spikes, max_time=100.0, time_step=1.0)

    # Print the output spikes
    print(output_spikes)

{'out0': [10.5, 18.0, 28.0, 38.0, 48.0, 58.0, 68.0, 78.0, 86.0, 98.0], 'out1': [10.5, 18.0, 28.0, 38.0, 48.0, 58.0, 68.0, 78.0, 86.0, 98.0]}
