In [2]:
from GSM._imports import *
from torch.optim import Adam
from SSN._imports import *
from SSN.params import GridParameters



# FoGSM

In [3]:
class FoGSMModel(nn.Module):
    def __init__(self,thetas=None, length_scale_feature=1.0, length_scale_amplitude=1.2, kappa=1.0, jitter=1e-5, grid_size=10, frequency=1.0,sigma=0.1):
        super(FoGSMModel, self).__init__()

        self.dtype = torch.float64

        self.length_scale_feature = Parameter(torch.tensor(length_scale_feature, dtype=self.dtype))
        self.length_scale_amplitude = Parameter(torch.tensor(length_scale_amplitude, dtype=self.dtype))
        self.kappa = Parameter(torch.tensor(kappa, dtype=self.dtype))
        self.frequency = Parameter(torch.tensor(frequency, dtype=self.dtype)) 
        self.sigma = Parameter(torch.tensor(sigma))

        if thetas is None:
            thetas = torch.linspace(0, 2 * np.pi, 8)  # 8 orientations from 0 to 2*pi
        self.thetas = thetas

        self.jitter = jitter
        self.grid_size = grid_size
        self.grid = torch.stack(torch.meshgrid(torch.linspace(-5, 5, grid_size), 
                                               torch.linspace(-5, 5, grid_size)), 
                                dim=-1).reshape(-1, 2)
        
        self.K_g = self.generate_K_g()
        
    def von_mises_kernel(self, theta1, theta2):
        theta_diff = theta1 - theta2  
        return torch.clamp(torch.exp(self.kappa * torch.cos(theta_diff)), min=1e-6)
    
    def squared_exponential_kernel(self, x1, x2, length_scale,jitter="True"):
        x1 = x1.unsqueeze(1) # Shape: [N, 1, 2]
        x2 = x2.unsqueeze(0) # Shape: [1, N, 2]
        sq_dist = torch.sum((x1 - x2) ** 2, dim=2) # Shape: [N, N]

        exp_term = torch.exp(-sq_dist / (2*length_scale**2))

        if jitter:
            return exp_term + self.jitter * torch.eye(x1.size(0))
        else:
            return exp_term

    def composite_feature_kernel(self, theta1, theta2):
         
        sq_exp_component = self.squared_exponential_kernel(self.grid, self.grid, self.length_scale_feature,jitter="False")        
        
        # Ensure theta1 and theta2 are tensors
        theta1 = torch.tensor(theta1)
        theta2 = torch.tensor(theta2)
        x1 = self.grid.unsqueeze(1) # Shape: [N, 1, 2]
        x2 = self.grid.unsqueeze(0) # Shape: [1, N, 2]

        n1 = torch.tensor([torch.cos(theta1), torch.sin(theta1)]).view(1, 1, 2)  # Shape: [1, 1, 2]
        n2 = torch.tensor([torch.cos(theta2), torch.sin(theta2)]).view(1, 1, 2)  # Shape: [1, 1, 2]
        average_orientation = (n1 + n2) / 2

        # Broadcasting average_orientation for dot product computation
        average_orientation = average_orientation.repeat(x1.size(0), x1.size(1), 1)
        dot_product = torch.sum((x1 - x2) * average_orientation, dim=2)
        periodic_component = torch.cos(2 * torch.pi * self.frequency * dot_product)

        # Composite Kernel
        return sq_exp_component * periodic_component

    def generate_K_g(self):
        
        theta1_grid, theta2_grid = torch.meshgrid(self.thetas, self.thetas)
        ori_kernel_val = self.von_mises_kernel(theta1_grid, theta2_grid)
    
        # Spatial kernel
        loc_kernel_val = torch.zeros((len(self.thetas), len(self.thetas), self.grid_size**2, self.grid_size**2))

        for i in range(len(self.thetas)):
            for j in range(len(self.thetas)):
                loc_kernel_val[i,j] = self.composite_feature_kernel(self.thetas[i], self.thetas[j])
        
        print("LOC ",loc_kernel_val.size())
        K_spatial = torch.sum(loc_kernel_val, dim=[0, 1])
        print("K_spatial ",K_spatial.size())
        K_g = torch.kron(K_spatial, ori_kernel_val)
        
        #K_g = ori_kernel_val.unsqueeze(-1).unsqueeze(-1) * loc_kernel_val
        #K_g = K_g.transpose(0, 2).transpose(1, 3).reshape((len(self.thetas) * self.grid_size**2, len(self.thetas) * self.grid_size**2))
        K_g = K_g + self.jitter * torch.eye(len(self.thetas)*self.grid_size**2)
        print(K_g.size())
        return K_g

    def compute_A(self):
        kernel_vals = self.squared_exponential_kernel(self.grid, self.grid, self.length_scale_amplitude)
        return torch.sqrt(torch.exp(MultivariateNormal(torch.zeros(self.grid.size(0)), kernel_vals).sample()))

    def samples(self):

        g = MultivariateNormal(torch.zeros(len(self.thetas)*(self.grid_size**2)), self.K_g).sample()  
        A = self.compute_A()
        
        # Tile amplitudes to match feature fields 
        A = A.repeat(len(self.thetas))
    
        # Combine
        I = g * A  + torch.randn_like(g) * self.sigma
        #I = torch.sum(I.reshape(len(self.thetas), self.grid_size, self.grid_size), dim=0)

        return I, g
    
    def log_likelihood(self, I, g,A):
        A = A.repeat(len(self.thetas))
        I_hat = g * A
        I_hat = torch.sum(I_hat.reshape(len(self.thetas), self.grid_size, self.grid_size), dim=0)
        return MultivariateNormal(I_hat.flatten(), self.sigma * torch.eye(self.grid_size**2)).log_prob(I.flatten())

    def likelihood(self, I, g,A):
        A = A.repeat(len(self.thetas))
        I_hat = g * A
        I_hat = torch.sum(I_hat.reshape(len(self.thetas), self.grid_size, self.grid_size), dim=0)
        return torch.exp(MultivariateNormal(I_hat.flatten(), self.sigma * torch.eye(self.grid_size**2)).log_prob(I.flatten()))

    def visualise(self, combined_fields):

        # Normalise the combined image for visualisation
        combined_fields_normalised = combined_fields / combined_fields.max()

        # Reshape to image format
        combined_image = combined_fields_normalised.view(self.grid_size, self.grid_size).detach().numpy()

        # Visualise the combined image
        plt.figure(figsize=(5,5))
        plt.imshow(combined_image, cmap='gray') 
        plt.title('FoGSM Sample')
        plt.axis("off")
        plt.show()

    def generate_fogsm_dataset(self, num_samples, save=False,save_path=None):
        samples = []
        for _ in range(num_samples):
            I, g = self.samples()
            samples.append((I, g))
    
        # Convert samples to tensors
        images, gs = zip(*samples)
        images = torch.stack(images)
        gs = torch.stack(gs)
    
        if save:
            torch.save(images, gs, save_path)
        else:
            return images

    def visualise_samples(self, save_path, num_samples_to_visualise, grid_size):
        # Load the saved samples
        images, _ = torch.load(save_path)

        # Select a subset of samples to visualise
        selected_samples = images[:num_samples_to_visualise]

        # Create a grid of images
        fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
        fig.subplots_adjust(hspace=0.1, wspace=0.1)

        for i, ax in enumerate(axes.flat):
            if i < num_samples_to_visualise:
                image = selected_samples[i].detach().numpy()
                image = (image - image.min()) / (image.max() - image.min())

                ax.imshow(image, cmap='gray')
        
            ax.set_xticks([])
            ax.set_yticks([])
            ax.axis('off')

        plt.tight_layout()
        plt.show()

