# Direct fit SSN

In [1]:
from GSM._imports import *
from GSM.fogsm import FoGSMModel
from torch.optim import Adam



In [2]:
from SSN._imports import *
from SSN.ssn_2dtopoV1 import SSN2DTopoV1
from SSN.params import GridParameters

In [16]:
class SSNBase(torch.nn.Module):
    def __init__(self, n, k, Ne, Ni, tau_e, tau_i, device='cpu', dtype=torch.float64):
        super().__init__()
        self.n = n
        self.k = k
        self.Ne = Ne
        self.Ni = Ni
        self.N = self.Ne + self.Ni
        self.device = device
        self.dtype = dtype
        self.register_buffer('EI', torch.cat([torch.ones(Ne), torch.zeros(Ni)]).to(device).bool())
        self.register_buffer('tau_vec', torch.cat([tau_e * torch.ones(Ne), tau_i * torch.ones(Ni)]).to(device, dtype))
        self.W = torch.nn.Parameter(torch.zeros((self.N, self.N), device=device, dtype=dtype))
    
    @torch.jit.script_method
    def drdt(self, r, inp_vec):
        return (-r + self.powlaw(self.W @ r + inp_vec)) / self.tau_vec
    
    def powlaw(self, u):
        return self.k * F.relu(u).pow(self.n)
    
    @torch.jit.script_method
    def simulate(self, inp_vec, r_init=None, duration=100, dt=0.1):
        if r_init is None:
            r_init = torch.zeros((self.N,), device=self.device, dtype=self.dtype)
        
        r = r_init
        t = 0
        while t < duration:
            dr = self.drdt(r, inp_vec)
            r += dt * dr
            t += dt
        return r
    
    def jacobian(self, r):
        Phi = self.gains_from_r(r)
        return -torch.eye(self.N, device=self.device, dtype=self.dtype) + Phi[:, None] * self.W
    
    def gains_from_r(self, r):
        return self.n * self.k**(1/self.n) * r.pow(1 - 1/self.n)
    
    def fixed_point(self, inp_vec, tol=1e-6, max_iter=1000):
        r = torch.zeros((self.N,), device=self.device, dtype=self.dtype)
        for _ in range(max_iter):
            dr = self.drdt(r, inp_vec)
            r_new = r + dr
            if torch.norm(r_new - r) < tol:
                return r_new
            r = r_new
        raise RuntimeError(f"Fixed point not found after {max_iter} iterations.")

