In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
import torch.nn.functional as tf
import torchvision
import numpy as np
from SpykeTorch import snn
from SpykeTorch import functional as sf
from SpykeTorch import visualization as vis
from SpykeTorch import utils
from torchvision import transforms

use_cuda = False

### SNN Module

In [2]:
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()

        self.conv1 = snn.Convolution(in_channels=2, out_channels=32, kernel_size=5, padding=2, weight_mean=0.8, weight_std=0.05)
        self.conv1_t = 10
        self.k1 = 5
        self.r1 = 2

        self.conv2 = snn.Convolution(32, 150, 2, 1, 0.8, 0.05)
        self.conv2_t = 1
        self.k2 = 8
        self.r2 = 1

        self.tconv1 = snn.TransposeConvolution(in_channels=32,out_channels=2,kernel_size=5,padding=2)
        self.tconv2 = snn.TransposeConvolution(150,32,2,1)

        self.stdp1 = snn.STDP(self.conv1, (0.004, -0.003))
        self.stdp2 = snn.STDP(self.conv2, (0.004, -0.003))
        self.max_ap = Parameter(torch.Tensor([0.15]))

        self.ctx = {"input_spikes":None, "potentials":None, "output_spikes":None, "winners":None}
        self.spk_cnt1 = 0
        self.spk_cnt2 = 0

    def forward(self, input, max_layer):
        spk_in = input.float()

        if self.training:
            # convolution 1
            pot = self.conv1(spk_in)
            spk, pot = sf.fire(pot, self.conv1_t, True)

            # apply stdp to convolution 1
            if max_layer == 1:
                self.spk_cnt1 += 1
                if self.spk_cnt1 >= 500:
                    self.spk_cnt1 = 0
                    ap = torch.tensor(self.stdp1.learning_rate[0][0].item(), device=self.stdp1.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp1.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k1, self.r1, spk)
                self.ctx["input_spikes"] = input
                self.ctx["potentials"] = pot
                self.ctx["output_spikes"] = spk
                self.ctx["winners"] = winners
                return spk, pot
            
            # max pool 1
            spk_pooled, pool1_indices = tf.max_pool2d(spk, 2, 2,return_indices=True)

            # print(f"input conv1:\t{spk_in.shape}")
            # print(f"output conv1:\t{spk.shape}")

            # print(f"pool1:\t\t{spk_pooled.shape}")

            # convolution 2
            spk_in = spk_pooled
            pot = self.conv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)

            # apply stdp to convolution 2
            if max_layer == 2:
                self.spk_cnt2 += 1
                if self.spk_cnt2 >= 500:
                    self.spk_cnt2 = 0
                    ap = torch.tensor(self.stdp2.learning_rate[0][0].item(), device=self.stdp2.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp2.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k2, self.r2, spk)
                self.ctx["input_spikes"] = spk_in
                self.ctx["potentials"] = pot
                self.ctx["output_spikes"] = spk
                self.ctx["winners"] = winners
                return spk, pot
            
            # print(f"input conv2:\t{spk_in.shape}")
            # print(f"output conv2:\t{spk.shape}")

            # load weights for transpose convolutions
            self.tconv2.load_weight(torch.transpose(self.conv2.weight,2,3))
            self.tconv1.load_weight(torch.transpose(self.conv1.weight,2,3))
            
            # transpose convolution 2
            spk_in = spk
            pot = self.tconv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)

            # max unpool 1
            spk_unpooled = tf.max_unpool2d(spk, pool1_indices, 2, 2)

            # print(f"input tconv2:\t{spk_in.shape}")
            # print(f"output conv2:\t{spk.shape}")
            
            # print(f"unpool2:\t{spk_unpooled.shape}")

            # transpose convolution 1
            spk_in = spk_unpooled
            pot = self.tconv1(spk_in)
            spk = sf.fire(pot, self.conv1_t)

            # print(f"input tconv2:\t{spk_in.shape}")
            # print(f"output conv2:\t{spk.shape}")

            return spk, pot
        else:
            # convolution 1
            pot = self.conv1(spk_in)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            if max_layer == 1:
                return spk, pot
            
            # max pool 1
            spk_pooled, pool1_indices = tf.max_pool2d(spk, 2, 2,return_indices=True)

            # convolution 2
            spk_in = spk_pooled
            pot = self.conv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)
            if max_layer == 2:
                return spk, pot
            
            # load weights for transpose convolutions
            self.tconv2.load_weight(torch.transpose(self.conv2.weight,2,3))
            self.tconv1.load_weight(torch.transpose(self.conv1.weight,2,3))
            
            # transpose convolution 2
            spk_in = spk
            pot = self.tconv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)

            # max unpool 1
            spk_unpooled = tf.max_unpool2d(spk, pool1_indices, 2, 2)

            # transpose convolution 1
            spk_in = spk_unpooled
            pot = self.tconv1(spk_in)
            spk = sf.fire(pot, self.conv1_t)

            return spk
    
    def stdp(self, layer_idx):
        if layer_idx == 1:
            self.stdp1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 2:
            self.stdp2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

In [3]:
# import torch
# import torch.nn as nn
# import matplotlib.pyplot as plt
# import torchvision