In [4]:
torch.manual_seed(0)

# Define the parameters for the FoGSM model
thetas = torch.linspace(0, 2 * torch.pi, 8)  # 8 orientations from 0 to 2*pi
fogsm_params = {
        "thetas": thetas,
        "length_scale_feature": .55,
        "length_scale_amplitude": .9,
        "kappa": .4,
        "jitter": 1e-2,
        "grid_size": 3,
        "frequency": .2,
        "sigma": 0.1,
    }
fogsm_model = FoGSMModel(**fogsm_params)

# Generate and visualise a sample from the FoGSM model
I, g = fogsm_model.samples()
#fogsm_model.visualise(I)


LOC  torch.Size([8, 8, 9, 9])
K_spatial  torch.Size([9, 9])
torch.Size([72, 72])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  theta1 = torch.tensor(theta1)
  theta2 = torch.tensor(theta2)


# SSN

## SSN Base

In [5]:
class SSNBase(torch.nn.Module):
    def __init__(self, n, k, Ne, Ni, tau_e, tau_i, device='cpu', dtype=torch.float64):
        super().__init__()
        self.n = n
        self.k = k
        self.Ne = Ne
        self.Ni = Ni
        self.N = self.Ne + self.Ni
        self.device = device
        self.dtype = dtype
        self.register_buffer('EI', torch.cat([torch.ones(Ne), torch.zeros(Ni)]).to(device).bool())
        self.register_buffer('tau_vec', torch.cat([tau_e * torch.ones(Ne), tau_i * torch.ones(Ni)]).to(device, dtype))
    
    def drdt(self, r, inp_vec):
        return (-r + self.powlaw(self.W @ r + inp_vec)) / self.tau_vec

    def powlaw(self, u):
        return self.k * F.relu(u).pow(self.n)
    
    def dvdt_batch(self, v, inp_vec):
        # Ensure v and inp_vec are 2D tensors for batch processing
        if v.ndim == 1:
            v = v.unsqueeze(0)
        if inp_vec.ndim == 1:
            inp_vec = inp_vec.unsqueeze(0)

        print(f"dvdt: v.shape={v.shape}, inp_vec.shape={inp_vec.shape}, W.shape={self.W.shape}")

        # Compute W @ r for batch processing
        W_r = self.W @ self.powlaw(v).T
        W_r = W_r.T  # Transpose back to match batch dimension
    
        return (-v + W_r + inp_vec) / self.tau_vec

    def simulate_batch(self, inp_vec, v_init=None, duration=500, dt=0.1):
        # inp_vec: [batch_size, num_neurons]
        # r_init: [batch_size, num_neurons]
        # returns: [batch_size, time_steps, num_neurons]

        # Check if inp_vec is a batch of inputs
        if inp_vec.ndim == 2:
            # inp_vec: [batch_size, num_neurons]
            batch_size, self.N = inp_vec.shape
            time_steps = int(duration / dt)

            #print duration, dt, time_steps
            print("duration, dt, time_steps",duration, dt, time_steps)
                  
            inp_vec = inp_vec.unsqueeze(1).expand(-1, time_steps, -1)
        else:
            raise ValueError("inp_vec must be a 2D tensor of shape [batch_size, num_neurons].")

        print(f"simulate: batch_size={batch_size}, time_steps={time_steps}, inp_vec.shape={inp_vec.shape}")

        # Initialise v_init if not provided
        if v_init is None:
            v_init = torch.zeros((batch_size, self.N), device=self.device, dtype=self.dtype)
        else:
            if isinstance(v_init, (int, float)):
                v_init = torch.full((batch_size, self.N), v_init, device=self.device, dtype=self.dtype)
            if v_init.shape[0] != batch_size or v_init.shape[1] != self.N:
                raise ValueError("v_init shape does not match batch_size or neuron count N.")

        # Initialise the rates tensor to store the rates at all time points
        rates = torch.zeros((batch_size, time_steps, self.N), device=self.device, dtype=self.dtype)
        v = v_init

        for t in range(time_steps):
            # Calculate dv for each element in the batch at the current time step
            dv = self.dvdt_batch(v, inp_vec[:, t, :])  # dv: [batch_size, num_neurons]

            # Update the membrane potentials for the current time step
            v += dt * dv

            # Compute the firing rates from the updated membrane potentials
            r = self.powlaw(v)
            rates[:, t, :] = r
            print(f"t={t}, v.shape={r.shape}, dv.shape={dv.shape}")


        return rates

