In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import neuron
import linear
import time

In [3]:
# check gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
# modify here if you use other datasets
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import random_split

input_dimmension = 784                 # input feature의 개수 (MNIST : 784)
hidden_dimmension = 100                # hidden layer 개수
output_dimmension = 10                 # output label 개수 (MNIST : 10)
# if you want to use other dataset
#data_path = './WineQT.csv'             # dataset이 저장되어있는 경로 // dataset이 csv파일이고 마지막 feature가 output label이라고 가정


# 임의의 dataset을 사용하기 위해 만든 class
class SNN_Dataset(Dataset):
  def __init__(self, csv_path):
    df = pd.read_csv(csv_path)

    self.x_data = df.iloc[:, :-1].values
    self.y_data = df.iloc[:, -1].values

  def __len__(self):
    return len(self.x_data)
  def __getitem__(self, idx):
    return self.x_data[idx], self.y_data[idx]

In [5]:
# nearly all paramters in this simulation, but it seems no use to store them in a dict...
param = {'G_min': -1.0, # the minimum (maximum) value of conductance, I found if I use real conductance value
        'G_max': 1.0,   # which is on 1e-4 level, the gradience will be too small even in two layer net.
        'Rd': 10e9,    # this device resistance is mannually set for smaller leaky current?
        'Cm': 80e-12,   # real capacitance is absolutely larger than this valu
        'Rs': 1205000,      # this series resistance value is mannually set for larger inject current?
        'Vth': 3.6,     # this is the real device threshould voltage
        'V_reset': 3.7,
        'dt': 1.75e-4,   # every time step is dt, in the one-order differential equation of neuron
        'T_sim': 50,   # could control total spike number collected
        'dim_in': input_dimmension,
        'dim_h': hidden_dimmension,
        'dim_out': output_dimmension,
        'amp' : 3,    # the gain of TIAs
        'q_bit': 7,    # quantize bit
        'epoch': 100,
        'batch_size': 2000,
        'learning_rate': 0.022,
        'data_dir': './MNIST',
        'train_file': 'trainning_log_7bit.txt',
        'test_file': 'test_log.txt',
        'model_dir': 'Model.pth'
}

def Poisson_encoder(x):
    '''
    To encode the image pixels to poisson event.

    input: a batch of input data x.
    output: a batch of poisson encoded 1.0 or 0.0 with the same shape as x,
            the possibility of a pixle to be encoded as 1.0 is propotional to the pixel value.
    '''
    out_spike = torch.rand_like(x).le(x).float()
    return out_spike
    
    
class Three_Layer_SNN(nn.Module):
    '''
    This net model contains 2 linear layer, 2 self-defined BatchNorm layer and 2 Neuron layer.

    linear layer: a memristor crossbar on which the MAC operation is implemented.
    BatchNorm layer: a row of TIA as the output interface of the pre-linear layer, normalize the
                    output current to  -2.0~2.0 V voltage.
    neuron layer: nonliear activation, receive input voltage and output spikes, spiking rate is taken
                    in loss computing.
    '''
    def __init__(self, param):
        super().__init__()
        self.linear1 = linear.MAC_Crossbar(param['dim_in'], param['dim_h'],
                                            param['G_min'], param['G_max'], param['q_bit'])
        self.BatchNorm1 = linear.TIA_Norm(param['dim_in'], 0.0, 200.0)    # the paramters of TIA are mannually set for moderate input voltage to neurons
        self.neuron1 = neuron.LIFNeuron(param['batch_size'], param['dim_h'], param['Rd'], param['Cm'],
                                            param['Rs'], param['Vth'], param['V_reset'], param['dt'])
        self.linear2 = linear.MAC_Crossbar(param['dim_h'], param['dim_out'],
                                            param['G_min'], param['G_max'], param['q_bit'])
        self.BatchNorm2 = linear.TIA_Norm(param['dim_h'], 0.0, 200.0)    # same as above
        self.neuron2 = neuron.LIFNeuron(param['batch_size'], param['dim_out'], param['Rd'], param['Cm'],
                                            param['Rs'], param['Vth'], param['V_reset'], param['dt'])

    def forward(self, input_vector):
        out_vector = self.linear1(input_vector)
        out_vector = self.BatchNorm1(out_vector)

        self.neuron1.v = self.neuron1.v.to(device)
        self.neuron2.v = self.neuron2.v.to(device)

        out_vector = self.neuron1(out_vector)
        
        out_vector = self.linear2(out_vector)
        out_vector = self.BatchNorm2(out_vector)
        out_vector = self.neuron2(out_vector)
        
        return out_vector

    def reset_(self):
        '''
        Reset all neurons after one forward pass,
        to ensure the independency of every input image.
        '''
        for item in self.modules():
            if hasattr(item, 'reset'):
                item.reset()

    def quant_(self):
        '''
        The quantization function in pytorch only support int8,
        so we need our own quant function for adjustable quantization precision.
        '''
        for item in self.modules():
            if hasattr(item, 'Gquant_'):
                #debug print：
                #print(item.weight.max())
                item.Gquant_()
                #debug print：
                #print(item.weight.max())



