# Installs and Imports

In [2]:
import numpy as np

# A simulator for a discrete-time linear Spring-Mass-Damper (SMD) environment.

In [3]:
class SpringMassDamperSimulator:
    def __init__(self, mass, spring_const, damping_coeff, dt, noise_cov):
        """
        Initialize the SMD simulator.
        
        Args:
            mass (float): mass m.
            spring_const (float): spring constant k.
            damping_coeff (float): damping coefficient b.
            dt (float): time-step size.
            noise_cov (ndarray): Process noise covariance matrix, shape (2, 2).
        """
        self.m = mass
        self.k = spring_const
        self.b = damping_coeff
        self.dt = dt
        self.noise_cov = noise_cov
        # Define the observation matrix C.
        # For example, we may assume we only observe position.
        self.C = np.array([[1.0, 0.0]])  # shape (1, 2)

    def step(self, state, control_input):
        """
        Update the system state using Euler integration.
        
        Discrete-time dynamics:
            x₁(t+1) = x₁(t) + dt * x₂(t)
            x₂(t+1) = x₂(t) + dt * (1/m * (control_input - b*x₂(t) - k*x₁(t)))
            
        Also adds process noise v ~ N(0, noise_cov).
        
        Args:
            state (ndarray): Current state [position; velocity], shape (2,).
            control_input (float): Control input applied at time t.
        
        Returns:
            next_state (ndarray): Updated state, shape (2,).
            observation (ndarray): Noisy observation, shape matches self.C * state.
        """
        pos, vel = state  # Unpack position and velocity
        
        # Compute derivatives using SMD dynamics.
        dpos_dt = vel
        dvel_dt = (control_input - self.b * vel - self.k * pos) / self.m
        
        # Euler integration to update state.
        state_dot = np.array([dpos_dt, dvel_dt])
        next_state = state + self.dt * state_dot
        
        # Add process noise v_t ~ N(0, noise_cov)
        process_noise = np.random.multivariate_normal(mean=np.zeros(2), cov=self.noise_cov)
        next_state = next_state + process_noise
        
        # Get noisy observation using the 'observe' method.
        # We can set our own observation noise covariance in the observe method.
        observation = self.observe(next_state)
        
        return next_state, observation

    def observe(self, state, obs_noise_cov=np.array([[0.1]])):
        """
        Generate a noisy observation of the current state.
        
        Observation equation:
            y_t = C * x_t + w_t,
        where w_t ~ N(0, obs_noise_cov).
        
        Args:
            state (ndarray): State vector, shape (2,).
            obs_noise_cov (ndarray): Observation noise covariance, shape (1, 1) 
                                     if observation is one-dimensional.
        
        Returns:
            observation (ndarray): Noisy observation, shape (1,).
        """
        # Linear observation: we assume we only observe the position.
        observation = self.C @ state  # This is a (1,) vector.
        
        # Add observation noise.
        w = np.random.multivariate_normal(mean=np.zeros(self.C.shape[0]), cov=obs_noise_cov)
        observation = observation + w
        
        return observation

# Spike Coding Network implementation (LIF neurons, spike encoding)