## SSN 2DTopo

In [6]:
class SSN2DTopo(SSNBase):
    def __init__(self, n, k, tauE, tauI, grid_pars, thetas, L = np.pi,device='cpu', dtype=torch.float64):
         
        num_orientations = thetas.shape[0]
        grid_size = grid_pars['grid_size_Nx']
                
        Ne = num_orientations * (grid_size ** 2)
        Ni = num_orientations * (grid_size ** 2)

        super(SSN2DTopo, self).__init__(n=n, k=k, Ne=Ne, Ni=Ni, tau_e=tauE, tau_i=tauI, device=device, dtype=dtype)
        
        self.num_orientations = num_orientations
        self.grid_size = grid_size
        self.Ne = Ne
        self.Ni = Ni
        self.L = L
        self._make_maps(thetas)
        
        # Initialise trainable parameters
        self.J_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype)) # Interaction strengths
        self.s_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype)) # Spatial length scales
        self.p_local = nn.Parameter(torch.rand(2, device=device, dtype=dtype)) # Local connectivity strengths - set to 0?
        self.sigma_oris = nn.Parameter(torch.rand(1, device=device, dtype=dtype)) # Orientation tuning width
        
        self.make_W()

    def _make_maps(self,thetas):

        self.ori_map = torch.tensor(thetas, device=self.device, dtype=self.dtype).flatten()
        self.ori_vec = self.ori_map.repeat(self.grid_size ** 2)  # Repeat orientation values for each grid cell
        self.ori_vec = self.ori_vec.repeat(2) # Repeat for E and I populations

        # Create x and y vectors for grid cells
        self.x_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations).repeat(self.grid_size)
        self.y_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations * self.grid_size)

        # Repeat for E and I populations
        self.x_vec = self.x_vec.repeat(2)
        self.y_vec = self.y_vec.repeat(2)

    def make_W(self):
        
        xy_dist = self.calc_xy_dist()
        ori_dist = self.calc_ori_dist()

        # Compute weight blocks
        W_ee = self.calc_W_block(xy_dist[:self.Ne, :self.Ne], ori_dist[:self.Ne, :self.Ne], self.s_2x2[0][0], self.sigma_oris)
        W_ei = self.calc_W_block(xy_dist[:self.Ne, self.Ne:], ori_dist[:self.Ne, self.Ne:], self.s_2x2[0][1], self.sigma_oris)
        W_ie = self.calc_W_block(xy_dist[self.Ne:, :self.Ne], ori_dist[self.Ne:, :self.Ne], self.s_2x2[1][0], self.sigma_oris)
        W_ii = self.calc_W_block(xy_dist[self.Ne:, self.Ne:], ori_dist[self.Ne:, self.Ne:], self.s_2x2[1][1], self.sigma_oris)
        
        # Apply local connectivity strengths
        W_ee = self.p_local[0] * torch.eye(self.Ne, device=self.device, dtype=self.dtype) + (1 - self.p_local[0]) * W_ee
        W_ei = self.p_local[1] * torch.eye(self.Ni, device=self.device, dtype=self.dtype) + (1 - self.p_local[1]) * W_ei
        
        # Concatenate submatrices to form W
        W = torch.cat([
            torch.cat([self.J_2x2[0, 0] * W_ee, self.J_2x2[0, 1] * W_ei], dim=1),
            torch.cat([self.J_2x2[1, 0] * W_ie, self.J_2x2[1, 1] * W_ii], dim=1)
        ], dim=0).double()

        # Register W as a buffer
        self.register_buffer('W', W)
        
        return self.W
    
    def calc_xy_dist(self):
        Ne = Ni = self.Ne
        x_vec_e = self.x_vec[:Ne]
        y_vec_e = self.y_vec[:Ne]
        x_vec_i = self.x_vec[Ne:Ne+Ni]
        y_vec_i = self.y_vec[Ne:Ne+Ni]
        
        xy_dist = torch.cdist(torch.stack([x_vec_e, y_vec_e], dim=1), torch.stack([x_vec_i, y_vec_i], dim=1), p=2).repeat(2, 2) #Distance Squared

        return xy_dist
    
    def calc_ori_dist(self,L=np.pi, method=None):

        Ne = Ni = self.num_orientations * self.grid_size ** 2
        
        ori_vec_e = self.ori_vec[:Ne]
        ori_vec_i = self.ori_vec[Ne:Ne+Ni]
        
        # define everything as squared distance ori_sqdist = (ori_vec_e - ori_vec_i) ** 2
        if method == "absolute":
            ori_dist = torch.cdist(ori_vec_e.unsqueeze(1), ori_vec_i.unsqueeze(1)).repeat(2,2)
        elif method == "cos":
            ori_vec_e_norm = ori_vec_e / ori_vec_e.norm(dim=1, keepdim=True)
            ori_vec_i_norm = ori_vec_i / ori_vec_i.norm(dim=1, keepdim=True)
            ori_dist = 1 - torch.mm(ori_vec_e_norm, ori_vec_i_norm.t())
        else:
            #1 - cos(2(pi/L) * |theta1 - theta2|^2)
            ori_dist = (1 - torch.cos((2 * np.pi / L) * (ori_vec_e.unsqueeze(1) - ori_vec_i.unsqueeze(0))**2)) / (2 * np.pi / L)**2

            #ori_vec[:,None] - ori_ve
        
        ori_dist = ori_dist.repeat(2, 2)

        return ori_dist

    def calc_W_block(self, xy_dist, ori_dist, s, sigma_oris, CellWiseNormalised = True):

        #Add a small constant to s and sigma_oris to avoid division by zero
        s = s + 1e-8
        sigma_oris = 2*np.pi*sigma_oris/self.L + 1e-8
        
        W =  torch.exp(-xy_dist / s - ori_dist ** 2 / (2 * sigma_oris ** 2))
        W = torch.where(W < 1e-4, torch.zeros_like(W), W)

        
        sW = torch.sum(W, dim=1, keepdim=True)
        if CellWiseNormalised:
            W = W / sW
        else:
            sW = sW.mean()
            W = W / sW

        return W.squeeze()

