# Direct fit SSN

In [None]:
#ratio of value at end for surround suppression should be small 

In [2]:
from GSM._imports import *
from GSM.fogsm import FoGSMModel
from torch.optim import Adam

In [3]:
from SSN._imports import *
from SSN.ssn_2dtopoV1 import SSN2DTopoV1
from SSN.params import GridParameters

In [198]:
class SSNBase(torch.nn.Module):
    def __init__(self, n, k, Ne, Ni, tau_e, tau_i, device='cpu', dtype=torch.float64):
        super().__init__()
        self.n = n
        self.k = k
        self.Ne = Ne
        self.Ni = Ni
        self.N = self.Ne + self.Ni
        self.device = device
        self.dtype = dtype
        self.register_buffer('EI', torch.cat([torch.ones(Ne), torch.zeros(Ni)]).to(device).bool())
        self.register_buffer('tau_vec', torch.cat([tau_e * torch.ones(Ne), tau_i * torch.ones(Ni)]).to(device, dtype))
    
    def drdt(self, r, inp_vec):
        return (-r + self.powlaw(self.W @ r + inp_vec)) / self.tau_vec
    
    def drdt_batch(self, r, inp_vec):
        # Ensure r and inp_vec are 2D tensors for batch processing
        if r.ndim == 1:
            r = r.unsqueeze(0)
        if inp_vec.ndim == 1:
            inp_vec = inp_vec.unsqueeze(0)

        print(f"drdt: r.shape={r.shape}, inp_vec.shape={inp_vec.shape}, W.shape={self.W.shape}")

        # Compute W @ r for batch processing
        W_r = self.W @ r.T
        W_r = W_r.T  # Transpose back to match batch dimension
    
        return (-r + self.powlaw(W_r + inp_vec)) / self.tau_vec
    
    def dvdt_batch(self, v, inp_vec):
        # Ensure v and inp_vec are 2D tensors for batch processing
        if v.ndim == 1:
            v = v.unsqueeze(0)
        if inp_vec.ndim == 1:
            inp_vec = inp_vec.unsqueeze(0)

        print(f"dvdt: v.shape={v.shape}, inp_vec.shape={inp_vec.shape}, W.shape={self.W.shape}")

        # Compute W @ r for batch processing
        W_r = self.W @ self.powlaw(v).T
        W_r = W_r.T  # Transpose back to match batch dimension
    
        return (-v + W_r + inp_vec) / self.tau_vec

    def powlaw(self, u):
        return self.k * F.relu(u).pow(self.n)
    
    def simulate(self, inp_vec, r_init=None, duration=100, dt=0.1):
        if r_init is None:
            r_init = torch.zeros((self.N,), device=self.device, dtype=self.dtype)
        
        r = r_init
        t = 0
        while t < duration:
            dr = self.drdt(r, inp_vec)
            r += dt * dr
            t += dt
        return r
    
    def simulate_batch(self, inp_vec, r_init=None, duration=100, dt=0.1):

        # Check if inp_vec is a batch of inputs
        batch_size = inp_vec.shape[0] if inp_vec.ndim > 1 else 1
        print(f"simulate: batch_size={batch_size}, inp_vec.shape={inp_vec.shape}")

    
        # Initialize r_init if not provided
        if r_init is None:
            r_init = torch.zeros((batch_size, self.N), device=self.device, dtype=self.dtype)
        else:
            if isinstance(r_init, int):
                r_init = torch.zeros((batch_size, self.N), device=self.device, dtype=self.dtype)
            if r_init.shape[0] != batch_size or r_init.shape[1] != self.N:
                raise ValueError("r_init shape does not match batch_size or neuron count N.")
    
    
        r = r_init
        t = 0
        while t < duration:
            # Calculate dr for each element in the batch
            dr = self.dvdt_batch(r, inp_vec)
            print(f"t={t}, r.shape={r.shape}, dr.shape={dr.shape}")

            r += dt * dr
            t += dt
    
        return r