In [6]:
class SpikeCodingNetwork:
    def __init__(self, lambda_decay, dims, threshold=1.0):
        """
        Initialize the Spike Coding Network.
        
        Args:
            lambda_decay (float): Voltage decay rate λ.
            dims (int): Number of neurons (and dimension of state/input/spikes).
            threshold (float): Voltage threshold for spiking.
        """
        self.lambda_decay = lambda_decay
        self.dims = dims
        self.threshold = threshold
        
        # We initialize the slow and fast recurrent weight matrices.
        # In SCNs, fast recurrent connections are typically set as Ω_f = -D^T D.
        # But here I initialize them to zero and hopefully later they can be set/learned?
        # I should check this later...
        self.Omega_s = np.zeros((dims, dims))   # slow recurrent weights
        self.Omega_f = np.zeros((dims, dims))   # fast recurrent weights
        
        # Decoding matrix D maps between spike space and continuous space.
        # For simplicity, we assume D is square and initialize it randomly.
        self.D = np.random.randn(dims, dims)
        
    def encode(self, x):
        """
        Spike encoding of a continuous variable x.
        
        For simplicity, we assume a linear encoding:
            potentials = D^T x
        and then threshold the potentials to generate binary spikes.
        
        Args:
            x (ndarray): Continuous input vector, shape (dims,).
            
        Returns:
            spikes (ndarray): Encoded binary spike vector, shape (dims,).
        """
        potentials = self.D.T @ x
        spikes = (potentials >= self.threshold).astype(float)
        return spikes

    def update_voltages(self, voltages, inputs, r, s, noise_std=0.01):
        """
        Update the membrane voltages using the following dynamics:
        
            v_dot = -λ v + Ω_s r + Ω_f s + Input + noise
            
        and then perform Euler integration:
        
            v_{t+1} = v_t + v_dot   (assuming dt=1)
        
        Args:
            voltages (ndarray): Current membrane voltages, shape (dims,).
            inputs (ndarray): External input currents, shape (dims,).
            r (ndarray): Filtered spike trains (e.g., low-pass filtered spikes), shape (dims,).
            s (ndarray): Recent spikes, shape (dims,).
            noise_std (float): Standard deviation for Gaussian noise.
            
        Returns:
            new_voltages (ndarray): Updated membrane voltages, shape (dims,).
        """
        noise = np.random.normal(0, noise_std, size=voltages.shape)
        # Compute the derivative of the voltages.
        v_dot = (-self.lambda_decay * voltages +
                 self.Omega_s @ r +
                 self.Omega_f @ s +
                 inputs +
                 noise)
        # Euler integration with dt=1.
        new_voltages = voltages + v_dot
        return new_voltages

    def generate_spikes(self, voltages):
        """
        Generate spikes based on the current membrane voltages.
        
        If a neuron's voltage exceeds the threshold, it emits a spike and the voltage is reset.
        
        Args:
            voltages (ndarray): Current membrane voltages, shape (dims,).
            
        Returns:
            spikes (ndarray): Binary vector indicating spikes, shape (dims,).
            updated_voltages (ndarray): Voltages after spike reset, shape (dims,).
        """
        spikes = (voltages >= self.threshold).astype(float)
        # Reset the voltages for neurons that have spiked.
        updated_voltages = np.copy(voltages)
        updated_voltages[spikes == 1] = 0.0
        return spikes, updated_voltages

    def decode(self, spikes):
        """
        Decode spikes to recover the continuous variable estimate.
        
        Here we assume a linear decoding:
            x_hat = D @ spikes
        In a full implementation, one might low-pass filter the spikes (r) and then decode.
        
        Args:
            spikes (ndarray): Binary spike vector, shape (dims,).
            
        Returns:
            x_hat (ndarray): Decoded continuous estimate, shape (dims,).
        """
        x_hat = self.D @ spikes
        return x_hat

# Kalman Filter module to perform optimal estimation (with delayed feedback)

In [9]:
class KalmanFilterSCN:
    def __init__(self, A_init, B_init, C_init, L_init, delay):
        """
        Initialize the Kalman filter with delay.
        
        Args:
            A_init (ndarray): Initial system matrix A, shape (n, n).
            B_init (ndarray): Initial control matrix B, shape (n, p).
            C_init (ndarray): Initial observation matrix C, shape (m, n).
            L_init (ndarray): Initial Kalman gain matrix L, shape (n, m).
            delay (int): Sensory feedback delay τ (in timesteps).
        """
        self.A = A_init
        self.B = B_init
        self.C = C_init
        self.L = L_init
        self.delay = delay
        
        # Should we maintain a buffer for state estimates to handle delays?
        # For simplicity, I'll assume the caller provides the delayed state estimate.
    
    def predict_state(self, x_hat, u):
        """
        Kalman prediction step.
        
        Predict the next state estimate based on the current estimate and control:
            x_hat_pred = A x_hat + B u
        
        Args:
            x_hat (ndarray): Current state estimate, shape (n,).
            u (ndarray): Control input at time t, shape (p,).
            
        Returns:
            x_hat_pred (ndarray): Predicted next state, shape (n,).
        """
        x_hat_pred = self.A @ x_hat + self.B @ u
        return x_hat_pred

    def update_state(self, x_hat_pred, y_delayed, x_hat_delayed):
        """
        Kalman update step, incorporating delayed observation.
        
        This function implements:
            x_hat_updated = x_hat_pred + L (y_{t+1-delay} - C x_hat_delayed)
        
        Args:
            x_hat_pred (ndarray): Predicted state estimate at time t+1, shape (n,).
            y_delayed (ndarray): Observation at time (t+1 - delay), shape (m,).
            x_hat_delayed (ndarray): State estimate at time (t+1 - delay), shape (n,).
            
        Returns:
            x_hat_updated (ndarray): Updated state estimate, shape (n,).
        """
        innovation = y_delayed - self.C @ x_hat_delayed  # prediction error (innovation)
        x_hat_updated = x_hat_pred + self.L @ innovation
        return x_hat_updated

    def adapt_kalman_gain(self, error_current, error_delayed, learning_rate):
        """
        Adapt the Kalman gain L using a local plasticity rule.
        
        Based on the Bio-OFC rule:
            ΔL ∝ L (error_current) (error_delayed)^T
        
        where error_current = y_t - C x_hat_t and error_delayed = y_{t-delay} - C x_hat_{t-delay}.
        
        Args:
            error_current (ndarray): Current prediction error, shape (m,).
            error_delayed (ndarray): Delayed prediction error, shape (m,).
            learning_rate (float): Step size for the adaptation.
            
        Updates self.L in place.
        """
        # Outer product of errors (we assume L multiplies from left).
        # TODO: Dimensional check before multiplication?
        delta_L = learning_rate * self.L * np.outer(error_current, error_delayed)
        self.L += delta_L