# Direct fit SSN

In [1]:
from GSM._imports import *
from GSM.fogsm import FoGSMModel



In [2]:
from SSN._imports import *
from SSN.ssn_2dtopoV1 import SSN2DTopoV1
from SSN.params import GridParameters

In [6]:
from tensorflow import keras

: 

In [3]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        super(DirectFit, self).__init__()
        self.ssn_model = SSN2DTopoV1(**ssn_params)
        self.fogsm_model = FoGSMModel(**fogsm_params)

    def sample_trajectories(self, input_batch, num_samples_g, duration=500, dt=1):
        batch_size = input_batch.shape[0]

        # Expand input_batch to match the number of samples
        input_expanded = input_batch.unsqueeze(1).expand(batch_size, num_samples_g, -1)

        # Reshape input_expanded to (batch_size * num_samples, input_size)
        input_reshaped = input_expanded.reshape(-1, input_batch.shape[-1])

        # Generate trajectories for the reshaped input
        trajectories = self.ssn_model.run_simulation(input_reshaped, duration, dt)

        # Reshape trajectories to (batch_size, num_samples, N)
        trajectories = trajectories.reshape(batch_size, num_samples_g)

        return trajectories

    def calculate_log_p_g(self, trajectories):
        mu = torch.zeros_like(trajectories[0])
        log_p_g = MultivariateNormal(mu, self.fogsm_model.K_g).log_prob(trajectories).mean()
        return log_p_g

    def calculate_log_p_I_given_g(self, trajectories, A_samples):
        
        # Calculate log p(I|g,A) for each sample
        log_likelihood = []
        for a in A_samples:
            log_p = 0
            for g in trajectories:
                I = g * a
                log_p += self.fogsm_model.log_likelihood(I)
            log_likelihood.append(log_p / len(trajectories))
        
        # Approximate the expectation over A using Monte Carlo sampling
        p_I_given_g = torch.exp(torch.stack(log_likelihood)).mean()
        log_p_I_given_g = torch.log(p_I_given_g)

        return log_p_I_given_g

    def calculate_elbo(self, input_batch, num_samples_g, A_samples, duration=500, dt=1):

        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input_batch, num_samples_g, duration, dt)

        # Calculate the first two terms of the ELBO using Monte Carlo approximation
        log_p_g = self.calculate_log_p_g(trajectories)
        log_p_I_given_g = self.calculate_log_p_I_given_g(trajectories, A_samples)

        # Approximate the entropy term
        # cov_matrix = torch.cov(trajectories.reshape(num_samples_g, -1).T)

        cov_matrix = torch.cov(trajectories.reshape(trajectories.shape[0], -1))
        entropy_term = 0.5 * torch.logdet(cov_matrix)

        elbo = log_p_g + log_p_I_given_g - entropy_term

        return elbo
    
    def optimise_elbo(self, input_data, batch_size, num_samples_a, num_samples_g, num_epochs, optimizer):
        dataset = torch.utils.data.TensorDataset(input_data)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(num_epochs):
            elbo_epoch = 0
            for input_batch in dataloader:
                input_batch = input_batch[0]

                # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
                A_samples = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]

                # Calculate the ELBO for the mini-batch
                elbo_batch = self.calculate_elbo(input_batch, num_samples_g, A_samples)
                elbo += elbo_batch.item()

                # Resample As for each mini-batch
                self.fogsm_model.resample_A()

                # Optimize the ELBO with respect to the model parameters
                elbo.backward()
                optimizer.step()
                optimizer.zero_grad()

            # Print the ELBO for monitoring
            if (epoch + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], ELBO: {elbo_epoch / len(dataloader):.4f}")
                

In [4]:
torch.manual_seed(0)

# Define the parameters for the FoGSM model
length_scale_feature = 0.5
length_scale_amplitude = 1.2
kappa = 1.0
grid_size = 10
frequency = 0.9
sigma = 0.1
thetas = torch.linspace(0, 2 * torch.pi, 8)  # 8 orientations from 0 to 2*pi

fogsm_params = {
        "thetas": thetas,
        "length_scale_feature": length_scale_feature,
        "length_scale_amplitude": length_scale_amplitude,
        "kappa": kappa,
        "jitter": 1e-4,
        "grid_size": grid_size,
        "frequency": frequency,
        "sigma": sigma,
    }

# Define the parameters for the SSN model
grid_pars = GridParameters(
    gridsize_Nx=17,  # Number of grid points in one dimension
    gridsize_deg=3.2,  # Size of the grid in degrees of visual angle
    magnif_factor=2,  # Magnification factor to convert degrees to mm
    hyper_col=800,  # Hypercolumn
    )