In [17]:
class SSN2DTopo(SSNBase):
    def __init__(self, n, k, tauE, tauI, grid_pars, conn_pars, thetas, device='cpu', dtype=torch.float64):
        num_orientations = thetas.shape[0]
        grid_size = grid_pars['grid_size_Nx']
        
        Ne = num_orientations * (grid_size ** 2)
        Ni = num_orientations * (grid_size ** 2)

        super(SSN2DTopo, self).__init__(n=n, k=k, Ne=Ne, Ni=Ni, tau_e=tauE, tau_i=tauI, device=device, dtype=dtype)
        
        self.num_orientations = num_orientations
        self.grid_size = grid_size
        self.Ne = Ne
        self.Ni = Ni
        self._make_maps(thetas)
        
        self.J_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype))
        self.s_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype))
        self.p_local = nn.Parameter(torch.rand(2, device=device, dtype=dtype))
        self.sigma_oris = nn.Parameter(torch.rand(1, device=device, dtype=dtype))
        
        self.make_W()

    def _make_maps(self,thetas):
        
        self.ori_map = torch.tensor(thetas, device=self.device, dtype=self.dtype).flatten()
        self.ori_vec = self.ori_map.repeat(self.grid_size ** 2)
        self.ori_vec = self.ori_vec.repeat(2)

        self.x_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations).repeat(self.grid_size)
        self.y_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations * self.grid_size)

        self.x_vec = self.x_vec.repeat(2)
        self.y_vec = self.y_vec.repeat(2)

    def make_W(self):
        
        xy_dist = self.calc_xy_dist()
        ori_dist = self.calc_ori_dist()
        
        W_ee = self.calc_W_block(xy_dist[:self.Ne, :self.Ne], ori_dist[:self.Ne, :self.Ne], self.s_2x2[0, 0], self.sigma_oris)
        W_ei = self.calc_W_block(xy_dist[:self.Ne, self.Ne:], ori_dist[:self.Ne, self.Ne:], self.s_2x2[0, 1], self.sigma_oris)
        W_ie = self.calc_W_block(xy_dist[self.Ne:, :self.Ne], ori_dist[self.Ne:, :self.Ne], self.s_2x2[1, 0], self.sigma_oris)
        W_ii = self.calc_W_block(xy_dist[self.Ne:, self.Ne:], ori_dist[self.Ne:, self.Ne:], self.s_2x2[1, 1], self.sigma_oris)
        
        W_ee = self.p_local[0] * torch.eye(self.Ne, device=self.device, dtype=self.dtype) + (1 - self.p_local[0]) * W_ee
        W_ei = self.p_local[1] * torch.eye(self.Ni, device=self.device, dtype=self.dtype) + (1 - self.p_local[1]) * W_ei
        
        self.W = nn.Parameter(torch.cat([
            torch.cat([self.J_2x2[0, 0] * W_ee, self.J_2x2[0, 1] * W_ei], dim=1),
            torch.cat([self.J_2x2[1, 0] * W_ie, self.J_2x2[1, 1] * W_ii], dim=1)
        ], dim=0).double())
        
        return self.W
    
    def calc_xy_dist(self):
        Ne = Ni = self.Ne
        x_vec_e = self.x_vec[:Ne]
        y_vec_e = self.y_vec[:Ne]
        x_vec_i = self.x_vec[Ne:Ne+Ni]
        y_vec_i = self.y_vec[Ne:Ne+Ni]
        
        xy_dist = torch.cdist(torch.stack([x_vec_e, y_vec_e], dim=1), torch.stack([x_vec_i, y_vec_i], dim=1), p=2).repeat(2, 2) #Distance Squared

        return xy_dist
    
    def calc_ori_dist(self,L=np.pi, method="absolute"):

        Ne = Ni = self.num_orientations * self.grid_size ** 2
        
        ori_vec_e = self.ori_vec[:Ne]
        ori_vec_i = self.ori_vec[Ne:Ne+Ni]
        
        if method == "absolute":
            ori_dist = torch.cdist(ori_vec_e.unsqueeze(1), ori_vec_i.unsqueeze(1)).repeat(2,2)
        elif method == "cos":
            ori_vec_e_norm = ori_vec_e / ori_vec_e.norm(dim=1, keepdim=True)
            ori_vec_i_norm = ori_vec_i / ori_vec_i.norm(dim=1, keepdim=True)
            ori_dist = 1 - torch.mm(ori_vec_e_norm, ori_vec_i_norm.t())
            ori_dist = ori_dist.repeat(2, 2)
        else:
            #1 - cos(2(pi/L) * |theta1 - theta2|)
            ori_dist = 1 - torch.cos((2 * np.pi / L) * torch.abs(ori_vec_e.unsqueeze(1) - ori_vec_i.unsqueeze(1)))
            ori_dist = ori_dist.repeat(2, 2)

        return ori_dist

    def calc_W_block(self, xy_dist, ori_dist, s, sigma_oris,CellWiseNormalised = True):

        #Add a small constant to s and sigma_oris to avoid division by zero
        s = s + 1e-8
        sigma_oris = sigma_oris + 1e-8

        W =  torch.exp(-xy_dist / s - ori_dist ** 2 / (2 * sigma_oris ** 2))
        W = torch.where(W < 1e-4, torch.zeros_like(W), W)
        
        sW = torch.sum(W, dim=1, keepdim=True)
        if CellWiseNormalised:
            W = W / sW[:, None]
        else:
            sW = sW.mean()
            W = W / sW
            
        return W