In [199]:
class SSN2DTopo(SSNBase):
    def __init__(self, n, k, tauE, tauI, grid_pars, thetas, L = np.pi,device='cpu', dtype=torch.float64):

        """
        Initialises the SSN2DTopo model, which represents a two-dimensional topographic Stabilized Supralinear Network (SSN).

        Args:
            n (float): The power law exponent for the activation function, typically chosen between 1 and 2 to model the supralinear response properties of cortical neurons.
            k (float): The gain factor for the activation function, which scales the overall activity levels of the network.
            tauE (float): The time constant for excitatory neurons, reflecting their temporal integration properties.
            tauI (float): The time constant for inhibitory neurons, typically shorter than tauE to model faster inhibition dynamics.
            grid_pars (dict): Dictionary containing grid parameters, such as the grid size.
            conn_pars (dict): Dictionary containing connectivity parameters, such as the number of orientations.
            thetas (numpy.ndarray): Array of orientation values representing the preferred orientations of neurons.
            L (float, optional): Range of orientation values (default: np.pi).
            device (str, optional): Device to use for tensor operations (default: 'cpu').
            dtype (torch.dtype, optional): Data type for tensors (default: torch.float64).
        """
         
        num_orientations = thetas.shape[0]
        grid_size = grid_pars['grid_size_Nx']
                
        Ne = num_orientations * (grid_size ** 2)
        Ni = num_orientations * (grid_size ** 2)

        super(SSN2DTopo, self).__init__(n=n, k=k, Ne=Ne, Ni=Ni, tau_e=tauE, tau_i=tauI, device=device, dtype=dtype)
        
        self.num_orientations = num_orientations
        self.grid_size = grid_size
        self.Ne = Ne
        self.Ni = Ni
        self.L = L
        self._make_maps(thetas)
        
        # Initialise trainable parameters
        self.J_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype)) # Interaction strengths
        self.s_2x2 = nn.Parameter(torch.rand(2, 2, device=device, dtype=dtype)) # Spatial length scales
        self.p_local = nn.Parameter(torch.rand(2, device=device, dtype=dtype)) # Local connectivity strengths - set to 0?
        self.sigma_oris = nn.Parameter(torch.rand(1, device=device, dtype=dtype)) # Orientation tuning width
        
        self.make_W()

    def _make_maps(self,thetas):
        """
        Creates orientation and spatial maps for the SSN2DTopo model.

        Args:
            thetas (numpy.ndarray): Array of orientation values representing the preferred orientations of neurons.
        """
        
        self.ori_map = torch.tensor(thetas, device=self.device, dtype=self.dtype).flatten()
        self.ori_vec = self.ori_map.repeat(self.grid_size ** 2)  # Repeat orientation values for each grid cell
        self.ori_vec = self.ori_vec.repeat(2) # Repeat for E and I populations

        # Create x and y vectors for grid cells
        self.x_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations).repeat(self.grid_size)
        self.y_vec = torch.arange(self.grid_size, device=self.device, dtype=self.dtype).repeat_interleave(self.num_orientations * self.grid_size)

        # Repeat for E and I populations
        self.x_vec = self.x_vec.repeat(2)
        self.y_vec = self.y_vec.repeat(2)

    def make_W(self):

        """
        Computes the weight matrix W for the network, representing the strength of connections between neurons.

        The weight matrix W is constructed from four submatrices (W_ee, W_ei, W_ie, W_ii), computed based on the spatial and orientation distances between neurons, 
        as well as the length scale parameters (s_2x2 and sigma_oris) and local connectivity strengths (p_local).
        """
        
        xy_dist = self.calc_xy_dist()
        ori_dist = self.calc_ori_dist()
        print(f"xy_dist.shape: {xy_dist.shape}, ori_dist.shape: {ori_dist.shape}")

        # Compute weight blocks
        W_ee = self.calc_W_block(xy_dist[:self.Ne, :self.Ne], ori_dist[:self.Ne, :self.Ne], self.s_2x2[0][0], self.sigma_oris)
        W_ei = self.calc_W_block(xy_dist[:self.Ne, self.Ne:], ori_dist[:self.Ne, self.Ne:], self.s_2x2[0][1], self.sigma_oris)
        W_ie = self.calc_W_block(xy_dist[self.Ne:, :self.Ne], ori_dist[self.Ne:, :self.Ne], self.s_2x2[1][0], self.sigma_oris)
        W_ii = self.calc_W_block(xy_dist[self.Ne:, self.Ne:], ori_dist[self.Ne:, self.Ne:], self.s_2x2[1][1], self.sigma_oris)
        print(f"W_ee.shape: {W_ee.shape}, W_ei.shape: {W_ei.shape}, W_ie.shape: {W_ie.shape}, W_ii.shape: {W_ii.shape}")

        
        # Apply local connectivity strengths
        W_ee = self.p_local[0] * torch.eye(self.Ne, device=self.device, dtype=self.dtype) + (1 - self.p_local[0]) * W_ee
        W_ei = self.p_local[1] * torch.eye(self.Ni, device=self.device, dtype=self.dtype) + (1 - self.p_local[1]) * W_ei
        
        # Concatenate submatrices to form W
        W = torch.cat([
            torch.cat([self.J_2x2[0, 0] * W_ee, self.J_2x2[0, 1] * W_ei], dim=1),
            torch.cat([self.J_2x2[1, 0] * W_ie, self.J_2x2[1, 1] * W_ii], dim=1)
        ], dim=0).double()

        # Register W as a buffer
        self.register_buffer('W', W)

        print("W shape:", W.shape)
        
        return self.W
    
    def calc_xy_dist(self):
        Ne = Ni = self.Ne
        x_vec_e = self.x_vec[:Ne]
        y_vec_e = self.y_vec[:Ne]
        x_vec_i = self.x_vec[Ne:Ne+Ni]
        y_vec_i = self.y_vec[Ne:Ne+Ni]
        
        xy_dist = torch.cdist(torch.stack([x_vec_e, y_vec_e], dim=1), torch.stack([x_vec_i, y_vec_i], dim=1), p=2).repeat(2, 2) #Distance Squared

        return xy_dist
    
    def calc_ori_dist(self,L=np.pi, method=None):

        Ne = Ni = self.num_orientations * self.grid_size ** 2
        
        ori_vec_e = self.ori_vec[:Ne]
        ori_vec_i = self.ori_vec[Ne:Ne+Ni]
        
        # define everything as squared distance ori_sqdist = (ori_vec_e - ori_vec_i) ** 2
        if method == "absolute":
            ori_dist = torch.cdist(ori_vec_e.unsqueeze(1), ori_vec_i.unsqueeze(1)).repeat(2,2)
        elif method == "cos":
            ori_vec_e_norm = ori_vec_e / ori_vec_e.norm(dim=1, keepdim=True)
            ori_vec_i_norm = ori_vec_i / ori_vec_i.norm(dim=1, keepdim=True)
            ori_dist = 1 - torch.mm(ori_vec_e_norm, ori_vec_i_norm.t())
        else:
            #1 - cos(2(pi/L) * |theta1 - theta2|^2)
            ori_dist = (1 - torch.cos((2 * np.pi / L) * (ori_vec_e.unsqueeze(1) - ori_vec_i.unsqueeze(0))**2)) / (2 * np.pi / L)**2

            #ori_vec[:,None] - ori_ve
        
        ori_dist = ori_dist.repeat(2, 2)

        return ori_dist

    def calc_W_block(self, xy_dist, ori_dist, s, sigma_oris, CellWiseNormalised = True):

        """
        Computes a weight block for the weight matrix W.

        Args:
            xy_dist (torch.Tensor): Distance matrix between grid cells.
            ori_dist (torch.Tensor): Orientation distance matrix.
            s (float): Length scale parameter for the spatial kernel.
            sigma_oris (float): Length scale parameter for the orientation kernel.
            CellWiseNormalised (bool, optional): Whether to perform cell-wise normalization (default: True).

        Returns:
            W (torch.Tensor): Weight block.
        """

        #Add a small constant to s and sigma_oris to avoid division by zero
        s = s + 1e-8
        sigma_oris = 2*np.pi*sigma_oris/self.L + 1e-8
        
        W =  torch.exp(-xy_dist / s - ori_dist ** 2 / (2 * sigma_oris ** 2))
        #print(f"calc_W_block before sparse: xy_dist.shape={xy_dist.shape}, ori_dist.shape={ori_dist.shape}, W.shape={W.shape}")
        W = torch.where(W < 1e-4, torch.zeros_like(W), W)
        #print(f"calc_W_block after sparse: xy_dist.shape={xy_dist.shape}, ori_dist.shape={ori_dist.shape}, W.shape={W.shape}")

        
        sW = torch.sum(W, dim=1, keepdim=True)
        if CellWiseNormalised:
            W = W / sW
        else:
            sW = sW.mean()
            W = W / sW
        #print(f"calc_W_block: xy_dist.shape={xy_dist.shape}, ori_dist.shape={ori_dist.shape}, W.shape={W.shape}")

        return W.squeeze()

