In [3]:
# Torch related imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Data visualisation and numerical computing
import matplotlib.pyplot as plt
import numpy as np
import math

# Network analysis
import networkx as nx
from matplotlib.gridspec import GridSpec

# Additional useful imports
import pandas as pd  
import seaborn as sns  
import os  
import sys  
import logging  

torch.manual_seed(0)
np.random.seed(0)
plt.rcParams['figure.figsize'] = (10, 6)

# Configure logging
logging.basicConfig(level=logging.INFO)


In [4]:
class SSNBase(torch.nn.Module):
    def __init__(self, n, k, Ne, Ni, tau_e, tau_i, device='cpu', dtype=torch.float64):
        super().__init__()
        self.n = n
        self.k = k
        self.Ne = Ne
        self.Ni = Ni
        self.N = self.Ne + self.Ni
        self.device = device
        self.dtype = dtype
        self.register_buffer('EI', torch.cat([torch.ones(Ne), torch.zeros(Ni)]).to(device).bool())
        self.register_buffer('tau_vec', torch.cat([tau_e * torch.ones(Ne), tau_i * torch.ones(Ni)]).to(device, dtype))
        self.W = torch.nn.Parameter(torch.zeros((self.N, self.N), device=device, dtype=dtype))
    
    @torch.jit.script_method
    def drdt(self, r, inp_vec):
        return (-r + self.powlaw(self.W @ r + inp_vec)) / self.tau_vec
    
    def powlaw(self, u):
        return self.k * F.relu(u).pow(self.n)
    
    @torch.jit.script_method
    def simulate(self, inp_vec, r_init=None, t_final=100, dt=0.1):
        if r_init is None:
            r_init = torch.zeros((self.N,), device=self.device, dtype=self.dtype)
        
        r = r_init
        t = 0
        while t < t_final:
            dr = self.drdt(r, inp_vec)
            r += dt * dr
            t += dt
        return r
    
    def jacobian(self, r):
        Phi = self.gains_from_r(r)
        return -torch.eye(self.N, device=self.device, dtype=self.dtype) + Phi[:, None] * self.W
    
    def gains_from_r(self, r):
        return self.n * self.k**(1/self.n) * r.pow(1 - 1/self.n)
    
    def fixed_point(self, inp_vec, tol=1e-6, max_iter=1000):
        r = torch.zeros((self.N,), device=self.device, dtype=self.dtype)
        for _ in range(max_iter):
            dr = self.drdt(r, inp_vec)
            r_new = r + dr
            if torch.norm(r_new - r) < tol:
                return r_new
            r = r_new
        raise RuntimeError(f"Fixed point not found after {max_iter} iterations.")