## SSN Main

In [7]:
# Set network parameters
n = 2
k = 0.4
tauE = 20.0
tauI = 10.0
grid_pars = {'grid_size_Nx': 3}
conn_pars = {'num_orientations': 8}
thetas = np.linspace(0, np.pi, conn_pars['num_orientations'])

ssn_params = {
        "n": 2,
        "k": 0.04,
        "tauE": 20,
        "tauI": 10,
        "grid_pars": grid_pars,
        "thetas": thetas,
    }

# ELBO

In [30]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        super(DirectFit, self).__init__()
        self.ssn_model = SSN2DTopo(**ssn_params)
        self.fogsm_model = FoGSMModel(**fogsm_params)

        self.C_E = nn.Parameter(torch.tensor(2.4)) # taken from Echeveste et al. 2020
        self.C_I = nn.Parameter(torch.tensor(2.4)) # taken from Echeveste et al. 2020

    def sample_trajectories(self, input_batch, num_samples_g, duration=500, dt=.1):

        batch_size = input_batch.shape[0]   
        print("Input batch shape: ", input_batch.shape)  

        # Expand input_batch to match the number of samples
        input_expanded = input_batch.unsqueeze(1).expand(batch_size, num_samples_g, *input_batch.shape[1:])
        print("Input expanded shape: ", input_expanded.shape)

        # Multiply the input batch by C_E and C_I scalars to get the input to the network
        input_weighted = torch.cat([
            input_expanded * self.C_E,
            input_expanded * self.C_I
        ], dim=-1)
        print("Input weighted shape: ", input_weighted.shape)

        # Duplex the input into on and off channels (RELU and -RELU)
        input_duplexed = torch.cat([
            F.relu(input_weighted),
            F.relu(-input_weighted)
        ], dim=-1)
        print("Input duplexed shape: ", input_duplexed.shape)

        # Calculate total neurons per grid point (16 neurons per grid point: 8 excitatory + 8 inhibitory)
        neurons_per_grid_point = 2 * self.ssn_model.num_orientations
        grid_height, grid_width = self.ssn_model.grid_size, self.ssn_model.grid_size
        total_neurons = grid_height * grid_width * neurons_per_grid_point  # 3*3*16 = 144

        # Reshape input_expanded to (batch_size * num_samples, total_neurons)
        input_reshaped = input_weighted.reshape(batch_size * num_samples_g, total_neurons)
        print("Input reshaped shape: ", input_reshaped.shape)

        # Ensure input_reshaped has correct dimensions before calling simulate
        if input_reshaped.shape[1] != self.ssn_model.N:
            input_reshaped = input_reshaped.view(input_reshaped.shape[0], self.ssn_model.N)

        # Generate trajectories for the reshaped input
        trajectories = self.ssn_model.simulate_batch(input_reshaped, duration=duration, dt=dt)
        print("Trajectories shape after simulation: ", trajectories.shape)

        # Reshape trajectories to (batch_size, num_samples, N)
        trajectories = trajectories.reshape(batch_size, num_samples_g, self.ssn_model.N, int(duration/dt))        
        print("Trajectories final shape: ", trajectories.shape)

        return trajectories

    def calculate_log_p_g(self, trajectories):
        num_samples_g, N, duration = trajectories.shape
        mu = torch.zeros_like(trajectories)
        log_p_g = 0
        for sample in range(num_samples_g):
            print("Shapes ", mu[:, sample].shape, trajectories[:, sample].shape, self.fogsm_model.K_g.shape)
            log_p_g = log_p_g + MultivariateNormal(mu[:, sample], self.fogsm_model.K_g).log_prob(trajectories[:, sample]).mean()
        return log_p_g / num_samples_g

    def calculate_log_p_I_given_g(self, I_data, trajectories, A_samples):

        log_likelihood = 0
        for g in trajectories:
            p_I_g = 0
            for a in A_samples:
                p_I_g = p_I_g + self.fogsm_model.likelihood(I_data, g, a) # I should come from the dataset and g from the SSN
            log_likelihood = log_likelihood + torch.log(p_I_g / len(A_samples))

        return log_likelihood / len(trajectories)

        
    def calculate_elbo(self, input_batch, num_samples_g, A_samples, duration=500, dt=.1):

        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input_batch, num_samples_g, duration=duration, dt=dt)
        print("ELBO trajectories received")

        elbo = 0
        for I, trajectory in zip(input_batch, trajectories): # trajectory is the trajectory for a single image
            print("trajectory shape: ", trajectory.shape)
            log_p_g = self.calculate_log_p_g(trajectory)
            print("Log p_g calculated")
            log_p_I_given_g = self.calculate_log_p_I_given_g(I.unsqueeze(0), trajectory, A_samples)
            print("Log p_I_given_g calculated")

            cov_matrix = torch.cov(trajectory.reshape(trajectory.shape[0], -1))
            entropy_term = 0.5 * torch.logdet(cov_matrix)

            elbo = elbo + log_p_g + log_p_I_given_g - entropy_term

        return elbo / len(input_batch)
        
    def optimise_elbo(self, batch_size, num_samples_a, num_samples_g, convergence_threshold=1e-3, optimizer=Adam):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.ssn_model.to(device)
        self.fogsm_model.to(device)

        batch = 0
        prev_elbo = float('-inf')

        while True:
            elbo_batch = 0

            # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
            A_samples = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]

            # Generate a mini-batch of input data
            input_batch = self.fogsm_model.generate_fogsm_dataset(batch_size)

            # Calculate the ELBO for the mini-batch
            elbo_batch = self.calculate_elbo(input_batch, num_samples_g, A_samples, duration=10, dt=1)
            elbo += elbo_batch.item()

            # Optimise the ELBO with respect to the model parameters
            elbo.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Print the ELBO for monitoring
            if (batch + 1) % 1 == 0:
                print(f"Batch [{batch+1}], ELBO: {elbo_batch / batch_size:.4f}")

            # Check for convergence
            if abs(elbo_batch - prev_elbo) < convergence_threshold:
                print(f"Converged after {batch+1} batches.")
                break

            prev_elbo = elbo_batch
            batch += 1
        
        return self.ssn_model, self.fogsm_model