In [224]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        super(DirectFit, self).__init__()
        self.ssn_model = SSN2DTopo(**ssn_params)
        self.fogsm_model = FoGSMModel(**fogsm_params)

    def sample_trajectories(self, input_batch, num_samples_g, duration=500, dt=1):
        batch_size = input_batch.shape[0]
        print("Batch size: ", batch_size)
    
        print("Input shape: ", input_batch.shape)
        # Expand input_batch to match the number of samples
        input_expanded = input_batch.unsqueeze(1).expand(batch_size, num_samples_g, *input_batch.shape[1:])
        print(input_expanded)
        print("Input expanded shape: ", input_expanded.shape)

        #TODO: take the input batch (which should be 72*72) and multiply it by C_E and C_I to get the input to the network (both should be optimised)
        #TODO: duplex the input into on and off channels (RELU and -RELU)
        #TODO: ignore the phase - phase preference is not important for the model

        # Reshape input_expanded to (batch_size * num_samples, input_size)
        #input_reshaped = input_expanded.reshape(-1, *input_batch.shape[1:])
        #input_reshaped = input_expanded.reshape(batch_size * num_samples_g, -1)

        # Calculate total neurons per grid point (16 neurons per grid point: 8 excitatory + 8 inhibitory)
        neurons_per_grid_point = 2 * self.ssn_model.num_orientations
        grid_height, grid_width = input_batch.shape[1], input_batch.shape[2]
        total_neurons = grid_height * grid_width * neurons_per_grid_point  # 3*3*16 = 144

        # Reshape input_expanded to (batch_size * num_samples, total_neurons)
        input_reshaped = input_expanded.reshape(batch_size * num_samples_g, total_neurons)
        print("Input reshaped shape: ", input_reshaped.shape)



        # Ensure input_reshaped has correct dimensions before calling simulate
        if input_reshaped.shape[1] != self.ssn_model.N:
            input_reshaped = input_reshaped.view(input_reshaped.shape[0], self.ssn_model.N)

        # Generate trajectories for the reshaped input
        trajectories = self.ssn_model.simulate_batch(input_reshaped, duration, dt)
        print("Trajectories shape after simulation: ", trajectories.shape)

        # Reshape trajectories to (batch_size, num_samples, N)
        trajectories = trajectories.reshape(batch_size, num_samples_g)
        print("Trajectories final shape: ", trajectories.shape)

        return trajectories

    def calculate_log_p_g(self, trajectories):
        mu = torch.zeros_like(trajectories[0])
        log_p_g = MultivariateNormal(mu, self.fogsm_model.K_g).log_prob(trajectories).mean()
        return log_p_g

    def calculate_log_p_I_given_g(self, I_data, trajectories, A_samples):
        
        # Calculate log p(I|g,A) for each sample
        # log_likelihood = []
        # for a in A_samples:
        #     log_p = 0
        #     for g in trajectories:
        #         # I shouyld come from the dataset 
        #         I = g * a
        #         log_p += self.fogsm_model.log_likelihood(I,g,A)
        #     log_likelihood.append(log_p / len(trajectories))
        
        # log_likelihood = []
        # for g in trajectories:
        #     p_I_g = 0
        #     for a in A_samples:
        #         # I shouyld come from the dataset 
        #         I = I_data
        #         p_I_g += self.fogsm_model.likelihood(I,g,A)
        #     log_likelihood.append(torch.log(p_I_g / len(A_samples)))
            
        log_likelihood = 0
        for g in trajectories:
            p_I_g = 0
            for a in A_samples:
                # I shouyld come from the dataset 
                I = I_data
                p_I_g = p_I_g + self.fogsm_model.likelihood(I, g, A)
            log_likelihood = log_likelihood + torch.log(p_I_g / len(A_samples))

        return log_likelihood / len(trajectories)
    
        # # Approximate the expectation over A using Monte Carlo sampling
        # p_I_given_g = torch.exp(torch.stack(log_likelihood)).mean()
        # log_p_I_given_g = torch.log(p_I_given_g)

        # return log_p_I_given_g
        

    def calculate_elbo(self, input_batch, num_samples_g, A_samples, duration=500, dt=1):

        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input_batch, num_samples_g, duration, dt)

        elbo = 0
        for I, trajectory in zip(input_batch, trajectories): # trajectory is the trajectory for a single image

        # Calculate the first two terms of the ELBO using Monte Carlo approximation
            log_p_g = self.calculate_log_p_g(trajectories)
            log_p_I_given_g = self.calculate_log_p_I_given_g(I, trajectory, A_samples)

        # Approximate the entropy term
        # cov_matrix = torch.cov(trajectories.reshape(num_samples_g, -1).T)

            cov_matrix = torch.cov(trajectory.reshape(trajectory.shape[0], -1))
            entropy_term = 0.5 * torch.logdet(cov_matrix)

            elbo = elbo + log_p_g + log_p_I_given_g - entropy_term

        return elbo / len(input_batch)
        
    def optimise_elbo(self, batch_size, num_samples_a, num_samples_g, convergence_threshold=1e-3, optimizer=Adam):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.ssn_model.to(device)
        self.fogsm_model.to(device)

        batch = 0
        prev_elbo = float('-inf')

        while True:
            elbo_batch = 0

            # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
            A_samples = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]

            # Generate a mini-batch of input data
            input_batch = self.fogsm_model.generate_fogsm_dataset(batch_size)

            # Calculate the ELBO for the mini-batch
            elbo_batch = self.calculate_elbo(input_batch, num_samples_g, A_samples)
            elbo += elbo_batch.item()

            # Optimise the ELBO with respect to the model parameters
            elbo.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Print the ELBO for monitoring
            if (batch + 1) % 100 == 0:
                print(f"Batch [{batch+1}], ELBO: {elbo_batch / batch_size:.4f}")

            # Check for convergence
            if abs(elbo_batch - prev_elbo) < convergence_threshold:
                print(f"Converged after {batch+1} batches.")
                break

            prev_elbo = elbo_batch
            batch += 1
        
        return self.ssn_model, self.fogsm_model