In [6]:
# Poisson Encoding

trainset = torchvision.datasets.MNIST(root=param['data_dir'], train=True,
                                        download=True, transform=transforms.ToTensor())
testset = torchvision.datasets.MNIST(root=param['data_dir'], train=False,
                                        download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=param['batch_size'],
                                            shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=param['batch_size'],
                                            shuffle=False, drop_last=True, num_workers=2, pin_memory=True)

#Train the SNN with BP
net = Three_Layer_SNN(param).to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(net.parameters(), param['learning_rate'])

start = time.time()
for epoch in range(param['epoch']):
    net.train()
    train_accuracy = []
    
    for img, label in trainloader:
        img = img.reshape(-1, input_dimmension)
        # using gpu
        img, label = img.to(device), label.to(device)
        spike_num_img = torch.zeros(param['batch_size'], param['dim_out']).to(device)

        net.reset_()
        for t in range(param['T_sim']):
            new_img = Poisson_encoder(img)
            out_spike = net(new_img)
            spike_num_img += out_spike

        spike_rate = spike_num_img/param['T_sim']
        loss = loss_func(spike_rate, label)

        #net.zero_grad()
        optim.zero_grad()
        loss.backward()
        optim.step()

        with torch.no_grad():
            net.reset_()    # reset the neuron voltage every batch, to ensure independency between batchs
            net.quant_()    # quantize the weights after weight update

        train_accuracy.append((spike_rate.max(1)[1] == label).float().mean().item())
    accuracy_epoch = np.mean(train_accuracy)
    print('tranning epoch %d: the SNN loss is %.6f' %(epoch, loss), end=' ')
    print('trainning accuracy: %.4f' %accuracy_epoch, end=' ')
    
# validation by testset every epoch to see if the network is overfitted
    net.eval()
    validation_accuracy = []
    with torch.no_grad():
        for img_test, label_test in testloader:
            img_test = img_test.reshape(-1, input_dimmension)
            
            # using gpu
            spike_num_img_test = torch.zeros(param['batch_size'], param['dim_out']).to(device)
            img_test, label_test = img_test.to(device), label_test.to(device)

            net.reset_() #set the neuron voltage as reset voltage
            for t in range(param['T_sim']):
                new_img = Poisson_encoder(img_test)
                out_spike = net(new_img)
                spike_num_img_test += out_spike 
                
            validation_accuracy.append((spike_num_img_test.max(1)[1]==label_test).float().mean().item())
        accuracy_val = np.mean(validation_accuracy)
        print('validation accuracy: %.4f' %accuracy_val)
        print('elapsed time : {0:.1f}s' .format(time.time() - start))

    with open(param['train_file'], 'a') as f_t:
        s = str(epoch).ljust(6, ' ') + str(round(loss.item(), 6)).ljust(12, ' ')
        s += str(round(accuracy_epoch, 4)).ljust(10, ' ') + str(round(accuracy_val, 4)).ljust(10, ' ')
        f_t.write(s)