In [31]:
direct_fit = DirectFit(ssn_params, fogsm_params)

# Create an optimiser for the ELBO
optimiser = Adam(list(direct_fit.ssn_model.parameters()), lr=0.001)

# Run the ELBO optimisation
input_data, _ = torch.load("fogsm_dataset.pt")
direct_fit.optimise_elbo(batch_size=32, num_samples_a=10, num_samples_g=100, optimizer=optimiser)

LOC  torch.Size([8, 8, 9, 9])
K_spatial  torch.Size([9, 9])
torch.Size([72, 72])
Input batch shape:  torch.Size([32, 72])
Input expanded shape:  torch.Size([32, 100, 72])
Input weighted shape:  torch.Size([32, 100, 144])
Input duplexed shape:  torch.Size([32, 100, 288])
Input reshaped shape:  torch.Size([3200, 144])
duration, dt, time_steps 10 1 10
simulate: batch_size=3200, time_steps=10, inp_vec.shape=torch.Size([3200, 10, 144])
dvdt: v.shape=torch.Size([3200, 144]), inp_vec.shape=torch.Size([3200, 144]), W.shape=torch.Size([144, 144])
t=0, v.shape=torch.Size([3200, 144]), dv.shape=torch.Size([3200, 144])
dvdt: v.shape=torch.Size([3200, 144]), inp_vec.shape=torch.Size([3200, 144]), W.shape=torch.Size([144, 144])
t=1, v.shape=torch.Size([3200, 144]), dv.shape=torch.Size([3200, 144])
dvdt: v.shape=torch.Size([3200, 144]), inp_vec.shape=torch.Size([3200, 144]), W.shape=torch.Size([144, 144])
t=2, v.shape=torch.Size([3200, 144]), dv.shape=torch.Size([3200, 144])
dvdt: v.shape=torch.Size(

  theta1 = torch.tensor(theta1)
  theta2 = torch.tensor(theta2)


RuntimeError: shape '[-1, 10, 10]' is invalid for input of size 5184