# Simulating Cooperatively Coding Networks

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.optimize import root

torch.set_default_dtype(torch.float64)

## 1D network

In [None]:
# Define rate network
class IELagRateNetwork(torch.nn.Module):
    def __init__(self, N, tau, dt,
        n_lag, width):
        super(IELagRateNetwork, self).__init__()
        self.N = N
        self.tau = tau
        self.dt = dt
        self.n_lag = n_lag  # IE lag in time steps
        self.width = width  # Width of target exponential RF

        self.state = torch.zeros(N)
        self.tot_state_change_old = torch.zeros(2*n_lag, N)
        self.tot_rec_input_old = torch.zeros(2*n_lag, N)
        self.rec_net_input_old = torch.zeros(2*n_lag, N)
        self.tot_rec_input_expFiltered = torch.zeros(N)
        self.rec_net_input_expFiltered = torch.zeros(N)
        self.gamma_filter = np.exp(-1/n_lag)
        self.gamma_filter_longer = np.exp(-1/(n_lag*1))
        self.d_state = N

        # Initialize network parameters
        self.initialize_base_conn()
        self.oscillating_state = torch.ones(N)
        self.oscillating_state[::2] = -1
        self.oscillating_state /= np.sqrt(N)

        self.w_in_netE = torch.nn.Parameter(torch.tensor(.001), requires_grad=True)
        self.w_in_I = torch.nn.Parameter(torch.tensor(.001), requires_grad=True)
        self.w_rec_netE_param = torch.nn.Parameter(torch.tensor(.001), requires_grad=True)
        self.w_rec_I_param = torch.nn.Parameter(torch.tensor(.001), requires_grad=True)
        self.reset()

    
    @property
    def tau_lag(self):
        return self.n_lag*self.dt

    def initialize_base_conn(self):
        self.w_max_rec_netE = 0.5
        self.w_max_rec_I = -0.5*(1-self.dt/self.tau)
        
        # Nearest neighbor connections with periodic boundary conditions
        self.A_rec = torch.roll(torch.eye(self.N), 1, dims=0) + torch.roll(torch.eye(self.N), -1, dims=0)
        self.A_rec_I = self.A_rec.clone() 
        self.set_width(self.width)

    def reset(self):
        self.state = torch.zeros(self.N)
        self.x_states_delayed = torch.zeros(3*self.n_lag+1, self.N)
        self.dxdt_nows_hist = torch.zeros(self.n_lag+1, self.N)
        self.tot_state_change_old = torch.zeros(2*self.n_lag, self.N)
        self.tot_rec_input_old = torch.zeros(2*self.n_lag, self.N)
        self.rec_net_input_old = torch.zeros(2*self.n_lag, self.N)
        self.tot_rec_input_expFiltered = torch.zeros(self.N)
        self.rec_net_input_expFiltered = torch.zeros(self.N)

    def set_width(self, width):
        self.width = width
        self.gamma = np.exp(-1/width)

    @property
    def w_tot(self):
        # Construct w_tot
        w_tot = torch.zeros((2*self.N, 2*self.N))

        w_tot[:self.N, :self.N] \
            = (1-self.dt/self.tau)*torch.eye(self.N) + (self.dt/self.tau)*self.w_rec_netE*self.A_rec
        w_tot[:self.N, self.N:] \
            = -self.w_rec_I*self.A_rec_I
        w_tot[self.N:, :self.N] \
            = (-self.dt/self.tau)*torch.eye(self.N) + (self.dt/self.tau)*self.w_rec_netE*self.A_rec
        w_tot[self.N:, self.N:] \
            = -self.w_rec_I*self.A_rec_I
        return w_tot

    def get_eig(self):
        w_tot = self.w_tot

        eigvals, eigvecs = torch.linalg.eig(w_tot)
        # Return sorted from smallest to largest real part
        order = torch.argsort(eigvals.real)
        eigvals = eigvals[order]
        eigvecs = eigvecs[:, order]
        return eigvals, eigvecs

    # def set_problematic_eigvecs(self):
    #     eigvals, eigvecs = self.get_eig()
    #     n_problematic = torch.sum(eigvals.real<0)
    #     print(f'Found {n_problematic} problematic (=negative) eigenvalues')
    #     # n_problematic = torch.sum(torch.abs(eigvals.real)>1+1e-6)
    #     if n_problematic>0:
    #         self.problematic_eigvals = eigvals[:n_problematic]
    #         """ Here we take only the 'x' part of the eigenvectors,
    #             but neglect the 'dx' part (x_state-x_state_prev) """
    #         self.problematic_eigvecs = eigvecs.real[:, :n_problematic].T
    #         # To be sure, normalize eigvecs
    #         self.problematic_eigvecs /= torch.linalg.norm(self.problematic_eigvecs, axis=1)[:, np.newaxis]
    #     else:
    #         self.problematic_eigvals = torch.tensor([])
    #         self.problematic_eigvecs = torch.tensor([])


    def saturation(self, x, w_max=0.5):
        return 2*w_max*(torch.sigmoid(2*x/w_max)-0.5)
    
    @property
    def w_rec_netE(self):
        # Parametrize such that w_rec_netE <= 0.5 and its gradient
        # diminishes as w_rec_netE approaches 0.5
        return self.saturation(self.w_rec_netE_param)
    @w_rec_netE.setter
    def w_rec_netE(self, w_rec_netE):
        self.w_rec_netE_param.data = self.get_param_value(w_rec_netE, w_max=self.w_max_rec_netE)
    
    @property
    def w_rec_I(self):
        # Parametrize such that w_rec_I >= -0.5*(1-dt/tau) and its gradient
        # diminishes as w_rec_I approaches -0.5*(1-dt/tau)
        # w_max = -0.5*(1-self.dt/self.tau)
        return self.saturation(self.w_rec_I_param, w_max=self.w_max_rec_I)
    @w_rec_I.setter
    def w_rec_I(self, w_rec_I):
        self.w_rec_I_param.data = self.get_param_value(w_rec_I, w_max=self.w_max_rec_I)

    @property
    def w_in_E(self):
        return self.w_in_netE - self.w_in_I

    @property
    def w_rec_E(self):
        return self.w_rec_netE - self.w_rec_I

    @property
    def w_rec_netE_mat(self):
        return self.w_rec_netE*self.A_rec

    @property
    def w_rec_I_mat(self):
        return self.w_rec_I * self.A_rec_I

    def print_weights(self):
        print(f'w_in_netE:      {self.w_in_netE.item()}')
        print(f'w_in_I:         {self.w_in_I.item()}')
        print(f'w_rec_netE:     {self.w_rec_netE.item()}')
        print(f'w_rec_I:        {self.w_rec_I.item()}')

    def get_weights(self):
        weights = {
            'w_in_netE': self.w_in_netE.item(),
            'w_in_E': self.w_in_E.item(),
            'w_in_I': self.w_in_I.item(),
            'w_rec_netE': self.w_rec_netE.item(),
            'w_rec_netE_param': self.w_rec_netE_param.item(),
            'w_rec_E': self.w_rec_E.item(),
            'w_rec_I': self.w_rec_I.item(),
            'w_rec_I_param': self.w_rec_I_param.item(),
            'w_max_rec_netE': self.w_max_rec_netE,
            'w_max_rec_I': self.w_max_rec_I,
        }
        return weights

    def set_weights(self, weights):
        self.w_in_netE.data = torch.tensor(weights['w_in_netE'])
        self.w_in_I.data = torch.tensor(weights['w_in_I'])
        self.w_rec_netE_param.data = torch.tensor(weights['w_rec_netE_param'])
        self.w_rec_I_param.data = torch.tensor(weights['w_rec_I_param'])
        self.sanitize()

    def get_param_value(self, w_target, w_max):
        # Convert w_target to torch.tensor if necessary
        if not isinstance(w_target, torch.Tensor):
            w_target = torch.tensor(w_target)
        w_param = - (w_max/2) * torch.log( (w_max-w_target) / (w_max+w_target) )
        assert torch.isclose(self.saturation(w_param, w_max), w_target), f'For w_param = {w_param}, self.saturation(w_param, w_max) (={self.saturation(w_param, w_max)}) != w_target (={w_target})'
        return w_param
    
    def set_optimal_net_weights(self, width):
        gamma = np.exp(-1/width)
        w_2 = torch.tensor(1 / (gamma + 1/gamma))
        # Inverse of sigmoid function
        self.w_rec_netE = w_2
        self.w_in_netE.data = 1.*(1.-w_2*2*gamma)

    def stabilize(self, x_state_concat):
        # Stabilize by subtratcing projection onto oscillating state
        for problematic_eigvec in self.problematic_eigvecs:
            x_state_concat -= (x_state_concat@problematic_eigvec)*problematic_eigvec
        return x_state_concat

    def sanitize(self):
        # Sanitize weights
        with torch.no_grad():
            self.w_in_I.data = torch.clip(self.w_in_I.data, None, 0)
            self.w_rec_I_param.data = torch.clip(self.w_rec_I_param.data, None, 0)

    
    def get_state_derivative(self, t_, x_state, x_state_delayed, x_in):
        # t_ is not used, only given for better readability
        x_state_ = x_state.clone()
        dx_state = x_state_-x_state_delayed
        x_state_deriv_times_tau \
            = - x_state_ \
              + self.w_in_netE*x_in \
              + self.w_rec_netE_mat@x_state_  \
              - (self.tau/self.tau_lag)*self.w_rec_I_mat@(dx_state)

        x_state_deriv = x_state_deriv_times_tau/self.tau
        return x_state_deriv
    

    def midpoint_method_step(self, x_state, x_state_delayed, x_in, dxdt):
        # Midpoint method that integrates the ODE
        # x_state'=f(x_state, x_state_delayed, x_in) from t to t+dt.
        # It relies on self.dxdt_nows_hist, which are
        # stored past changes to x_state (oldest first)
        t0 = 0 # Does not matter, is just given for readability
        dt = self.dt
        dxdt_now = dxdt(t0, x_state, x_state_delayed, x_in)
        dxdt_mid = dxdt(t0+dt/2, x_state+(dt/2)*dxdt_now, x_state_delayed+(dt/2)*self.dxdt_nows_hist[0], x_in)
        x_state_new = x_state + dt*dxdt_mid
        self.dxdt_nows_hist = torch.roll(self.dxdt_nows_hist, -1, dims=0)
        self.dxdt_nows_hist[-1] = dxdt_now


        return x_state_new
    
    def forward(self, x_in, x_in_delayed, x_state, x_state_delayed, force_stability=False):
        # Use midpoint method to integrate state derivative
        x_state_new = self.midpoint_method_step(x_state, x_state_delayed, x_in, self.get_state_derivative)

        return x_state_new

