In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import init
import torchvision.transforms as transforms
import numpy as np
import sys
from demo_utils import snn_utils as snnu
from demo_utils import demo_utils as du

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

batch_size = 256

transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor()
            # transforms.Normalize((0.5,), (0.5,))
            ])

fmnist_train = torchvision.datasets.FashionMNIST(root='D:/SoftProject/Python/PyTorch_demo/Datasets/FashionMNIST',train=True, download=True, transform=transform)
fmnist_test = torchvision.datasets.FashionMNIST(root='D:/SoftProject/Python/PyTorch_demo/Datasets/FashionMNIST', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(fmnist_test, batch_size=batch_size, shuffle=False)




In [2]:
num_inputs = 784
num_hiddens= 1500
num_outputs = 10

num_epochs=10
lr=0.01

# Temporal Dynamics
num_steps = 25
beta = 0.95

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

from collections import OrderedDict


net = nn.Sequential(
    # FlattenLayer(),       # unnamed layer, use layer1.parameters() or layer1.weight to access
    # nn.Linear(num_inputs, num_outputs)
    OrderedDict([
        # ('rate', snnu.RateEncodingLayer(num_steps=num_steps)),
        ('linear1', nn.Linear(num_inputs, num_hiddens)),
        ('leaky1', snn.Leaky(beta=beta, init_hidden=True)), # activate `init_hidden` to use nn.Sequential. It will cause Leaky layer only returns spikes
        # ('active', ReLuFunction()),
        # ('active', nn.ReLU()),
        ('linear2',nn.Linear(num_hiddens, num_outputs)),
        ('leaky2', snn.Leaky(beta=beta, init_hidden=True, output=True)) # Last layer set `output=true`, to output membrane potential
    ])
).to(device)

for _, layer in net.named_modules():
    if isinstance(layer, nn.Linear):
        
        init.normal_(layer.weight, mean=0, std=0.01)
        init.constant_(layer.bias, val=0)


class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hiddens)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hiddens, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net2 = Net().to(device)

loss = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(net.parameters(), lr=lr)



In [3]:
# findex=du.get_next_demo_index("rec/","demo")

# print(findex)

In [4]:
# results=list()

# filename="Test"

# for i in range(4):
#     results.append(np.random.uniform(0,100,10))

# du.store_record(results,f"{filename}",one_time=True)

# res_fr_file=du.load_sin_res(f"rec/{filename}.csv")

# du.plot_acc(len(res_fr_file[0]),res_fr_file[0],res_fr_file[1],res_fr_file[2],res_fr_file[3],
#              suptitle=filename)

In [5]:
# num_steps=25

# X,y=next(iter(train_loader))

# spk_in=snnu.rate_encoding_images(X,num_steps)

# spk_out,mem_out=net(spk_in[0])

# print(X)

# print(spk_in)

# print(spk_out.shape)
# print(spk_out)
# print(mem_out.shape)
# print(mem_out)

# # mem_rec = []
# # spk_rec = []
# # for step in range(num_steps):
# #             spk_out, mem_out = net(spk_in[step])
# #             spk_rec.append(spk_out)
# #             mem_rec.append(mem_out)
# # print(mem_rec.shape)
# # print(spk_rec.shape)

In [9]:
num_steps=25

X,y=next(iter(train_loader))
X=X.to(device)
y=y.to(device)

# print(X)
print(X.shape)


spk_in=snnu.rate_encoding_images(X,num_steps)

# print(spk_in)
print(spk_in.shape)

out11,out12=snnu.forward_pass(net,num_steps,spk_in).spkin()

X,y=next(iter(train_loader))
X=X.to(device)
y=y.to(device)

out21,out22=net2(X.view(batch_size,-1))

print(out11.shape)
print(out12.shape)
print(out21.shape)
print(out22.shape)

torch.Size([256, 1, 28, 28])
torch.Size([25, 256, 784])
torch.Size([25, 256, 10])
torch.Size([25, 256, 10])
torch.Size([25, 256, 10])
torch.Size([25, 256, 10])