In [39]:
class SSN2DTopo(SSNBase):
    def __init__(self, n, k, tauE, tauI, grid_pars, conn_pars, thetas, device='cpu', dtype=torch.float64):
        num_orientations = thetas.shape[0]
        grid_size = grid_pars['grid_size_Nx']
        
        Ne = num_orientations * (grid_size ** 2)
        Ni = num_orientations * (grid_size ** 2)

        super(SSN2DTopo, self).__init__(n=n, k=k, Ne=Ne, Ni=Ni, tau_e=tauE, tau_i=tauI, device=device, dtype=dtype)
        
        self.num_orientations = num_orientations
        self.grid_size = grid_size
        self.Ne = Ne
        self.Ni = Ni
        self._make_maps(thetas)
        
        self.J_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype))
        self.s_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype))
        self.p_local = nn.Parameter(torch.rand(2, device=device, dtype=dtype))
        self.sigma_oris = nn.Parameter(torch.rand(1, device=device, dtype=dtype))
        
        self.make_W()

    def _make_maps(self,thetas):
        
        self.ori_map = torch.tensor(thetas, device=self.device, dtype=self.dtype).flatten()
        self.ori_vec = self.ori_map.repeat(self.grid_size ** 2)
        self.ori_vec = self.ori_vec.repeat(2)

        self.x_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations).repeat(self.grid_size)
        self.y_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations * self.grid_size)

        self.x_vec = self.x_vec.repeat(2)
        self.y_vec = self.y_vec.repeat(2)

    def make_W(self):
        
        xy_dist = self.calc_xy_dist()
        ori_dist = self.calc_ori_dist()
        
        W_ee = self.calc_W_block(xy_dist[:self.Ne, :self.Ne], ori_dist[:self.Ne, :self.Ne], self.s_2x2[0, 0], self.sigma_oris)
        W_ei = self.calc_W_block(xy_dist[:self.Ne, self.Ne:], ori_dist[:self.Ne, self.Ne:], self.s_2x2[0, 1], self.sigma_oris)
        W_ie = self.calc_W_block(xy_dist[self.Ne:, :self.Ne], ori_dist[self.Ne:, :self.Ne], self.s_2x2[1, 0], self.sigma_oris)
        W_ii = self.calc_W_block(xy_dist[self.Ne:, self.Ne:], ori_dist[self.Ne:, self.Ne:], self.s_2x2[1, 1], self.sigma_oris)
        
        W_ee = self.p_local[0] * torch.eye(self.Ne, device=self.device, dtype=self.dtype) + (1 - self.p_local[0]) * W_ee
        W_ei = self.p_local[1] * torch.eye(self.Ni, device=self.device, dtype=self.dtype) + (1 - self.p_local[1]) * W_ei
        
        self.W = nn.Parameter(torch.cat([
            torch.cat([self.J_2x2[0, 0] * W_ee, self.J_2x2[0, 1] * W_ei], dim=1),
            torch.cat([self.J_2x2[1, 0] * W_ie, self.J_2x2[1, 1] * W_ii], dim=1)
        ], dim=0).double())
        
        return self.W
    
    def calc_xy_dist(self):
        Ne = Ni = self.Ne
        x_vec_e = self.x_vec[:Ne]
        y_vec_e = self.y_vec[:Ne]
        x_vec_i = self.x_vec[Ne:Ne+Ni]
        y_vec_i = self.y_vec[Ne:Ne+Ni]
        
        xy_dist = torch.cdist(torch.stack([x_vec_e, y_vec_e], dim=1), torch.stack([x_vec_i, y_vec_i], dim=1), p=2).repeat(2, 2) #Distance Squared

        return xy_dist
    
    def calc_ori_dist(self,L=np.pi, method="absolute"):

        Ne = Ni = self.num_orientations * self.grid_size ** 2
        
        ori_vec_e = self.ori_vec[:Ne]
        ori_vec_i = self.ori_vec[Ne:Ne+Ni]
        
        if method == "absolute":
            ori_dist = torch.cdist(ori_vec_e.unsqueeze(1), ori_vec_i.unsqueeze(1)).repeat(2,2)
        elif method == "cos":
            ori_vec_e_norm = ori_vec_e / ori_vec_e.norm(dim=1, keepdim=True)
            ori_vec_i_norm = ori_vec_i / ori_vec_i.norm(dim=1, keepdim=True)
            ori_dist = 1 - torch.mm(ori_vec_e_norm, ori_vec_i_norm.t())
            ori_dist = ori_dist.repeat(2, 2)
        else:
            #1 - cos(2(pi/L) * |theta1 - theta2|)
            ori_dist = 1 - torch.cos((2 * np.pi / L) * torch.abs(ori_vec_e.unsqueeze(1) - ori_vec_i.unsqueeze(1)))
            ori_dist = ori_dist.repeat(2, 2)

        return ori_dist

    def calc_W_block(self, xy_dist, ori_dist, s, sigma_oris,CellWiseNormalised = True):

        #Add a small constant to s and sigma_oris to avoid division by zero
        s = s + 1e-8
        sigma_oris = sigma_oris + 1e-8

        W =  torch.exp(-xy_dist / s - ori_dist ** 2 / (2 * sigma_oris ** 2))
        W = torch.where(W < 1e-4, torch.zeros_like(W), W)
        
        sW = torch.sum(W, dim=1, keepdim=True)
        if CellWiseNormalised:
            W = W / sW[:, None]
        else:
            sW = sW.mean()
            W = W / sW
            
        return W

In [40]:
# Set network parameters
n = 2
k = 0.4
tauE = 20.0
tauI = 10.0
grid_pars = {'grid_size_Nx': 3}
conn_pars = {'num_orientations': 8}
thetas = np.linspace(0, np.pi, conn_pars['num_orientations'])

# Create an instance of SSN2DTopo
ssn_topo = SSN2DTopo(n, k, tauE, tauI, grid_pars, conn_pars, thetas)