psi = torch.tensor(0.774)

#LEARN all the conn parameters in the optimisation 
conn_pars = {
        'J_2x2': torch.tensor([[1.124, -0.931], [1.049, -0.537]], dtype=torch.float64) * torch.pi * psi,
        's_2x2': torch.tensor([[0.2955, 0.09], [0.5542, 0.09]]),
        'p_local': [0.72, 0.7],
        'sigma_oris': 45,
    }

n = 2
k = 0.04
tauE = 20  # Time constant for excitatory neurons
tauI = 10  # Time constant for inhibitory neurons

ssn_params = {
        "n": n,
        "k": k,
        "tauE": tauE,
        "tauI": tauI,
        "grid_pars": grid_pars,
        "conn_pars": conn_pars,
    }

direct_fit = DirectFit(ssn_params, fogsm_params)

# Define the number of samples and epochs for ELBO optimisation
num_samples = 100
num_epochs = 1000

# Create an optimiser for the ELBO
optimiser = Adam(list(direct_fit.ssn_model.parameters()), lr=0.001)

# Run the ELBO optimisation
direct_fit.optimise_elbo(None, num_samples, num_epochs, optimiser)

SSN2DTopoV1.__init__
grid_pars:  <SSN.params.GridParameters object at 0x7fea39f147c0>
conn_pars:  {'J_2x2': tensor([[ 2.7331, -2.2638],
        [ 2.5507, -1.3058]], dtype=torch.float64), 's_2x2': tensor([[0.2955, 0.0900],
        [0.5542, 0.0900]]), 'p_local': [0.72, 0.7], 'sigma_oris': 45}
Making W
W made
SSN2DTopoV1.__init__ done


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  theta1 = torch.tensor(theta1)
  theta2 = torch.tensor(theta2)


NameError: name 'Adam' is not defined

In [None]:
class DirectFit:
    def __init__(self, ssn_params, fogsm_params):
        
        self.ssn_model = SSN2DTopoV1(**ssn_params)

        self.fogsm_model = FoGSMModel(**fogsm_params)

    def sample_trajectories(self, input, num_samples, duration=500, dt=1):
        trajectories = []
        for _ in range(num_samples):
            traj = self.ssn_model.run_simulation(input, duration, dt)
            trajectories.append(traj)
        return torch.stack(trajectories)

    def calculate_log_p_g(self, trajectories):
        mu = torch.zeros_like(trajectories[0])
        log_p_g = MultivariateNormal(mu, self.fogsm_model.K_g).log_prob(trajectories).mean()
        return log_p_g

    def calculate_log_p_I_given_g(self, trajectories, num_samples_a=10):
        # Sample amplitude fields from the prior p(A) using Monte Carlo sampling
        amplitude_fields = [self.fogsm_model.compute_A() for _ in range(num_samples_a)]
        
        # Calculate log p(I|g,A) for each sample
        log_likelihood = []
        for a in amplitude_fields:
            log_p = 0
            for g in trajectories:
                I = g * a
                log_p += self.fogsm_model.log_likelihood(I)
            log_likelihood.append(log_p / len(trajectories))
        
        # Approximate the expectation over A using Monte Carlo sampling
        p_I_given_g = torch.exp(torch.stack(log_likelihood)).mean()
        log_p_I_given_g = torch.log(p_I_given_g)

        return log_p_I_given_g

    def calculate_elbo(self, input, num_samples, duration=500, dt=1):
        # Sample trajectories from the SSN
        trajectories = self.sample_trajectories(input, num_samples, duration, dt)
        
        # Calculate the first two terms of the ELBO using Monte Carlo approximation
        log_p_g = self.calculate_log_p_g(trajectories)
        log_p_I_given_g = self.calculate_log_p_I_given_g(trajectories)
        
        # Approximate the entropy term
        cov_matrix = torch.cov(trajectories.reshape(num_samples, -1).T)
        entropy_term = 0.5 * torch.logdet(cov_matrix)
        
        elbo = log_p_g + log_p_I_given_g - entropy_term
        return elbo

    def optimise_elbo(self, input, num_samples, num_epochs, optimizer):
        for epoch in range(num_epochs):
            # Calculate the ELBO
            elbo = self.calculate_elbo(input, num_samples)
            
            #resample As for each minibatch

            # Optimise the ELBO with respect to the model parameters
            elbo.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # Print the ELBO for monitoring
            if (epoch + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], ELBO: {elbo.item():.4f}")