# Import Dependencies 

In [None]:
!pip install torchvision

In [None]:
import torch
from torch import save, load
from PIL import Image
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd.function import Function
import numpy as np
import matplotlib.pyplot as plt

# Define the SNN architecture

In [None]:
class LIFNeuron(nn.Module):
    def __init__(self, threshold=1.0, reset=0.0):
        super().__init__()
        self.threshold = threshold
        self.reset = reset

    def forward(self, x, dt=1.0):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, channels, width * height)

        # Initialize membrane potential and output spike
        v = torch.zeros_like(x)
        spike = torch.zeros_like(x)

        # Update membrane potential and output spike iteratively
        for t in range(int(dt)):
            dv = (x - v) / self.threshold
            v += dv
            spk = (v >= 1.0).float()
            spike += spk
            v -= spk * (self.threshold - self.reset)

        # Update membrane potential with remaining input
        dt -= int(dt)
        if dt > 0:
            dt = torch.tensor(dt).to(x.device)
            dv = (x - v) / self.threshold
            v += dv * torch.exp(-dt)
            spk = (v >= 1.0).float()
            spike += spk
            v -= spk * (self.threshold - self.reset)

        # Reshape spike back to the original size
        spike = spike.view(batch_size, channels, width, height)

        return spike, v
    
class LIFLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.neuron = LIFNeuron()

    def forward(self, x, dt=1.0):
        x = self.conv(x)
        x, v = self.neuron(x, dt)
        return x