### 1D with SFA

In [None]:
def simulate_1D_SFA(d_RF, a_SFA, tau_SFA, N_1D=100, tau=1, dt=0.01, n_steps=1_000, x_init=0., t_pulse_on=None, t_pulse_off=None):
    if t_pulse_on==None:
        t_pulse_on = n_steps//3
    if t_pulse_off==None:
        t_pulse_off = 2*n_steps//3
    w_rec, w_ff = get_weights(d_RF, a_SFA)
    w_rec_mat = np.roll(np.eye(N_1D), 1, axis=0) * w_rec + np.roll(np.eye(N_1D), -1, axis=0) * w_rec

    x_state_data = np.zeros((n_steps, N_1D))
    u_SFA_data = np.zeros((n_steps, N_1D))
    r_in_data = np.zeros((n_steps, N_1D))
    loss_data = np.zeros(n_steps)
    violation_data = np.zeros(n_steps)

    x_state = np.ones(N_1D) * x_init
    u_SFA = np.zeros(N_1D)
    r_in = np.zeros(N_1D)
    x_tar = np.zeros(N_1D)

    x_tar_0 = get_target(d_RF, N_1D)
    
    for i_t in range(n_steps):

        # Set input pulse
        if i_t==t_pulse_on:
            r_in[N_1D//2] = 1
            x_tar = np.copy(x_tar_0)

        if i_t==t_pulse_off:
            r_in[N_1D//2] = 0
            x_tar *= 0

        # Dynamics
        x_state_old = np.copy(x_state)
        x_state += (dt/tau) * ( -x_state 
                               + w_rec_mat@x_state 
                               + w_ff*r_in 
                               - a_SFA*u_SFA)
        u_SFA += (dt/tau_SFA) * (-u_SFA + x_state_old)

        deviation = x_state - x_tar
        dev_abs_sum = np.sum(np.abs(deviation))
        x_state_data[i_t] = x_state
        u_SFA_data[i_t] = u_SFA
        r_in_data[i_t] = r_in
        loss_data[i_t] = dev_abs_sum
        violation_data[i_t] = (dev_abs_sum-np.abs(np.sum(deviation)))/(dev_abs_sum if dev_abs_sum>0 else 1) # Is zero if deviation all have same sign

    x_tar_sum = np.sum(np.abs(x_tar_0))
    loss_data /= x_tar_sum
    t_resp = (np.argmax(loss_data[t_pulse_off:]<np.exp(-1))) \
                * (dt/tau)  # In units of tau

    return x_state_data, u_SFA_data, r_in_data, loss_data, t_resp, violation_data

def get_weights(d_RF, a_SFA):
    gamma = np.exp(-1/d_RF)
    w_rec_eff = gamma/(1+gamma**2)
    w_ff_eff = (1-gamma**2)/(1+gamma**2)
    w_rec = w_rec_eff * (1+a_SFA)
    w_ff = w_ff_eff * (1+a_SFA)
    return w_rec, w_ff

def get_target(d_RF, N_1D):
    gamma = np.exp(-1/d_RF)
    x_tar = np.minimum(gamma**(N_1D//2-np.arange(N_1D)), gamma**(-N_1D//2+np.arange(N_1D)))
    return x_tar

## 2D networks

In [None]:
class IELag2DConvNetwork(torch.nn.Module):
    def __init__(self, N, tau, dt,
        n_lag, width, selectivity='NonlinearMixed'):
        super(IELag2DConvNetwork, self).__init__()
        self.N = N
        self.tau = tau
        self.dt = dt
        self.n_lag = n_lag  # IE lag in time steps
        self.width = width  # Width of target exponential RF
        self.selectivity = selectivity
        self.dim = 2

        self.state = torch.zeros(N, N)
        self.d_state = N*N

        # Initialize network parameters
        self.A_rec = torch.zeros(3,3) # Conv kernel for recurrent connections
        self.A_rec[0, 1] = 1 # Bottom neigbor
        self.A_rec[1, 0] = 1 # Left neighbor
        self.A_rec[1, 2] = 1 # etc.
        self.A_rec[2, 1] = 1
        

        self.set_width(width)

        self.w_in_netE = torch.nn.Parameter(torch.tensor(.001), requires_grad=True)
        self.w_in_I = torch.nn.Parameter(torch.tensor(-.001), requires_grad=True)
        self.w_rec_netE_param = torch.nn.Parameter(torch.tensor(.001), requires_grad=True)
        self.w_rec_I_param = torch.nn.Parameter(torch.tensor(-.001), requires_grad=True)

        self.w_rec_netE_overwrite = None

        # Convolutional layers
        self.conv_rec_netE = torch.nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, padding_mode='zeros', bias=False)
        self.conv_rec_I = torch.nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=0, padding_mode='zeros', bias=False)

        self.set_conv_weights()
        self.reset()

    @property
    def tau_lag(self):
        return self.n_lag*self.dt
    
    def reset(self):
        self.state = torch.zeros(self.N, self.N)
        self.dxdt_nows_hist = torch.zeros(self.n_lag+1, self.N, self.N)

    def pad_state(self, x_state):
        x_state_padded = torch.zeros(x_state.shape[0]+2, x_state.shape[1]+2)
        x_state_padded[1:-1, 1:-1] = x_state
        # Add periodic boundary conditions
        x_state_padded[0,1:-1] = x_state[-1,:]
        x_state_padded[-1,1:-1] = x_state[0,:]
        x_state_padded[1:-1,0] = x_state[:,-1]
        x_state_padded[1:-1,-1] = x_state[:,0]
        return x_state_padded
    
    def compute_w_param_grads(self):
        self.w_rec_netE_param.grad = 0.25*torch.sum(self.conv_rec_netE.weight.grad * self.A_rec) * self.saturation_deriv(self.w_rec_netE_param, w_max=0.25)
        self.w_rec_I_param.grad = 0.25*torch.sum(self.conv_rec_I.weight.grad * self.A_rec) * self.saturation_deriv(self.w_rec_I_param, w_max=-0.25*(1-self.dt/self.tau))

    def set_conv_weights(self):
        self.conv_rec_netE.weight.data = self.w_rec_netE * self.A_rec[np.newaxis,np.newaxis,:,:]
        self.conv_rec_I.weight.data = self.w_rec_I * self.A_rec[np.newaxis,np.newaxis,:,:]

    def flatten(self, x):
        # If x is 2D, flatten it
        if len(x.shape) == 2:
            x_flattened = x.flatten()
        # If x is 4D, flatten first and last two dimensions
        elif len(x.shape) == 4:
            x_flattened = x.view(x.shape[0]*x.shape[1], x.shape[2]*x.shape[3])
        return x_flattened

    def unflatten(self, x):
        # If x is 1D, unflatten it
        if len(x.shape) == 1:
            x_inflattened = x.view(self.N, self.N)
        # If x is 2D, unflatten first and last dimension
        elif len(x.shape) == 2:
            x_inflattened = x.view(self.N, self.N, self.N, self.N)
        return x_inflattened

    def set_width(self, width):
        self.width = width
        self.gamma = np.exp(-1/width)
        self.gamma_squared = np.exp(-2/width)
        assert self.width>0 and self.gamma<1, f'width = {self.width}, gamma = {self.gamma}'

        def get_padding_factor(i, j, dim=0):
            # i, j range from 0 to self.N-1 + 2
            # because state is already padded
            dx = np.abs(i-(self.N//2+1))
            dy = np.abs(j-(self.N//2+1))
            x_amp = self.gamma**(dx-1) # Amplitude contribution at border
            y_amp = self.gamma**(dy-1)
            if dim==0: # Extend along x
                # x direction decays, but y direction does not
                factor = (x_amp*self.gamma+y_amp)/(x_amp+y_amp)
            elif dim==1: # Extend along y
                # y direction decays, but x direction does not
                factor = (x_amp+y_amp*self.gamma)/(x_amp+y_amp)
            return factor
            
        self.pad_vector_x = get_padding_factor(0, torch.arange(self.N)+1, dim=0)
        self.pad_vector_y = get_padding_factor(torch.arange(self.N)+1, 0, dim=1)
        assert torch.all(self.pad_vector_x<=1) and torch.all(self.pad_vector_y<=1), f'max(pad_vector_x) = {torch.max(self.pad_vector_x)}, max(pad_vector_y) = {torch.max(self.pad_vector_y)}'

    def saturation(self, x, w_max=0.25):
        return 2*w_max*(torch.sigmoid(2*x/w_max)-0.5)

    def get_param_value(self, w_target, w_max):
        # Convert w_target to torch.tensor if necessary
        if not isinstance(w_target, torch.Tensor):
            w_target = torch.tensor(w_target)
        w_param = - (w_max/2) * torch.log( (w_max-w_target) / (w_max+w_target) )
        assert torch.isclose(self.saturation(w_param, w_max), w_target), f'For w_param = {w_param}, self.saturation(w_param, w_max) (={self.saturation(w_param, w_max)}) != w_target (={w_target})'
        return w_param
    
    def saturation_deriv(self, x, w_max=0.25):
        return 4*torch.sigmoid(2*x/w_max)*(1-torch.sigmoid(2*x/w_max))
    
    @property
    def w_rec_netE(self):
        # Parametrize such that w_rec_netE <= 0.25 and its gradient
        # diminishes as w_rec_netE approaches 0.25
        if self.w_rec_netE_overwrite is not None:
            return self.w_rec_netE_overwrite
        else:
            return self.saturation(self.w_rec_netE_param, w_max=0.25)
    @w_rec_netE.setter
    def w_rec_netE(self, w_rec_netE):
        self.w_rec_netE_param.data = self.get_param_value(w_rec_netE, w_max=torch.tensor(0.25))
        self.set_conv_weights()
    
    @property
    def w_rec_I(self):
        # Parametrize such that w_rec_I >= -0.25*(1-dt/tau) and its gradient
        # diminishes as w_rec_I approaches -0.25*(1-dt/tau)
        w_max = -0.25
        return self.saturation(self.w_rec_I_param, w_max=w_max)
    @w_rec_I.setter
    def w_rec_I(self, w_rec_I):
        w_max = -0.25
        self.w_rec_I_param.data = self.get_param_value(w_rec_I, w_max=w_max)
        self.set_conv_weights()

    @property
    def w_in_E(self):
        return self.w_in_netE - self.w_in_I

    @property
    def w_rec_E(self):
        return self.w_rec_netE - self.w_rec_I

    @property
    def w_rec_netE_mat(self):
        return self.w_rec_netE*self.A_rec

    @property
    def w_rec_I_mat(self):
        return self.w_rec_I * self.A_rec

    def print_weights(self):
        print(f'w_in_netE:      {self.w_in_netE.item()}')
        print(f'w_in_I:         {self.w_in_I.item()}')
        print(f'w_rec_netE:     {self.w_rec_netE.item()}')
        print(f'w_rec_I:        {self.w_rec_I.item()}')

    def get_weights(self):
        weights = {
            'w_in_netE': self.w_in_netE.item(),
            'w_in_E': self.w_in_E.item(),
            'w_in_I': self.w_in_I.item(),
            'w_rec_netE': self.w_rec_netE.item(),
            'w_rec_netE_param': self.w_rec_netE_param.item(),
            'w_rec_E': self.w_rec_E.item(),
            'w_rec_I': self.w_rec_I.item(),
            'w_rec_I_param': self.w_rec_I_param.item(),
        }
        return weights

    def set_weights(self, weights):
        # # Should not do anything, but seems necessary
        # self.w_rec_netE = torch.tensor(weights['w_rec_netE'])
        # self.w_rec_I = torch.tensor(weights['w_rec_I'])
        # Override what was done above (for higher precision)
        self.w_in_netE.data = torch.tensor(weights['w_in_netE'])
        self.w_in_I.data = torch.tensor(weights['w_in_I'])
        self.w_rec_netE_param.data = torch.tensor(weights['w_rec_netE_param'])
        self.w_rec_I_param.data = torch.tensor(weights['w_rec_I_param'])
        # Update convolutional weights
        self.set_conv_weights()
    
    def set_optimal_net_weights(self, width):
        gamma = np.exp(-1/width)
        w_2 = torch.tensor(0.5 / (gamma + 1/gamma))
        # Inverse of sigmoid function
        self.w_rec_netE = w_2
        self.w_in_netE.data = 1.*(1.-w_2*4*gamma)
        self.set_conv_weights()

    def sanitize(self):
        # Sanitize weights
        with torch.no_grad():
            self.w_in_I.data = torch.clip(self.w_in_I.data, None, 0)
            self.w_rec_I_param.data = torch.clip(self.w_rec_I_param.data, None, 0)
            self.set_conv_weights()


    def transform_input(self, x_in, x_in_delayed):
        if self.selectivity=='LinearMixed':
            # In this case x_in[0] is stimulus 1 and x_in[1] is stimulus 2
            x_in_ = torch.sum(x_in, dim=1)[:,np.newaxis] + torch.sum(x_in, dim=0)[np.newaxis,:]
            x_in_delayed_ = torch.sum(x_in_delayed, dim=1)[:,np.newaxis] + torch.sum(x_in_delayed, dim=0)[np.newaxis,:]
        elif self.selectivity=='NonlinearMixed': 
            x_in_ = x_in
            x_in_delayed_ = x_in_delayed
        return x_in_, x_in_delayed_

    
    def get_state_derivative(self, t_, x_state, x_state_delayed, x_in):
        # t_ is not used, only given for better readability
        x_state_ = x_state.clone()
        x_state_padded = self.pad_state(x_state)
        dx_state = x_state_-x_state_delayed
        dx_state_padded = self.pad_state(dx_state)
        x_in_, _ = self.transform_input(x_in, 0*x_in)

        x_state_deriv_times_tau \
            = - x_state_ \
              + self.w_in_netE*x_in_ \
              + self.conv_rec_netE(x_state_padded[np.newaxis, :, :])[0,:,:] \
              - (self.tau/self.tau_lag)*self.conv_rec_I(dx_state_padded[np.newaxis, :, :])[0,:,:]

        x_state_deriv = x_state_deriv_times_tau/self.tau
        return x_state_deriv
    

    def midpoint_method_step(self, x_state, x_state_delayed, x_in, dxdt):
        # Midpoint method that integrates the ODE
        # x_state'=f(x_state, x_state_delayed, x_in) from t to t+dt.
        # It relies on self.dxdt_nows_hist, which are
        # stored past changes to x_state (oldest first)
        t0 = 0 # Does not matter, is just given for readability
        dt = self.dt
        dxdt_now = dxdt(t0, x_state, x_state_delayed, x_in)
        dxdt_mid = dxdt(t0+dt/2, x_state+(dt/2)*dxdt_now, x_state_delayed+(dt/2)*self.dxdt_nows_hist[0], x_in)
        x_state_new = x_state + dt*dxdt_mid
        self.dxdt_nows_hist = torch.roll(self.dxdt_nows_hist, -1, dims=0)
        self.dxdt_nows_hist[-1] = dxdt_now

        return x_state_new
    
    def forward(self, x_in, x_in_delayed, x_state, x_state_delayed, force_stability=False):
        # Use midpoint method to integrate state derivative
        x_state_new = self.midpoint_method_step(x_state, x_state_delayed, x_in, self.get_state_derivative)

        return x_state_new

## Functions for simulating and evaluating

In [None]:
def evolve_net(net, n_steps, x_in, x_target, x_state_init=None, x_in_prev=None, only_loss=False, mask=None):
    net.reset()
    state_list_len = net.n_lag
    if x_state_init is None:
        x_state_init = torch.zeros_like(net.state)
    states = [x_state_init,]*(state_list_len)
    prev_states = states.copy() # No state jumps at t=0
    if x_in_prev is None:
        x_in_prev = torch.zeros_like(net.state)
    x_in_history = [x_in_prev,]*(state_list_len)

    loss = 0
    loss_curve = np.zeros(n_steps)
    violations_x = np.zeros(n_steps)
    violations_dx = np.zeros(n_steps)
    summed_activity_curve = np.zeros(n_steps)
    summed_activity_diff_curve = np.zeros(n_steps)
    for t in tqdm(range(n_steps)):
        x_in_ = x_in
        x_target_ = x_target
        x_in_history.append(x_in_)
        x_state_new = net(x_in_history[-1], x_in_history[-1-net.n_lag], states[-1], prev_states[-net.n_lag])
        loss_bit = error(x_state_new, x_target_) #/ n_steps
        loss += loss_bit

        prev_states.append(states[-1])
        states.append(x_state_new)
        loss_curve[t] = loss_bit.item()
        summed_activity_curve[t] = torch.sum(x_state_new).detach().numpy()
        summed_activity_diff_curve[t] = (torch.sum(x_state_new)-torch.sum(prev_states[-net.n_lag])).detach().numpy()


        violations_x[t] = torch.sum(torch.clamp(x_state_new-x_target, 0, None)).detach().numpy()
        violations_dx[t] = torch.sum(torch.clamp(x_state_new-prev_states[-1], None, 0)).detach().numpy()

        if only_loss:
            # Pop oldest elements of prev_states, states, x_in_history
            prev_states.pop(0)
            states.pop(0)
            x_in_history.pop(0)
            if t==np.clip(2*n_steps//3, 0, n_steps-1):
                dt_twoThirds = t
                x_state_twoThirds = x_state_new.detach().clone().numpy()

    states = torch.stack(states).detach().numpy()
    prev_states = torch.stack(prev_states).detach().numpy()
    x_in_history = torch.stack(x_in_history).detach().numpy()
    if only_loss:
        prev_states = (dt_twoThirds, x_state_twoThirds)
    
    return loss, loss_curve, states, prev_states, x_in_history, violations_x, violations_dx, summed_activity_curve, summed_activity_diff_curve

def error(x, x_target):
    return torch.mean(torch.abs((x-x_target)))

In [None]:
def get_loss_response_times(net, loss_curve, p_thr_loss, n_powers=1, plot=False):
    c_t = net.dt/net.tau

    # Define response time as the earliest time step for which the error 
    # and that of all later time steps is below the threshold.
    # Repeat this for powers of the threshold factor p_thr_loss, which
    # should correspond to multiples of the decay constants (depending
    # on the initialization)
    t_loss_reduction_abs_values = np.zeros(n_powers)
    for i_pow, n_power in enumerate(range(1, n_powers+1)):
        p_thr_loss_pow = p_thr_loss**n_power
        loss_threshold_abs = p_thr_loss_pow*loss_curve[0]

        is_below_thr_abs = loss_curve<loss_threshold_abs
        all_later_below_thr_abs = np.ones_like(loss_curve)
        for i in range(len(loss_curve)):
            all_later_below_thr_abs[i] = int(np.all(is_below_thr_abs[i:]))
        t_loss_reduction_abs = c_t*np.argmax(all_later_below_thr_abs)

        t_loss_reduction_abs_values[i_pow] = t_loss_reduction_abs

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        ax.plot(loss_curve)
        ax.axhline(y=loss_threshold_abs, c='gray', ls='dashed')
        for i_pow, n_power in enumerate(range(1, n_powers+1)):
            ax.axhline(y=(p_thr_loss**n_power)*loss_curve[0], c='gray', ls='dashed')
            ax.axvline(x=t_loss_reduction_abs_values[i_pow]/c_t, c='gray', ls='dashed')
        ax.set_yscale('log')
        plt.show()

    if n_powers==1:
        t_loss_reduction_abs_values = t_loss_reduction_abs_values[0]

    # t_prop = np.mean([np.median(time_deltas_down), np.median(time_deltas_up)])
    return t_loss_reduction_abs_values

In [None]:
def get_tresp_and_wrisc(w_rec_netE_sum, tau, tau_lag):
    t_resp_net = tau/(1-w_rec_netE_sum)
    def func(t_resp_full_sqrt):
        t_resp_full = t_resp_full_sqrt**2
        # Return the residual of the equation determining t_resp_full
        return (1/t_resp_full) - (  (1/t_resp_net) + (1/tau_lag)*(1-np.exp(-tau_lag/t_resp_full))  )
    
    # Minimize the residual
    res = root(func, 0.001*t_resp_net, args=(), method='hybr', tol=1e-15)

    t_resp_full = (res.x**2)[0]
    w_rec_I_sum = - np.exp(-tau_lag/t_resp_full)
    t_dec_I = tau/(1+w_rec_I_sum)

    return t_resp_full, w_rec_I_sum, t_dec_I

In [None]:
def get_1D_RF_gamma(w_rec_netE):
    # gamma, the factor of the exponential RF decay, determined by w_rec_netE
    gamma = (1/(2*w_rec_netE)) - np.sqrt((1/(2*w_rec_netE))**2 - 1)
    return gamma

def get_1D_target(w_rec_netE, N_1D):
    # Return the target exponential RF with gamma determined by w_rec_netE
    gamma = get_1D_RF_gamma(w_rec_netE)
    assert np.isclose(1/(gamma+1/gamma), w_rec_netE), \
        f'For gamma = {gamma}, 1/(gamma+1/gamma) = {1/(gamma+1/gamma)} != w_rec_netE = {w_rec_netE}'
    return gamma**torch.abs(torch.arange(N_1D)-N_1D//2)

# Running the simulations

In [None]:
def get_x_in(N):
    x_in = torch.zeros(N)
    x_in[N//2] = 1
    return x_in

def get_x_target(N, width):
    x_target = torch.zeros(N)
    x_target[N//2] = 1
    for i in range(N//2+1, N):
        x_target[i] = x_target[i-1]*np.exp(-1/width)
    for i in range(N//2-1, -1, -1):
        x_target[i] = x_target[i+1]*np.exp(-1/width)
    return x_target

def get_x_in_2D(N):
    x_in = torch.zeros(N, N)
    x_in[N//2, N//2] = 1
    return x_in

In [None]:
# Parameters
N_1D = 200
N_2D = 200
tau = 1
dt = 0.01
dt_vanilla = 0.1
n_lag = 10
tau_lag = n_lag*dt

dimensionality = 2
selectivity = 'LinearMixed'
isBalanced = False # Set True for balanced, False for excitatory network


t_decay = 100 # Target response time of an excitatory network



# Set weights and input
if dimensionality==1:
    x_in = get_x_in(N_1D)
    w_rec_netE = 0.5*(1 - tau/t_decay)
    wrns = 2* w_rec_netE
    t_resp_full, w_rec_I_sum, t_dec_I = get_tresp_and_wrisc(wrns, tau, tau_lag)
    w_rec_I = 0.5 * w_rec_I_sum
    x_target = get_1D_target(w_rec_netE, N_1D)
elif dimensionality==2:
    x_in = get_x_in_2D(N_2D)
    w_rec_netE = 0.25*(1 - tau/t_decay)
    wrns = 4* w_rec_netE
    t_resp_full, w_rec_I_sum, t_dec_I = get_tresp_and_wrisc(wrns, tau, tau_lag)
    w_rec_I = 0.25 * w_rec_I_sum

# Initialize network
if dimensionality==1:
    net = IELagRateNetwork(N_1D, tau=tau, dt=dt, n_lag=n_lag, width=1)
elif dimensionality==2:
    net = IELag2DConvNetwork(N_2D, tau=tau, dt=dt, n_lag=n_lag, width=1, selectivity=selectivity)

if isBalanced:
    net.dt = dt
    net.w_in_netE.data = torch.tensor(1.) # Input weight for excitatory input:
                                          # adjust later to normalize network
                                          # response to a maximum of 1, if 
                                          # wanted
    net.w_in_I.data = torch.tensor(0.)
    net.w_rec_netE = torch.tensor(w_rec_netE) # Recurrent net weight
    net.w_rec_I = torch.tensor(w_rec_I)   # Strength of balanced interactions
else:
    net.dt = dt_vanilla
    net.w_in_netE.data = torch.tensor(1.) 
    net.w_in_I.data = torch.tensor(0.)
    net.w_rec_netE = torch.tensor(w_rec_netE) # Recurrent net weight
    net.w_rec_I_param.data = torch.tensor(0.)

net.set_weights(net.get_weights())
net.sanitize()
net.print_weights()

In [None]:
if isBalanced:
    n_steps = 10*int(t_resp_full/net.dt)
else:
    n_steps = 10*int(t_decay/net.dt)

p_thr_loss = np.exp(-1)  # Threshold factor for loss reduction


with torch.no_grad():
    # Run once to obtain final state (used as target, make sure n_steps is
    # sufficiently large)
    loss, loss_curve, states, prev_states, x_in_history, violations_x, violations_dx, summed_activity_curve, summed_activity_diff_curve \
        = evolve_net(net, 2*n_steps, x_in, x_target=0*x_in, x_state_init=0*x_in, x_in_prev=None, only_loss=True)
    
    
    # Obtain target state
    x_target = states[-1].copy()

    # Normalize input weights and (max value of) target to 1
    net.w_in_netE.data /= np.max(x_target)
    x_target /= np.max(x_target)

    # Run with obtained target to obtain the loss evolution
    loss, loss_curve, states, prev_states, x_in_history, violations_x, violations_dx, summed_activity_curve, summed_activity_diff_curve \
        = evolve_net(net, n_steps, x_in, x_target=x_target, x_state_init=0*x_in, x_in_prev=None, only_loss=True)
    
    # Calculate response time
    t_loss_reduction_abs_values = get_loss_response_times(net, loss_curve, p_thr_loss, n_powers=3, plot=True)

    
print(f'Loss for target: {loss}')

In [None]:
# Plot loss curve and final state
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
c_t = net.dt/net.tau # Conversion factor for time steps to tau units
time_grid = np.arange(len(loss_curve))*c_t
axes[0].set_title('Loss curve for target')
axes[0].plot(time_grid, loss_curve)
axes[0].set_xlabel(r't $(\tau)$')
axes[0].set_ylabel('Loss')
axes[0].set_yscale('log')
# Plot e^-1 of initial loss threshold and time of threshold crossing
axes[0].axhline(y=p_thr_loss*loss_curve[0], color='gray', linestyle='dashed', label=f'Threshold: {p_thr_loss:.2f} * initial loss')
axes[0].axvline(x=t_loss_reduction_abs_values[0], color='gray', linestyle='dashed')
# Plot analytic response time to check
if isBalanced: 
    t_resp_analytic = t_resp_full
else:
    t_resp_analytic = t_decay
axes[0].axvline(x=t_resp_analytic, color='red', linestyle='dotted', label=f'Analytic response time: {t_resp_analytic:.2f} time steps')

axes[1].set_title('Final state and target')
if dimensionality==1:
    axes[1].plot(states[-1], label='Final state')
    axes[1].plot(x_target, label='Target state', linestyle='dashed')
    axes[1].legend()
    axes[1].set_xlabel('Neuron index')
    axes[1].set_ylabel('State')
elif dimensionality==2:
    im = axes[1].imshow(states[-1], cmap='viridis', origin='lower', vmin=0, vmax=1)
    axes[1].set_title('Final state')
    fig.colorbar(im, ax=axes[1])
plt.tight_layout()
plt.show()

## 1D SFA network

In [None]:
N_1D = 200
tau = 1
dt = 0.01

# Values from supplementary Fig. Aa) - i.e., non-optimal, but illustrative
d_RF = 4.5            # RF width, determines n_RF = 2*d_RF + 1
a_SFA = 0.09           # Strength of SFA (0 means no SFA)
tau_SFA = 10*tau      # Time constant of SFA (here long for illustration)

n_steps = 500*int(tau/dt)    
t_pulse_on = 100*int(tau/dt)  # Time step at which the input pulse is turned on
t_pulse_off = 300*int(tau/dt) # Time step at which the input pulse is turned off

x_init = 0.

x_state_data, u_SFA_data, r_in_data, loss_data_exc, t_resp, violation_data \
    = simulate_1D_SFA(d_RF=d_RF, a_SFA=a_SFA, tau_SFA=tau_SFA, N_1D=N_1D, tau=tau, dt=dt, n_steps=n_steps, x_init=x_init, t_pulse_on=t_pulse_on, t_pulse_off=t_pulse_off)

In [None]:
# Plot the results
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].set_title('Dynamics')
axes[0].plot(x_state_data[:,N_1D//2], label=r'$x_{j_0}$', c='black')
axes[0].plot(u_SFA_data[:,N_1D//2], label=r'$u_{j_0}$', c='green')
axes[0].plot(r_in_data[:,N_1D//2], label=r'$r_{j_0}$', c='blue')
axes[0].set_xlabel(r'Time $(\tau)$')
axes[0].legend()

axes[1].set_title('Final activity')
neuron_range = np.arange(N_1D//2-20, N_1D//2+21)
axes[1].plot(neuron_range, x_state_data[t_pulse_off, neuron_range], label='Final state')
axes[1].set_xlabel('Neuron index')
axes[1].set_ylabel('Response')

axes[2].set_title('Loss evolution')
axes[2].plot(loss_data_exc, label='Loss')
axes[2].set_xlabel(r'Time $(\tau)$')
axes[2].set_ylabel('Loss')
plt.tight_layout()
plt.show()