torch.save(net.state_dict(), param['model_dir'])

tranning epoch 0: the SNN loss is 2.257298 trainning accuracy: 0.1667 validation accuracy: 0.2531
elapsed time : 1.4s
tranning epoch 1: the SNN loss is 2.209493 trainning accuracy: 0.3732 validation accuracy: 0.4945
elapsed time : 2.7s
tranning epoch 2: the SNN loss is 2.184917 trainning accuracy: 0.5515 validation accuracy: 0.5955
elapsed time : 3.9s
tranning epoch 3: the SNN loss is 2.164946 trainning accuracy: 0.6292 validation accuracy: 0.6545
elapsed time : 5.1s
tranning epoch 4: the SNN loss is 2.154344 trainning accuracy: 0.6691 validation accuracy: 0.6768
elapsed time : 6.3s
tranning epoch 5: the SNN loss is 2.143142 trainning accuracy: 0.6971 validation accuracy: 0.7123
elapsed time : 7.5s
tranning epoch 6: the SNN loss is 2.135429 trainning accuracy: 0.7165 validation accuracy: 0.7240
elapsed time : 8.8s
tranning epoch 7: the SNN loss is 2.131116 trainning accuracy: 0.7334 validation accuracy: 0.7396
elapsed time : 10.0s
tranning epoch 8: the SNN loss is 2.132857 trainning ac

In [7]:
# Test process after training
net_test = Three_Layer_SNN(param).to(device)
print('Loading Model, please wait......')
net_test.load_state_dict(torch.load(param['model_dir'], weights_only=True))
print('Model loaded successfully!')
list_num_spike = []
for i in range(10):
    list_num_spike.append([0])
    list_num_spike[i].append(torch.zeros(param['dim_out']))

total_correct = 0  # 전체 맞춘 예측 수
total_samples = 0  # 전체 테스트 샘플 수

with torch.no_grad():
    for img_test, label_test in testloader:
        img_test = img_test.reshape(-1, 28 * 28)
        spike_num_img_test = torch.zeros(param['batch_size'], param['dim_out']).to(device)
        net_test.reset_()  # set the neuron voltage as reset voltage

        # using gpu
        img_test, label_test = img_test.to(device), label_test.to(device)

        net.reset_()  # set the neuron voltage as reset voltage

        for t in range(param['T_sim']):
            new_test_img = Poisson_encoder(img_test)
            out_spike = net_test(new_test_img)
            spike_num_img_test += out_spike
            # update mask

        pred_label = F.one_hot(spike_num_img_test.max(1)[1], num_classes=10).to('cpu')  # convert the max neuron output index to onehot vector

        # Accuracy calculation
        pred_label_class = spike_num_img_test.max(1)[1]  # predicted class for each image in the batch
        total_correct += (pred_label_class == label_test).sum().item()  # count correct predictions
        total_samples += label_test.size(0)  # update total samples

        for j in range(label_test.size(0)):
            index = label_test[j]
            list_num_spike[index][0] += 1
            list_num_spike[index][1] += pred_label[j].to('cpu')  # statistics of prediction for every input image

# Calculate and print overall accuracy
accuracy = total_correct / total_samples * 100
print(f'Test Accuracy: {accuracy:.2f}%')

# Save confusion matrix
with open('confusion_matrix.txt', 'a') as f2:
    for i in range(len(list_num_spike)):
        s = str(list_num_spike[i][0]) + ' ' + str(list_num_spike[i][1].numpy()).replace('[', '').replace(']', '') + '\n'
        f2.write(s)

Loading Model, please wait......
Model loaded successfully!
Test Accuracy: 83.91%