class LIFClassifier(nn.Module):
    def __init__(self):
        super(LIFClassifier, self).__init__()
        self.conv1 = LIFLayer(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2)
        self.conv2 = LIFLayer(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.conv3 = LIFLayer(64, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear((64 * 28 // 4) * (28 // 4), 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.pool1(x))
        x = self.conv2(x)
        x = F.relu(self.pool2(x))
        x = self.conv3(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
# Three spiking convolutional layers followed by Two spiking linear layers
# Input encoding used in this code: each pixel value representing the intensity of the corresponding pixel in the image

# Define a custom spiking activation function
class SpikeActivation(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.gt(0).float()

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[x <= 0] = 0
        return grad_input

# Define a spiking linear layer
class SpikeLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(SpikeLinear, self).__init__(in_features, out_features, bias)
        self.threshold = nn.Parameter(torch.Tensor([0.5]))

    def forward(self, input):
        output = F.linear(input, self.weight, self.bias)
        return SpikeActivation.apply(output - self.threshold)

# Define a spiking convolutional layer
class SpikeConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(SpikeConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.threshold = nn.Parameter(torch.Tensor([0.5]))

    def forward(self, input):
        output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return SpikeActivation.apply(output - self.threshold)

class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()

        # Define convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Define fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 10)

        # Define threshold parameter for spiking activation
        self.threshold = nn.Parameter(torch.Tensor([0.5]), requires_grad=True)

    def forward(self, x):
        # Apply convolutional layers with ReLU activation
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))

        # Flatten the output and apply fully connected layers with ReLU activation
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        # Apply spiking activation function
        x = x - self.threshold
        x = x / torch.max(torch.abs(x))
        x = torch.sigmoid(100 * (x - 0.2))

        return x

'''class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()

        # Spiking convolutional layers
        self.conv1 = SpikeConv2d(1, 32, kernel_size=3, stride=1, padding=1)
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        self.conv2 = SpikeConv2d(32, 64, kernel_size=3, stride=1, padding=1)
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        self.conv3 = SpikeConv2d(64, 64, kernel_size=3, stride=1, padding=1)
        nn.init.kaiming_uniform_(self.conv3.weight, nonlinearity='relu')

        # Spiking linear layers
        self.fc1 = SpikeLinear(64 * 7 * 7, 512)
        nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')
        self.fc2 = SpikeLinear(512, 10)
        nn.init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu')

    def forward(self, x):
        x = torch.poisson(x)
        x = self.conv1(x)
        x = F.avg_pool2d(x, 2)
        x = self.conv2(x)
        x = F.avg_pool2d(x, 2)
        x = self.conv3(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc1(x)
        x = self.fc2(x)
        return x'''

# Define the spiking neural network
'''class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()

        # Spiking convolutional layers
        self.conv1 = SpikeConv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = SpikeConv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = SpikeConv2d(64, 64, kernel_size=3, stride=1, padding=1)

        # Spiking linear layers
        self.fc1 = SpikeLinear(64 * 7 * 7, 512)
        self.fc2 = SpikeLinear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.avg_pool2d(x, 2)
        x = self.conv2(x)
        x = F.avg_pool2d(x, 2)
        x = self.conv3(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc1(x)
        x = self.fc2(x)
        return x'''

In [None]:
# Based on TPW-SDP, Weight Normalization to simulate Spiking


class SNN_TPWSDP(nn.Module):
    def __init__(self):
        super(SNN_TPWSDP, self).__init__()
        
        # Convolutional layers + tpwsdp layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.tpwsdp1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.tpwsdp2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.tpwsdp3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        # Pooling layers
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.tpwsdp1(x))
        
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = F.relu(self.tpwsdp2(x))
        
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.tpwsdp3(x))

        # Flatten the output from the convolutional layers
        x = x.view(-1, 64 * 7 * 7)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        #x = x.view(16, -1)
        return x


class TPWSDP(nn.Module):
    def __init__(self, beta=0.001):
        super(TPWSDP, self).__init__()
        self.beta = beta
        self.alpha = nn.Parameter(torch.Tensor([0.5]), requires_grad=True)
        
    def forward(self, x):
        if x.dim() == 4:
            w = x.mean(dim=(2,3), keepdim=True)
        elif x.dim() == 3:
            w = x.mean(dim=2, keepdim=True)
        else:
            raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {x.dim()} dimensions.")
        w = torch.sigmoid(self.alpha * (w - self.beta))
        return x * w

In [None]:
# The S-CNN consists of three convolutional and three twp-dsp layers, two pooling layers, and two fully connected layers.
# The TPW-SDP module is defined separately and is used in the convolutional and fully connected layers...
# ...to modify the synaptic weights of the SNN based on the temporal relationship between the spikes in the input and output layers.

In [None]:
# Used a simulation of temporal-precision-weighted spike-dependent plasticity (TPW-SDP) algorithm to train the SNN.
# A form of unsupervised learning that adjusts the synaptic weights between neurons based on the temporal precision and timing of the spikes.

# Define the training parameters

In [None]:
num_epochs = 10
batch_size = 16         # although better results are extracted when batch_size = 1
learning_rate = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the MNIST dataset and Create a data loader

In [None]:
train_dataset = torchvision.datasets.MNIST(root='data', download=True, train=True,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size)

# Define the Loss Function, Optimizer and Instance

In [None]:
# CrossEntropyLoss applies softmax to the output of the model to convert it into a probability distribution over classes.
# Then computes the negative log-likelihood of the true class under this distribution.

snn = LIFClassifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

# Training the SNN

In [None]:
#Skip if you load the pretrained

for epoch in range(num_epochs):
    for (images, targets) in train_loader:
        
        images = images.to(device)
        targets = targets.to(device)
        
        outputs = snn(images)
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Loss: {}'.format(epoch+1, num_epochs, loss.item()))

# Save the model

In [None]:
#Skip if you load the pretrained

with open('SNN.pt', 'wb') as f:
    save(snn.state_dict(), f)

# Load the trained model's weights

In [None]:
with open('SNN.pt', 'rb') as f:
    snn.load_state_dict(load(f))

# Test an Image

In [None]:
img = Image.open("Z:\\Jupyter Scripts\\PyTorch Object Detection\\PTODSNN\\Pytorch\\data\\test\\img_3.jpg")
img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)

print(torch.argmax(snn(img_tensor)))