# Import Dependencies 

In [None]:
!pip install torchvision

In [None]:
import torch
from torch import save, load
from PIL import Image
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

# Define the SNN architecture

In [None]:
# Based on TPW-SDP, Leaky Integrate-and-Fire (LIF) Neurons
# Rate Coding for encoding


class SNN_TPWSDP(nn.Module):
    def __init__(self):
        super(SNN_TPWSDP, self).__init__()
        
        # Convolutional layers + tpwsdp layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.tpwsdp1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.tpwsdp2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.tpwsdp3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        # Pooling layers
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.tpwsdp1(x))
        
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = F.relu(self.tpwsdp2(x))
        
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.tpwsdp3(x))

        # Flatten the output from the convolutional layers
        x = x.view(-1, 64 * 7 * 7)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        #x = x.view(16, -1)
        return x


class TPWSDP(nn.Module):
    def __init__(self, beta=0.001):
        super(TPWSDP, self).__init__()
        self.beta = beta
        self.alpha = nn.Parameter(torch.Tensor([0.5]), requires_grad=True)
        
    def forward(self, x):
        if x.dim() == 4:
            w = x.mean(dim=(2,3), keepdim=True)
        elif x.dim() == 3:
            w = x.mean(dim=2, keepdim=True)
        else:
            raise ValueError(f"Expected input tensor to have 3 or 4 dimensions, but got {x.dim()} dimensions.")
        w = torch.sigmoid(self.alpha * (w - self.beta))
        return x * w

In [None]:
# The SNN consists of three convolutional and three twp-dsp layers, two pooling layers, and two fully connected layers.
# The TPW-SDP module is defined separately and is used in the convolutional and fully connected layers...
# ...to modify the synaptic weights of the SNN based on the temporal relationship between the spikes in the input and output layers.

In [None]:
# Used a temporal-precision-weighted spike-dependent plasticity (TPW-SDP) algorithm to train the SNN.
# A form of unsupervised learning that adjusts the synaptic weights between neurons based on the temporal precision and timing of the spikes.

# Define the training parameters

In [None]:
num_epochs = 10
batch_size = 16         # although better results are extracted when batch_size = 1
learning_rate = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the MNIST dataset and Create a data loader

In [None]:
train_dataset = torchvision.datasets.MNIST(root='data', download=True, train=True,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size)

# Define the Loss Function, Optimizer and Instance

In [None]:
# CrossEntropyLoss applies softmax to the output of the model to convert it into a probability distribution over classes.
# Then computes the negative log-likelihood of the true class under this distribution.

snn = SNN_TPWSDP().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

# Training the SNN

In [None]:
#Skip if you load the pretrained

for epoch in range(num_epochs):
    for (images, targets) in train_loader:
        
        images = images.to(device)
        targets = targets.to(device)
        
        outputs = snn(images)
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Loss: {}'.format(epoch+1, num_epochs, loss.item()))

# Save the model

In [None]:
#Skip if you load the pretrained

with open('SNN_TPW-SDP_model_state.pt', 'wb') as f:
    save(snn.state_dict(), f)

# Load the trained model's weights

In [None]:
with open('SNN_TPW-SDP_model_state.pt', 'rb') as f:
    snn.load_state_dict(load(f))

# Test an Image

In [None]:
img = Image.open("Filepath to image\\img_3.jpg")
img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)

print(torch.argmax(snn(img_tensor)))