# # Define some random input data
# input_data = torch.randn(1, 2, 28, 28)  # Batch size of 1, 1 channel, 5x5 input
# plt.imshow(input_data[0][0])
# plt.show()
# print(input_data.shape)

# # Define a 2D convolution layer with padding
# conv = snn.Convolution(in_channels=2, out_channels=32, kernel_size=5, padding=2, weight_mean=0.8, weight_std=0.05)

# # Apply the convolution
# conv_output = conv(input_data)
# print(conv_output.shape)

# # Define a transpose convolution layer with the same parameters
# tconv = snn.TransposeConvolution(in_channels=32,out_channels=2,kernel_size=5,padding=2)
# tconv.load_weight(torch.transpose(conv.weight, 2,3))
# # print(conv.weight[0][0],'\n',tconv.weight[0][0])

# # Apply the transpose convolution
# trans_conv_output = tconv(conv_output)
# print(trans_conv_output.shape)
# plt.imshow(trans_conv_output[0][0])
# plt.show()

# # Check if the output of the transpose convolution matches the input
# print(torch.allclose(input_data, trans_conv_output, atol=1e-5))

### Training/Testing functions

In [4]:
def train_unsupervise(network, data, layer_idx):
    network.train()
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        network(data_in, layer_idx)
        network.stdp(layer_idx)

def test(network, data, target, layer_idx):
    network.eval()
    ans = [None] * len(data)
    t = [None] * len(data)
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        output,_ = network(data_in, layer_idx).max(dim = 0)
        ans[i] = output.reshape(-1).cpu().numpy()
        t[i] = target[i]
    return np.array(ans), np.array(t)

### Convert to Spike Wave

In [5]:
class S1Transform:
    def __init__(self, filter, timesteps = 15):
        self.to_tensor = transforms.ToTensor()
        self.filter = filter
        self.temporal_transform = utils.Intensity2Latency(timesteps)
        self.cnt = 0
    def __call__(self, image):
        if self.cnt % 1000 == 0:
            print(self.cnt)
        self.cnt+=1
        image = self.to_tensor(image) * 255
        image.unsqueeze_(0)
        image = self.filter(image)
        image = sf.local_normalization(image, 8)
        temporal_image = self.temporal_transform(image)
        return temporal_image.sign().byte()

kernels = [ utils.DoGKernel(7,1,2),
            utils.DoGKernel(7,2,1),]
filter = utils.Filter(kernels, padding = 3, thresholds = 50)
s1 = S1Transform(filter)

### Data Prep

In [6]:
data_root = "data"
MNIST_train = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform = s1))
MNIST_test = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform = s1))
MNIST_loader = DataLoader(MNIST_train, batch_size=len(MNIST_train), shuffle=False)
MNIST_testLoader = DataLoader(MNIST_test, batch_size=len(MNIST_test), shuffle=False)

### Train SNN

In [7]:
kheradpisheh = SNN()
if use_cuda:
    kheradpisheh.cuda()

In [8]:
# data,target = next(iter(MNIST_loader))
# train_unsupervise(kheradpisheh, data, 1)

In [9]:
# train_unsupervise(kheradpisheh, data, 3)

In [10]:
# Training The First Layer
print("Training the first layer")
if os.path.isfile("saved_l1.net"):
    kheradpisheh.load_state_dict(torch.load("saved_l1.net"))
else:
    for epoch in range(2):
        print("Epoch", epoch)
        iter = 0
        for data,_ in MNIST_loader:
            print("Iteration", iter)
            train_unsupervise(kheradpisheh, data, 1)
            print("Done!")
            iter+=1
    torch.save(kheradpisheh.state_dict(), "saved_l1.net")

Training the first layer
Epoch 0
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
Iteration 0
HERE1
(15, 2, 28, 28) (15, 32, 28, 28) (15, 32, 28, 28) (5, 3)
YO1
YO2
YO3
torch.Size([2, 4, 5]) torch.Size([5, 5])


RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1

In [None]:
# Training The Second Layer
print("Training the second layer")
if os.path.isfile("saved_l2.net"):
    kheradpisheh.load_state_dict(torch.load("saved_l2.net"))
for epoch in range(20):
    print("Epoch", epoch)
    iter = 0
    for data,_ in MNIST_loader:
        print("Iteration", iter)
        train_unsupervise(kheradpisheh, data, 2)
        print("Done!")
        iter+=1
torch.save(kheradpisheh.state_dict(), "saved_l2.net")

### Determine Performance

In [None]:
from sklearn.svm import LinearSVC

In [None]:
# Classification
# Get train data
for data,target in MNIST_loader:
    train_X, train_y = test(kheradpisheh, data, target, 2)
    

# Get test data
for data,target in MNIST_testLoader:
    test_X, test_y = test(kheradpisheh, data, target, 2)

# SVM
clf = LinearSVC(C=2.4)
clf.fit(train_X, train_y)
predict_train = clf.predict(train_X)
predict_test = clf.predict(test_X)

def get_performance(X, y, predictions):
    correct = 0
    silence = 0
    for i in range(len(predictions)):
        if X[i].sum() == 0:
            silence += 1
        else:
            if predictions[i] == y[i]:
                correct += 1
    return (correct/len(X), (len(X)-(correct+silence))/len(X), silence/len(X))

print(get_performance(train_X, train_y, predict_train))
print(get_performance(test_X, test_y, predict_test))