In [225]:
torch.manual_seed(0)

# Define the parameters for the FoGSM model
thetas = torch.linspace(0, 2 * torch.pi, 8)  # 8 orientations from 0 to 2*pi
fogsm_params = {
        "thetas": thetas,
        "length_scale_feature": 0.5,
        "length_scale_amplitude": 1.2,
        "kappa": 1.0,
        "jitter": 1e-4,
        "grid_size": 3,
        "frequency": 0.9,
        "sigma": 0.1,
    }
fogsm_model = FoGSMModel(**fogsm_params)

  theta1 = torch.tensor(theta1)
  theta2 = torch.tensor(theta2)


In [226]:
# Set network parameters
n = 2
k = 0.4
tauE = 20.0
tauI = 10.0
grid_pars = {'grid_size_Nx': 3}
conn_pars = {'num_orientations': 8}
thetas = np.linspace(0, np.pi, conn_pars['num_orientations'])

ssn_params = {
        "n": 2,
        "k": 0.04,
        "tauE": 20,
        "tauI": 10,
        "grid_pars": grid_pars,
        "thetas": thetas,
    }

In [229]:
direct_fit = DirectFit(ssn_params, fogsm_params)

# Create an optimiser for the ELBO
optimiser = Adam(list(direct_fit.ssn_model.parameters()), lr=0.001)

