In [4]:
import torch
import torch.nn as nn

In [6]:
# Correcting the use of torch.sqrt to accept a tensor as argument
class FitzHughNagumoNetwork(nn.Module):
    def __init__(self, n_neurons, a, b, J, epsilon, sigma, dt):
        super(FitzHughNagumoNetwork, self).__init__()
        self.n_neurons = n_neurons
        self.a = a
        self.b = b
        self.J = J
        self.epsilon = epsilon
        self.sigma = sigma
        self.dt = dt
        
        # Initialize the variables v and w for all neurons
        self.v = torch.randn(n_neurons, device=device) * 0.1
        self.w = torch.randn(n_neurons, device=device) * 0.1
        
        # Coupling matrix with J/n off-diagonal elements and zeros on the diagonal
        self.coupling = (torch.ones(n_neurons, n_neurons, device=device) - torch.eye(n_neurons, device=device)) * (J / n_neurons)
    
    def cubic_nonlinearity(self, v):
        return v * (1 - v) * (v - self.a)
    
    def forward(self, I, step):
        # Brownian motion term for the noise
        dW = torch.randn(self.n_neurons, device=device) * torch.sqrt(torch.tensor(self.dt).to(device))
        
        # dv/dt = f(v) - w + J/n * sum(v_j - v_i) + I + sigma*dW
        # dw/dt = epsilon * (b*v - w)
        dv = (self.cubic_nonlinearity(self.v) - self.w + torch.matmul(self.coupling, self.v) + I + 
              self.sigma * dW) * self.dt
        dw = self.epsilon * (self.b * self.v - self.w) * self.dt
        
        # Update the variables v and w
        self.v += dv
        self.w += dw

        return self.v, self.w

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Updated parameters as provided
a = 4.0         # Controls excitability
b = 4.0         # Controls coupling between voltage and recovery variable
epsilon = 0.01  # Time-scale ratio of the recovery variable
J = 1.5         # Coupling strength
sigma = 1.5     # Level of noise
I_peak = 1.0    # Peak input current
omega = 1.0     # Frequency of the input current (rad/s)
n_neurons = 4000  # Number of neurons in the network
T = 1000         # Number of simulation steps
dt = 0.01        # Time step for the simulation


# Initialize the network
fhn_network = FitzHughNagumoNetwork(n_neurons, a, b, J, epsilon, sigma, dt).to(device)

# Input current using cosine function
I = I_peak * torch.cos(omega * torch.arange(T, device=device) * dt)

# Record the membrane potentials and recovery variables over time
v_rec = torch.zeros(T, n_neurons, device=device)
w_rec = torch.zeros(T, n_neurons, device=device)

# Simulate for T time steps
for step in range(T):
    current_input = I[step].expand(n_neurons)  # Current input is the same for all neurons at each step
    v, w = fhn_network(current_input, step)
    v_rec[step] = v
    w_rec[step] = w

# The data is now stored in v_rec and w_rec


In [7]:
v_rec

tensor([[ 2.3663e-03, -9.7379e-03, -1.9094e-01,  ...,  7.9870e-02,
          1.5942e-04,  2.7146e-02],
        [ 1.3386e-02,  4.9621e-04, -1.7072e-01,  ...,  8.6687e-02,
          1.1223e-02,  3.6544e-02],
        [ 2.5697e-02,  8.6225e-03, -1.5140e-01,  ...,  9.5532e-02,
          2.0740e-02,  4.4985e-02],
        ...,
        [ 4.2939e+00,  4.2892e+00,  4.2905e+00,  ...,  4.2893e+00,
          4.2918e+00,  4.2975e+00],
        [ 4.2919e+00,  4.2901e+00,  4.2936e+00,  ...,  4.2892e+00,
          4.2923e+00,  4.2981e+00],
        [ 4.2918e+00,  4.2915e+00,  4.2925e+00,  ...,  4.2889e+00,
          4.2941e+00,  4.2986e+00]])