# Installs and Imports

Disclaimer: You should create a virtual environment using the provided environment.yml and activate it before running the codes below

In [1]:
import numpy as np

# Define the Spring-Mass-Damper System Class

In [None]:
class SMDSystem:
    def __init__(self, A, B, C, noise_params):
        """
        Initialize the SMD system with:
          - System matrices A, B, C.
          - Noise parameters (process noise V, sensor noise W).
        """
        self.A = A  # state-transition matrix
        self.B = B  # control input matrix
        self.C = C  # observation matrix
        self.V = noise_params['V']  # process noise covariance
        self.W = noise_params['W']  # observation noise covariance
        
    def step(self, x, u):
        """
        Given the current state x and control input u, compute:
          - Next state: x_next = A*x + B*u + process_noise.
          - Observation: y = C*x_next + sensor_noise.
          
        Parameters:
            x : ndarray, shape (m,)
                Current state vector.
            u : ndarray, shape (k,)
                Control input vector.
        
        Returns:
            x_next : ndarray, shape (m,)
                Next state vector.
            y : ndarray, shape (n,)
                Observation vector.
        """
        m = self.A.shape[0]
        n = self.C.shape[0]
        
        # Process noise: sample from N(0, V)
        process_noise = np.random.multivariate_normal(np.zeros(m), self.V)
        
        # Compute next state: x_next = A*x + B*u + process_noise
        x_next = self.A.dot(x) + self.B.dot(u) + process_noise
        
        # Sensor noise: sample from N(0, W)
        sensor_noise = np.random.multivariate_normal(np.zeros(n), self.W)
        
        # Compute observation: y = C*x_next + sensor_noise
        y = self.C.dot(x_next) + sensor_noise
        
        return x_next, y

# Define the Spiking Neural Estimator Class (SCN)

In [2]:
class SpikingNeuralEstimator:
    def __init__(self, num_neurons, state_dim, params):
        """
        Initialize the SCN with:
          - Number of neurons (N).
          - State dimension (for position and velocity).
          - Parameters including:
              - Decoding matrix D.
              - Fast recurrent connectivity (Omega_f).
              - Slow recurrent connectivity (Omega_s).
              - Leak rate (lambda), threshold, and time step dt.
        
        Parameters:
            num_neurons: int
                Number of spiking neurons (N).
            state_dim: int
                Dimension of the state (e.g., 2 for [position, velocity]).
            params: dict
                Should contain:
                  - 'D_init': initial decoding matrix (or None to initialize randomly)
                  - 'leak_rate': leak rate (lambda)
                  - 'threshold': spike threshold
                  - 'dt': time step (default 1)
        """
        self.N = num_neurons
        self.state_dim = state_dim
        self.dt = params.get('dt', 1.0)
        self.leak_rate = params.get('leak_rate', 0.1)
        self.threshold = params.get('threshold', 1.0)
        
        # Initialize decoding matrix D: shape (state_dim, num_neurons)
        # If provided, use it; otherwise, initialize with random values and normalize each column.
        if params.get('D_init') is not None:
            self.D = params['D_init']
        else:
            self.D = np.random.randn(state_dim, num_neurons)
            # Normalize each column to unit norm
            self.D /= np.linalg.norm(self.D, axis=0, keepdims=True)
        
        # Compute fast recurrent connectivity Omega_f as -D^T * D
        self.Omega_f = -self.D.T.dot(self.D)
        
        # For slow recurrent connectivity, if needed, we set it to zero for now.
        self.Omega_s = np.zeros_like(self.Omega_f)
        
    def update_neurons(self, v, net_input):
        """
        Update the membrane potentials v based on:
          - Leak: -leak_rate * v.
          - Net input: provided externally (this can include feedforward input,
            plus contributions from recurrent connections computed elsewhere).
        Check for threshold crossings to produce spikes.
        
        Parameters:
            v: ndarray, shape (N,)
                Current membrane potentials.
            net_input: ndarray, shape (N,)
                Total input current to each neuron.
        
        Returns:
            s: ndarray, shape (N,)
                Binary spike vector (1 if neuron spikes, 0 otherwise).
            v_new: ndarray, shape (N,)
                Updated membrane potentials (reset for neurons that spiked).
        """
        # Update the membrane potential using Euler integration
        # v_new = v + dt * ( -leak_rate * v + net_input )
        v_new = v + self.dt * (-self.leak_rate * v + net_input)
        
        # Determine spikes: a neuron fires if its membrane potential crosses the threshold.
        s = (v_new >= self.threshold).astype(float)
        
        # Reset the membrane potential for spiking neurons (here, we reset to 0)
        v_new = v_new * (1 - s)
        
        return s, v_new

    def update_filtered_spikes(self, r, s):
        """
        Update the filtered spike train r using an exponential decay:
            r_new = r + dt * (-leak_rate * r + s)
        
        Parameters:
            r: ndarray, shape (N,)
                Current filtered spike train.
            s: ndarray, shape (N,)
                Current spike vector.
        
        Returns:
            r_new: ndarray, shape (N,)
                Updated filtered spike train.
        """
        r_new = r + self.dt * (-self.leak_rate * r + s)
        return r_new

    def decode_state(self, r):
        """
        Decode the current state estimate from the filtered spike train:
            x_hat = D * r
        
        Parameters:
            r: ndarray, shape (N,)
                Filtered spike train.
        
        Returns:
            x_hat: ndarray, shape (state_dim,)
                Decoded state estimate.
        """
        x_hat = self.D.dot(r)
        return x_hat

# Define the Spiking Controller Class

In [3]:
import numpy as np

class SpikingController:
    def __init__(self, control_dim, num_neurons, learning_rate, sigma=0.1):
        """
        Initialize the controller that produces the control signal u.
          - control_dim: dimension of control (for SMD, typically 1).
          - num_neurons: same as SCN's N.
          - learning_rate: step size for policy gradient updates.
          - sigma: standard deviation of exploration noise.
        """
        self.control_dim = control_dim
        self.N = num_neurons
        self.learning_rate = learning_rate
        self.sigma = sigma
        # Initialize D_u (control readout weights) with small random values.
        self.D_u = np.random.randn(control_dim, num_neurons) * 0.01

    def get_control(self, r):
        """
        Compute the control signal:
            u = D_u * r + exploration_noise
        where exploration_noise (xi) is drawn from a Gaussian distribution
        with mean zero and standard deviation sigma.
        
        Parameters:
            r : ndarray, shape (num_neurons,)
                The filtered spike train from the SCN.
        
        Returns:
            u : ndarray, shape (control_dim,)
                The computed control signal.
            xi : ndarray, shape (control_dim,)
                The exploration noise added to the control signal,
                which will be used later in the weight update.
        """
        # Compute deterministic control
        u_det = self.D_u.dot(r)
        # Sample exploration noise xi ~ N(0, sigma^2)
        xi = np.random.randn(self.control_dim) * self.sigma
        # Compute final control signal
        u = u_det + xi
        return u, xi

    def update_weights(self, r, xi, cost):
        """
        Update the control readout weights D_u using a policy gradient update:
            ΔD_u = -learning_rate * cost * xi * r^T
        This mimics a reward-modulated synaptic plasticity rule.
        
        Parameters:
            r : ndarray, shape (num_neurons,)
                The filtered spike train from the SCN.
            xi : ndarray, shape (control_dim,)
                The exploration noise sampled during control.
            cost : float
                The cost (or error signal) incurred at the current timestep.
        """
        # Compute the update: outer product between xi and r, scaled by cost.
        update = -self.learning_rate * cost * np.outer(xi, r)
        self.D_u += update