<a href="https://colab.research.google.com/github/Maya7991/gsc_classification/blob/main/test_aifes_conv_lif.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install snntorch --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import spikeplot as splt
import torch.nn.functional as F

In [3]:
# Helper to cleanly print 2D tensors
def print_2d_tensor(tensor, title=""):
    if title:
        print(title)
    arr = tensor.squeeze().detach().cpu().numpy()
    for row in arr:
        print("  ", ["{:6.2f}".format(x) for x in row])
    print()

In [7]:
# Set random seed for reproducibility
torch.manual_seed(0)
beta = 1.0
T = 2  # timesteps
N = 1  # batch size
C_in = 2
H = W = 3
F = 2

# Input: [T, N, C, H, W] -> [2, 1, 1, 3, 3]
# input_data = torch.tensor([
#     [[[[1.0, 2.0, 3.0],
#        [4.0, 5.0, 6.0],
#        [7.0, 8.0, 9.0]]]],

#     [[[[9.0, 8.0, 7.0],
#        [6.0, 5.0, 4.0],
#        [3.0, 2.0, 1.0]]]]
# ])
input_data = torch.tensor([
    [
        [  # timestep 1
            [[1, 2, 3], [4, 5, 6], [7, 8, 9]],  # Channel 0
            [[10, 20, 30], [40, 50, 60], [70, 80, 90]]   # Channel 1
        ]
    ],
    [
        [  # timestep 2
            [[9, 8, 7], [6, 5, 4], [3, 2, 1]] ,  # Channel 0
            [[90, 80, 70], [60, 50, 40], [30, 20, 10]]  # Channel 1
        ]
    ]
], dtype=torch.float32)

# weights shape: (out_channels, in_channels, kernel_height, kernel_width)
weights = torch.tensor([
    [  # Filter 1
        [[1, 0],
         [0, -1]],
        [[17, -6],
         [-7, 4]]
    ]
    ,[  # Filter 2
        [[-2, 1],
         [6, 11]],
        [[-3, 5],
         [1, 13]]
    ]
], dtype=torch.float32)

In [8]:
# Manually define Conv2d layer with 1 filter of shape 1x2x2
conv = nn.Conv2d(in_channels=C_in, out_channels=F, kernel_size=2, stride=1, bias=False)
lif = snn.Leaky(beta=beta)

with torch.no_grad():
  conv.weight.copy_(weights)

In [9]:

# Track spikes and membrane potential
spk_out = []
mem_out = []

# Simulate over 2 timesteps (T=2)
for t in range(2):
    x_t = input_data[t]         # Shape: [1, 1, 3, 3]
    conv_out = conv(x_t)        # Shape: [1, 1, 2, 2]
    spk, mem = lif(conv_out)
    spk_out.append(spk)
    mem_out.append(mem)
    # print(conv_out)

# Stack outputs
spk_out = torch.stack(spk_out)  # Shape: [2, 1, 1, 2, 2]
mem_out = torch.stack(mem_out)  # Shape: [2, 1, 1, 2, 2]

In [14]:
# Print outputs
print("Convolution output at each timestep:")
for t in range(2):
    print(f"Time step {t}:\n", conv(input_data[t]))

Convolution output at each timestep:
Time step 0:
 tensor([[[[ -34.,   46.],
          [ 206.,  286.]],

         [[ 839., 1015.],
          [1367., 1543.]]]], grad_fn=<ConvolutionBackward0>)
Time step 1:
 tensor([[[[834., 754.],
          [594., 514.]],

         [[921., 745.],
          [393., 217.]]]], grad_fn=<ConvolutionBackward0>)


In [15]:
print("\nLIF membrane potential:")
print(mem_out)

print("\nLIF spike output:")
print(spk_out)


LIF membrane potential:
tensor([[[[[ -34.,   46.],
           [ 206.,  286.]],

          [[ 839., 1015.],
           [1367., 1543.]]]],



        [[[[ 800.,  799.],
           [ 799.,  799.]],

          [[1759., 1759.],
           [1759., 1759.]]]]], grad_fn=<StackBackward0>)

LIF spike output:
tensor([[[[[0., 1.],
           [1., 1.]],

          [[1., 1.],
           [1., 1.]]]],



        [[[[1., 1.],
           [1., 1.]],

          [[1., 1.],
           [1., 1.]]]]], grad_fn=<StackBackward0>)


In [13]:
print("=== Convolution Output at Each Timestep ===")
for t in range(T):
    print(f"Time step {t}:")
    # print_2d_tensor(conv(input_data[t]))
    # print_2d_tensor(spk_out[t], title="Spikes:")
    # print_2d_tensor(mem_out[t], title="Membrane:")
    # print()

print("=== LIF Membrane Potential at Each Timestep ===")
for t in range(T):
    print(f"Time step {t}:")
    print_2d_tensor(mem_out[t], title="Membrane:")

print("=== LIF Spike Output at Each Timestep ===")
for t in range(T):
    print(f"Time step {t}:")
    print_2d_tensor(spk_out[t], title="Spikes:")


=== Convolution Output at Each Timestep ===
Time step 0:
Time step 1:
=== LIF Membrane Potential at Each Timestep ===
Time step 0:
Membrane:


TypeError: unsupported format string passed to numpy.ndarray.__format__