In [3]:
import torch
import os
from my_snn.tonic_dataloader import DatasetLoader
from my_snn.abstract_rsnn import CHECKPOINT_PATH
path = os.path.join(CHECKPOINT_PATH,'bojian_model')


dataset = 'shd'
time_window = 250
batch_size = 128 # lr=1e-4
#batch_size = 128 # lr=1e-4
DL = DatasetLoader(dataset=dataset, caching='disk', num_workers=0, batch_size=batch_size, time_window=time_window)
test_loader, train_loader = DL.get_dataloaders()

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import math
import torch.nn.functional as F
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter


'''
STEP 2: MAKING DATASET ITERABLE
'''

decay = 0.1  # neuron decay rate
thresh = 0.5  # neuronal threshold
lens = 0.5  # hyper-parameters of approximate function
num_epochs = 20  # 150  # n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

'''
STEP 3a: CREATE spike MODEL CLASS
'''

b_j0 = 0.01  # neural threshold baseline
R_m = 1  # membrane resistance
dt = 1  #
gamma = .5  # gradient scale


gradient_type = 'MG'
print('gradient_type: ',gradient_type)
def gaussian(x, mu=0., sigma=.5):
    return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(math.pi)) / sigma


# define approximate firing function

class ActFun_adp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):  # input = membrane potential- threshold
        ctx.save_for_backward(input)
        return input.gt(0).float()  # is firing ???

    @staticmethod
    def backward(ctx, grad_output):  # approximate the gradients
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # temp = abs(input) < lens
        scale = 6.0
        hight = .15
        if gradient_type == 'G':
            temp = torch.exp(-(input**2)/(2*lens**2))/torch.sqrt(2*torch.tensor(math.pi))/lens
        elif gradient_type == 'MG':
            temp = gaussian(input, mu=0., sigma=lens) * (1. + hight) \
                - gaussian(input, mu=lens, sigma=scale * lens) * hight \
                - gaussian(input, mu=-lens, sigma=scale * lens) * hight
        elif gradient_type =='linear':
            temp = F.relu(1-input.abs())
        elif gradient_type == 'slayer':
            temp = torch.exp(-5*input.abs())
        return grad_input * temp.float() * gamma


act_fun_adp = ActFun_adp.apply


# tau_m = torch.FloatTensor([tau_m])

def mem_update_adp(inputs, mem, spike, tau_adp, b, tau_m, dt=1, isAdapt=1):
    alpha = torch.exp(-1. * dt / tau_m).cuda()
    ro = torch.exp(-1. * dt / tau_adp).cuda()
    if isAdapt:
        beta = 1.8
    else:
        beta = 0.

    b = ro * b + (1 - ro) * spike
    B = b_j0 + beta * b

    mem = mem * alpha + (1 - alpha) * R_m * inputs - B * spike * dt
    inputs_ = mem - B
    spike = act_fun_adp(inputs_)  # act_fun : approximation firing function
    return mem, spike, B, b


def output_Neuron(inputs, mem, tau_m, dt=1):
    """
    The read out neuron is leaky integrator without spike
    """
    # alpha = torch.exp(-1. * dt / torch.FloatTensor([30.])).cuda()
    alpha = torch.exp(-1. * dt / tau_m).cuda()
    mem = mem * alpha + (1. - alpha) * R_m * inputs
    return mem


