In [5]:
import torch
import torch.nn as nn

class IzhikevichLayer(nn.Module):
    def __init__(self, size, a=0.02, b=0.2, c=-65, d=8, dt=1.0):
        super(IzhikevichLayer, self).__init__()
        
        # Izhikevich parameters
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        self.dt = dt
        self.size = size
        
        # Initial values for v and u
        self.v = torch.full((self.size,), self.c, dtype=torch.float32)
        self.u = torch.full((self.size,), self.b * self.c, dtype=torch.float32)
        
    def forward(self, I):
        # Izhikevich equations for a layer of neurons
        dv = (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + I) * self.dt
        du = (self.a * (self.b * self.v - self.u)) * self.dt
        
        self.v += dv
        self.u += du
        
        # Record spikes
        spikes = (self.v >= 30).float()
        
        # If the neurons fire
        self.v[self.v >= 30] = self.c
        self.u[self.v >= 30] += self.d
        
        return spikes

class SNN(nn.Module):
    def __init__(self,input_size=10):
        super(SNN, self).__init__()
        self.layer1 = IzhikevichLayer(size=input_size)  # 10 neurons
        self.weight = nn.Parameter(torch.randn(input_size, 5))  # Weight matrix to transform spikes from 10 to 5 neurons
        self.layer2 = IzhikevichLayer(size=5)   # 5 neurons
        
    def forward(self, I):
        spikes1 = self.layer1(I)
        
        # Using a weight matrix to project spikes from the first layer to the second layer
        projected_spikes = torch.matmul(spikes1.unsqueeze(0), self.weight).squeeze()
        
        spikes2 = self.layer2(projected_spikes)
        
        return spikes2
    
# Test the new network
network_direct_spikes = SNN()
output_direct_spikes = network_direct_spikes(torch.full((10,), 10.0))  # Apply input current of 10 to each neuron in the first layer

output_direct_spikes

tensor([0., 0., 0., 0., 0.])

In [7]:
# Modify the edge_stim function to handle the theta value as a torch tensor
def generate_stimuli(x0, y0, F=5, pixel_h=320, pixel_w=240, theta=0, w=1):
    x = torch.arange(pixel_h).float().reshape(-1, 1) - x0
    y = torch.arange(pixel_w).float().reshape(1, -1) - y0
    theta_tensor = torch.tensor(theta)
    u = x * torch.cos(theta_tensor) + y * torch.sin(theta_tensor)
    v = -x * torch.sin(theta_tensor) + y * torch.cos(theta_tensor)
    return torch.exp(-(u ** 2 + (v / w) ** 2) / (2 * F ** 2))

# Generate the visual stimuli again
stimulation = edge_stim(angle=45, num_stim=1, pixel_h=320, pixel_w=240, plot_stimuli=True)

stimulation.shape

TypeError: edge_stim() got an unexpected keyword argument 'angle'

In [3]:
output_direct_spikes = network_direct_spikes(torch.full((10,), 100.0))  # Apply input current of 10 to each neuron in the first layer

output_direct_spikes

tensor([0., 0., 0., 0., 0.])

In [4]:
weights = network_direct_spikes.weight

weights

Parameter containing:
tensor([[-2.4717,  1.5611, -0.2171,  0.7864,  0.6821],
        [ 0.4987, -0.5248,  0.6906,  0.7741,  1.4685],
        [-0.1269, -2.1137, -0.0367,  1.0572, -1.4493],
        [-0.8473, -0.0659,  0.6453, -0.0652, -0.9526],
        [-0.1136,  0.7116, -0.3975, -1.7800, -0.7283],
        [-0.5156,  0.4596, -0.9580,  1.8983,  1.7530],
        [ 0.0429, -0.8344,  0.5445,  0.7898,  1.0992],
        [-0.7353, -0.5417, -0.3213, -0.6271, -0.5443],
        [ 1.4306,  0.6735,  0.1599,  0.4299, -1.2846],
        [ 1.2354,  0.4694,  0.2001,  0.5502, -1.3301]], requires_grad=True)