In [18]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        super(DirectFit, self).__init__()
        self.ssn_model = SSN2DTopo(**ssn_params)
        self.fogsm_model = FoGSMModel(**fogsm_params)

    def sample_trajectories(self, input_batch, num_samples_g, duration=500, dt=1):
        batch_size = input_batch.shape[0]

        # Expand input_batch to match the number of samples
        input_expanded = input_batch.unsqueeze(1).expand(batch_size, num_samples_g, *input_batch.shape[1:])
        print("Input expanded shape: ", input_expanded.shape)

        # Reshape input_expanded to (batch_size * num_samples, input_size)
        input_reshaped = input_expanded.reshape(-1, *input_batch.shape[1:])

        # Generate trajectories for the reshaped input
        trajectories = self.ssn_model.simulate(input_reshaped, duration, dt)

        # Reshape trajectories to (batch_size, num_samples, N)
        trajectories = trajectories.reshape(batch_size, num_samples_g)

        return trajectories

    def calculate_log_p_g(self, trajectories):
        mu = torch.zeros_like(trajectories[0])
        log_p_g = MultivariateNormal(mu, self.fogsm_model.K_g).log_prob(trajectories).mean()
        return log_p_g

    def calculate_log_p_I_given_g(self, trajectories, A_samples):
        
        # Calculate log p(I|g,A) for each sample
        log_likelihood = []
        for a in A_samples:
            log_p = 0
            for g in trajectories:
                I = g * a
                log_p += self.fogsm_model.log_likelihood(I)
            log_likelihood.append(log_p / len(trajectories))
        
        # Approximate the expectation over A using Monte Carlo sampling
        p_I_given_g = torch.exp(torch.stack(log_likelihood)).mean()
        log_p_I_given_g = torch.log(p_I_given_g)

        return log_p_I_given_g

    def calculate_elbo(self, input_batch, num_samples_g, A_samples, duration=500, dt=1):

        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input_batch, num_samples_g, duration, dt)

        # Calculate the first two terms of the ELBO using Monte Carlo approximation
        log_p_g = self.calculate_log_p_g(trajectories)
        log_p_I_given_g = self.calculate_log_p_I_given_g(trajectories, A_samples)

        # Approximate the entropy term
        # cov_matrix = torch.cov(trajectories.reshape(num_samples_g, -1).T)

        cov_matrix = torch.cov(trajectories.reshape(trajectories.shape[0], -1))
        entropy_term = 0.5 * torch.logdet(cov_matrix)

        elbo = log_p_g + log_p_I_given_g - entropy_term

        return elbo
    
    def optimise_elbo(self, batch_size, num_samples_a, num_samples_g, convergence_threshold=1e-3, optimizer=Adam):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.ssn_model.to(device)
        self.fogsm_model.to(device)

        batch = 0
        prev_elbo = float('-inf')

        while True:
            elbo_batch = 0

            # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
            A_samples = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]

            # Generate a mini-batch of input data
            input_batch = self.fogsm_model.generate_fogsm_dataset(batch_size)

            # Calculate the ELBO for the mini-batch
            elbo_batch = self.calculate_elbo(input_batch, num_samples_g, A_samples)
            elbo += elbo_batch.item()

            # Optimise the ELBO with respect to the model parameters
            elbo.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Print the ELBO for monitoring
            if (batch + 1) % 100 == 0:
                print(f"Batch [{batch+1}], ELBO: {elbo_batch / batch_size:.4f}")

            # Check for convergence
            if abs(elbo_batch - prev_elbo) < convergence_threshold:
                print(f"Converged after {batch+1} batches.")
                break

            prev_elbo = elbo_batch
            batch += 1
        
        return self.ssn_model, self.fogsm_model

In [19]:
torch.manual_seed(0)

# Define the parameters for the FoGSM model
thetas = torch.linspace(0, 2 * torch.pi, 8)  # 8 orientations from 0 to 2*pi
fogsm_params = {
        "thetas": thetas,
        "length_scale_feature": 0.5,
        "length_scale_amplitude": 1.2,
        "kappa": 1.0,
        "jitter": 1e-4,
        "grid_size": 3,
        "frequency": 0.9,
        "sigma": 0.1,
    }
fogsm_model = FoGSMModel(**fogsm_params)

  theta1 = torch.tensor(theta1)
  theta2 = torch.tensor(theta2)


In [20]:
# Set network parameters
n = 2
k = 0.4
tauE = 20.0
tauI = 10.0
grid_pars = {'grid_size_Nx': 3}
conn_pars = {'num_orientations': 8}
thetas = np.linspace(0, np.pi, conn_pars['num_orientations'])

ssn_params = {
        "n": 2,
        "k": 0.04,
        "tauE": 20,
        "tauI": 10,
        "grid_pars": grid_pars,
        "conn_pars": conn_pars,
        "thetas": thetas,
    }

In [21]:
direct_fit = DirectFit(ssn_params, fogsm_params)

# Create an optimiser for the ELBO
optimiser = Adam(list(direct_fit.ssn_model.parameters()), lr=0.001)