class RNN_custom(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN_custom, self).__init__()

        self.hidden_size = hidden_size
        # self.hidden_size = input_size
        self.i_2_h1 = nn.Linear(input_size, hidden_size[0])
        self.h1_2_h1 = nn.Linear(hidden_size[0], hidden_size[0])
        self.h1_2_h2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.h2_2_h2 = nn.Linear(hidden_size[1], hidden_size[1])

        self.h2o = nn.Linear(hidden_size[1], output_size)

        self.tau_adp_h1 = nn.Parameter(torch.Tensor(hidden_size[0]))
        self.tau_adp_h2 = nn.Parameter(torch.Tensor(hidden_size[1]))
        self.tau_adp_o = nn.Parameter(torch.Tensor(output_size))
        self.tau_m_h1 = nn.Parameter(torch.Tensor(hidden_size[0]))
        self.tau_m_h2 = nn.Parameter(torch.Tensor(hidden_size[1]))
        self.tau_m_o = nn.Parameter(torch.Tensor(output_size))

        nn.init.orthogonal_(self.h1_2_h1.weight)
        nn.init.orthogonal_(self.h2_2_h2.weight)
        nn.init.xavier_uniform_(self.i_2_h1.weight)
        nn.init.xavier_uniform_(self.h1_2_h2.weight)
        nn.init.xavier_uniform_(self.h2_2_h2.weight)
        nn.init.xavier_uniform_(self.h2o.weight)

        nn.init.constant_(self.i_2_h1.bias, 0)
        nn.init.constant_(self.h1_2_h2.bias, 0)
        nn.init.constant_(self.h2_2_h2.bias, 0)
        nn.init.constant_(self.h1_2_h1.bias, 0)
        

        # nn.init.constant_(self.tau_adp_h1,150)
        # nn.init.constant_(self.tau_adp_h2, 150)
        # nn.init.constant_(self.tau_adp_o, 150)
        # nn.init.constant_(self.tau_m_h1, 20.)
        # nn.init.constant_(self.tau_m_h2, 20.)
        # nn.init.constant_(self.tau_m_o, 20.)

        nn.init.normal_(self.tau_adp_h1,150,10)
        nn.init.normal_(self.tau_adp_h2, 150,10)
        nn.init.normal_(self.tau_adp_o, 150,10)
        nn.init.normal_(self.tau_m_h1, 20.,5)
        nn.init.normal_(self.tau_m_h2, 20.,5)
        nn.init.normal_(self.tau_m_o, 20.,5)


        self.b_h1 = self.b_h2 = self.b_o = 0

    def forward(self, input):
        batch_size, seq_num, input_dim = input.shape
        self.b_h1 = self.b_h2 = self.b_o = b_j0

        mem_layer1 = spike_layer1 = torch.rand(batch_size, self.hidden_size[0]).cuda()
        mem_layer2 = spike_layer2 = torch.rand(batch_size, self.hidden_size[1]).cuda()
        mem_output = torch.rand(batch_size, output_dim).cuda()
        output = torch.zeros(batch_size, output_dim).cuda()

        hidden_spike_ = []
        fr = []
        hidden_mem_ = []
        h2o_mem_ = []

        for i in range(seq_num):
            input_x = input[:, i, :]

            h_input = self.i_2_h1(input_x.float()) + self.h1_2_h1(spike_layer1)
            mem_layer1, spike_layer1, theta_h1, self.b_h1 = mem_update_adp(h_input, mem_layer1, spike_layer1,
                                                                         self.tau_adp_h1, self.b_h1,self.tau_m_h1)
            
            h2_input = self.h1_2_h2(spike_layer1) + self.h2_2_h2(spike_layer2)
            mem_layer2, spike_layer2, theta_h2, self.b_h2 = mem_update_adp(h2_input, mem_layer2, spike_layer2,
                                                                         self.tau_adp_h2, self.b_h2, self.tau_m_h2)
            
            mem_output = output_Neuron(self.h2o(spike_layer2), mem_output, self.tau_m_o)
            #mem_output = output_Neuron(self.h2o(spike_layer1), mem_output, self.tau_m_o)

            
            if i > 10:
                output= output + F.softmax(mem_output, dim=1)#F.softmax(mem_output, dim=1)#

            hidden_spike_.append(spike_layer1.data.cpu().numpy())
            # hidden_spike2_.append(spike_layer2.data.cpu().numpy())
            hidden_mem_.append(spike_layer2.data.cpu().numpy())
            h2o_mem_.append(mem_output.data.cpu().numpy())
            fr.append((spike_layer1.detach().mean().cpu().numpy()+spike_layer2.detach().mean().cpu().numpy())/2.)

        return output, hidden_spike_, hidden_mem_, h2o_mem_,np.mean(fr)


def test(model, dataloader=test_loader,is_test=0):
    correct = 0
    total = 0
    # Iterate through test dataset
    for images, labels in dataloader:
        images = images.view(-1, seq_dim, input_dim).to(device)
        
        _, labels = torch.max(labels.data, 1)

        outputs, _,_,_,fr_ = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        if torch.cuda.is_available():
            correct += (predicted.cpu() == labels.long().cpu()).sum()
        else:
            correct += (predicted == labels).sum()

    accuracy = 100. * correct.numpy() / total
    if is_test:
        print('Mean FR: ', np.array(fr_).mean())
    return accuracy
    
def predict(model):
    # Iterate through test dataset
    result = np.zeros(1)
    for images, labels in test_loader:
        images = images.view(-1, seq_dim, input_dim).to(device)

        outputs, _,_,_,_ = model(images)
        # _, Predicted = torch.max(outputs.data, 1)
        # result.append(Predicted.data.cpu().numpy())
        predicted_vec = outputs.data.cpu().numpy()
        Predicted = predicted_vec.argmax(axis=1)
        result = np.append(result,Predicted)
    return np.array(result[1:]).flatten()



gradient_type:  MG


In [4]:
model = torch.load(path+'1.0-frame-v3-inte.pth')
# model.load_state_dict(models['model_state_dict'])
model.eval()

RNN_custom(
  (i_2_h1): Linear(in_features=700, out_features=512, bias=True)
  (h1_2_h1): Linear(in_features=512, out_features=512, bias=True)
  (h1_2_h2): Linear(in_features=512, out_features=512, bias=True)
  (h2_2_h2): Linear(in_features=512, out_features=512, bias=True)
  (h2o): Linear(in_features=512, out_features=20, bias=True)
)

In [8]:
seq_dim = 250
input_dim = 700
#hidden_dim = [48,48]
hidden_dim = [512,512]  # 128
#hidden_dim = [64,64]  # 64
output_dim = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

accuracy = test(model)
print(' Accuracy: ', accuracy)

 Accuracy:  54.87132352941177