# Run the ELBO optimisation
input_data, _ = torch.load("fogsm_dataset.pt")
direct_fit.optimise_elbo(batch_size=32, num_samples_a=10, num_samples_g=100, optimizer=optimiser)

xy_dist.shape: torch.Size([144, 144]), ori_dist.shape: torch.Size([144, 144])
W_ee.shape: torch.Size([72, 72]), W_ei.shape: torch.Size([72, 72]), W_ie.shape: torch.Size([72, 72]), W_ii.shape: torch.Size([72, 72])
W shape: torch.Size([144, 144])
Batch size:  32
Input shape:  torch.Size([32, 3, 3])
tensor([[[[-39.7896,  -5.3572, -18.7189],
          [  2.9520,  58.4735,  -0.1313],
          [ -4.9396,   6.8304, -18.0139]],

         [[-39.7896,  -5.3572, -18.7189],
          [  2.9520,  58.4735,  -0.1313],
          [ -4.9396,   6.8304, -18.0139]],

         [[-39.7896,  -5.3572, -18.7189],
          [  2.9520,  58.4735,  -0.1313],
          [ -4.9396,   6.8304, -18.0139]],

         ...,

         [[-39.7896,  -5.3572, -18.7189],
          [  2.9520,  58.4735,  -0.1313],
          [ -4.9396,   6.8304, -18.0139]],

         [[-39.7896,  -5.3572, -18.7189],
          [  2.9520,  58.4735,  -0.1313],
          [ -4.9396,   6.8304, -18.0139]],

         [[-39.7896,  -5.3572, -18.7189],
     

RuntimeError: shape '[3200, 144]' is invalid for input of size 28800

In [197]:
#SSN2DTopoV1
torch.manual_seed(0)


# Define the parameters for the SSN model
grid_pars = GridParameters(
    gridsize_Nx=17,  # Number of grid points in one dimension
    gridsize_deg=3.2,  # Size of the grid in degrees of visual angle
    magnif_factor=2,  # Magnification factor to convert degrees to mm
    hyper_col=800,  # Hypercolumn
    )
psi = torch.tensor(0.774)
conn_pars = {
        'J_2x2': torch.tensor([[1.124, -0.931], [1.049, -0.537]], dtype=torch.float64) * torch.pi * psi,
        's_2x2': torch.tensor([[0.2955, 0.09], [0.5542, 0.09]]),
        'p_local': [0.72, 0.7],
        'sigma_oris': 45,
    }
ssn_params = {
        "n": 2,
        "k": 0.04,
        "tauE": 20,
        "tauI": 10,
        "grid_pars": grid_pars,
        "conn_pars": conn_pars,
    }



### Unbatched DirectFit archived

In [None]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        self.ssn_model = SSN2DTopoV1(**ssn_params)
        self.fogsm_model = FoGSMModel(**fogsm_params)

    def sample_trajectories(self, input, num_samples, duration=500, dt=1):
        trajectories = []
        for _ in range(num_samples):
            traj = self.ssn_model.run_simulation(input, duration, dt)
            trajectories.append(traj)
        return torch.stack(trajectories)

    def calculate_log_p_g(self, trajectories):
        mu = torch.zeros_like(trajectories[0])
        log_p_g = MultivariateNormal(mu, self.fogsm_model.K_g).log_prob(trajectories).mean()
        return log_p_g

    def calculate_log_p_I_given_g(self, trajectories, num_samples_a=10):
        # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
        amplitude_fields = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]
        
        # Calculate log p(I|g,A) for each sample
        log_likelihood = []
        for a in amplitude_fields:
            log_p = 0
            for g in trajectories:
                I = g * a
                log_p += self.fogsm_model.log_likelihood(I)
            log_likelihood.append(log_p / len(trajectories))
        
        # Approximate the expectation over A using Monte Carlo sampling
        p_I_given_g = torch.exp(torch.stack(log_likelihood)).mean()
        log_p_I_given_g = torch.log(p_I_given_g)

        return log_p_I_given_g

    def calculate_elbo(self, input, num_samples, duration=500, dt=1):
        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input, num_samples, duration, dt)
        
        # Calculate the first two terms of the ELBO using Monte Carlo approximation
        log_p_g = self.calculate_log_p_g(trajectories)
        log_p_I_given_g = self.calculate_log_p_I_given_g(trajectories)
        
        # Approximate the entropy term
        cov_matrix = torch.cov(trajectories.reshape(num_samples, -1).T)
        entropy_term = 0.5 * torch.logdet(cov_matrix)
        
        elbo = log_p_g + log_p_I_given_g - entropy_term
        return elbo

    def optimise_elbo(self, input, num_samples, num_epochs, optimizer):
        for epoch in range(num_epochs):
            # Calculate the ELBO
            elbo = self.calculate_elbo(input, num_samples)
            
            #resample As for each minibatch

            # Optimise the ELBO with respect to the model parameters
            elbo.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # Print the ELBO for monitoring
            if (epoch + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], ELBO: {elbo.item():.4f}")