# Run the ELBO optimisation
input_data, _ = torch.load("fogsm_dataset.pt")
direct_fit.optimise_elbo(batch_size=32, num_samples_a=10, num_samples_g=100, optimizer=optimiser)

Input expanded shape:  torch.Size([32, 100, 3, 3])


TypeError: 'ScriptMethodStub' object is not callable

In [4]:
#SSN2DTopoV1
torch.manual_seed(0)


# Define the parameters for the SSN model
grid_pars = GridParameters(
    gridsize_Nx=17,  # Number of grid points in one dimension
    gridsize_deg=3.2,  # Size of the grid in degrees of visual angle
    magnif_factor=2,  # Magnification factor to convert degrees to mm
    hyper_col=800,  # Hypercolumn
    )
psi = torch.tensor(0.774)
conn_pars = {
        'J_2x2': torch.tensor([[1.124, -0.931], [1.049, -0.537]], dtype=torch.float64) * torch.pi * psi,
        's_2x2': torch.tensor([[0.2955, 0.09], [0.5542, 0.09]]),
        'p_local': [0.72, 0.7],
        'sigma_oris': 45,
    }
ssn_params = {
        "n": 2,
        "k": 0.04,
        "tauE": 20,
        "tauI": 10,
        "grid_pars": grid_pars,
        "conn_pars": conn_pars,
    }



  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  theta1 = torch.tensor(theta1)
  theta2 = torch.tensor(theta2)


SSN2DTopoV1.__init__
grid_pars:  <SSN.params.GridParameters object at 0x7f9af186e3d0>
conn_pars:  {'J_2x2': tensor([[ 2.7331, -2.2638],
        [ 2.5507, -1.3058]], dtype=torch.float64), 's_2x2': tensor([[0.2955, 0.0900],
        [0.5542, 0.0900]]), 'p_local': [0.72, 0.7], 'sigma_oris': 45}
Making W
W made
SSN2DTopoV1.__init__ done
Input expanded shape:  torch.Size([32, 100, 3, 3])


RuntimeError: The size of tensor a (578) must match the size of tensor b (3) at non-singleton dimension 2

### Unbatched DirectFit archived

In [None]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        self.ssn_model = SSN2DTopoV1(**ssn_params)
        self.fogsm_model = FoGSMModel(**fogsm_params)

    def sample_trajectories(self, input, num_samples, duration=500, dt=1):
        trajectories = []
        for _ in range(num_samples):
            traj = self.ssn_model.run_simulation(input, duration, dt)
            trajectories.append(traj)
        return torch.stack(trajectories)

    def calculate_log_p_g(self, trajectories):
        mu = torch.zeros_like(trajectories[0])
        log_p_g = MultivariateNormal(mu, self.fogsm_model.K_g).log_prob(trajectories).mean()
        return log_p_g

    def calculate_log_p_I_given_g(self, trajectories, num_samples_a=10):
        # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
        amplitude_fields = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]
        
        # Calculate log p(I|g,A) for each sample
        log_likelihood = []
        for a in amplitude_fields:
            log_p = 0
            for g in trajectories:
                I = g * a
                log_p += self.fogsm_model.log_likelihood(I)
            log_likelihood.append(log_p / len(trajectories))
        
        # Approximate the expectation over A using Monte Carlo sampling
        p_I_given_g = torch.exp(torch.stack(log_likelihood)).mean()
        log_p_I_given_g = torch.log(p_I_given_g)

        return log_p_I_given_g

    def calculate_elbo(self, input, num_samples, duration=500, dt=1):
        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input, num_samples, duration, dt)
        
        # Calculate the first two terms of the ELBO using Monte Carlo approximation
        log_p_g = self.calculate_log_p_g(trajectories)
        log_p_I_given_g = self.calculate_log_p_I_given_g(trajectories)
        
        # Approximate the entropy term
        cov_matrix = torch.cov(trajectories.reshape(num_samples, -1).T)
        entropy_term = 0.5 * torch.logdet(cov_matrix)
        
        elbo = log_p_g + log_p_I_given_g - entropy_term
        return elbo

    def optimise_elbo(self, input, num_samples, num_epochs, optimizer):
        for epoch in range(num_epochs):
            # Calculate the ELBO
            elbo = self.calculate_elbo(input, num_samples)
            
            #resample As for each minibatch

            # Optimise the ELBO with respect to the model parameters
            elbo.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # Print the ELBO for monitoring
            if (epoch + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], ELBO: {elbo.item():.4f}")