In [None]:
# load the processed data

import numpy as np
import pandas as pd

np_all_train = np.load(r'\np_all_train_new.npy')

In [None]:
from torch.utils.data import sampler, Dataset, DataLoader
import torch

class TopicData(Dataset):
    def  __init__(self,inputs,labels):
        self.inputs=torch.DoubleTensor(inputs.astype(float)).to(torch.float32)
        self.labels=torch.DoubleTensor(labels.astype(float)).to(torch.float32)
        self.len = inputs.shape[0]
        
    def __getitem__(self,index):
        inp = self.inputs[index,:,:]
        label = self.labels[index,:,:]
        return inp,label
        
    def __len__(self):
        return self.len

x = np_all_train[:,:14,:]
y = np_all_train[:,14:,:]

shuffled_indices = np.random.permutation(len(np_all_train))
train_indices = shuffled_indices[:int(len(np_all_train)*0.8)]
val_indices = shuffled_indices[int(len(np_all_train)*0.8):int(len(np_all_train)*0.9)]
test_indices = shuffled_indices[int(len(np_all_train)*0.9):]

x_train = x[train_indices,:,:]
y_train = y[train_indices,:,:]
x_val = x[val_indices,:,:]
y_val = y[val_indices,:,:]
x_test = x[test_indices,:,:]
y_test = y[test_indices,:,:]

print(x_train.shape,y_train.shape)
print(x_val.shape,y_val.shape)
print(x_test.shape,y_test.shape)

trainset = TopicData(x_train, y_train)
train_data_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
valset = TopicData(x_val, y_val)
val_data_loader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
testset = TopicData(x_test, y_test)
test_data_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=0, drop_last=True)

In [None]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import numpy as np
from spikingjelly.clock_driven import neuron, encoding, functional
#from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# Define and initialize the network
tau=2.0
num_topic = 5
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(14*num_topic, 256, bias=False),
    nn.Linear(256, 512, bias=False),
    nn.LayerNorm(512),
    nn.Linear(512, 256, bias=False),
    nn.Linear(256, num_topic, bias=False),
    neuron.LIFNode(tau=tau)
)
model = model.to('cpu')
encoder = encoding.PoissonEncoder()

In [None]:
import copy

device = 'cpu'
train_epoch = 200
T = 50
train_times = 0
min_test_loss = 10000
log_interval = 100
train_accs = []
test_accs = []
train_losses = []
test_losses = []

lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(train_epoch):
    print("Epoch {}:".format(epoch))
    print("Training...")
    train_correct_sum = 0
    train_sum = 0
    train_loss = 0
    model.train()
    for batch_idx, (inputs, labels) in enumerate(train_data_loader):
        inputs = inputs.to(device)
        labels = labels.squeeze(1).to(device)
        
        optimizer.zero_grad()
        
        # run for T times，out_spikes_counter is a tensor with shape=[batch_size, 10]
        # record the pulse times of 10 neurons in the output layer during the simulation
        
        for t in range(T):
            if t == 0:
                out_spikes_counter = model(encoder(inputs).float())
            else:
                out_spikes_counter += model(encoder(inputs).float())
        
        # out_spikes_counter / T obtained the pulse frequency of 10 neurons in the output layer during the simulation
        out_spikes_counter_frequency = out_spikes_counter / T
        #out_spikes_counter_frequency += model(encoder(inputs).float())
        
        loss = F.mse_loss(out_spikes_counter_frequency, labels)
        loss.backward()
        optimizer.step()
        # reset the network because the neurons of SNN has memories
        functional.reset_net(model)
        
        train_correct_sum += (out_spikes_counter_frequency.max(1)[1] == labels.max(1)[1]).float().sum().item()
        train_sum += inputs.shape[0]
        
        train_loss += loss.item()*inputs.shape[0]
        
        train_batch_accuracy = (out_spikes_counter_frequency.max(1)[1] == labels.max(1)[1]).float().mean().item()
        if train_times%log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Train Loss : {:.4f} Train batch Accuracy: {:.4f}'.format(
                                epoch,batch_idx*inputs.shape[0],len(x_train),batch_idx*inputs.shape[0]/len(trainset)*100,loss.item(),train_batch_accuracy*100))
        #writer.add_scalar('train_batch_accuracy', train_batch_accuracy, train_times)
        train_accs.append(train_batch_accuracy)
        
        train_times += 1
    train_accuracy = train_correct_sum / train_sum
    train_loss = train_loss/len(trainset)
    train_losses.append(train_loss)
    
    print("Testing...")
    model.eval()
    with torch.no_grad():
        test_sum = 0
        correct_sum = 0
        test_loss = 0
        for batch_idx, (inputs, labels) in enumerate(val_data_loader):
            inputs = inputs.to(device)
            labels = labels.squeeze(1).to(device)
            for t in range(T):
                if t == 0:
                    out_spikes_counter = model(encoder(inputs).float())
                else:
                    out_spikes_counter += model(encoder(inputs).float())
                    
            out_spikes_counter = out_spikes_counter / T
            correct_sum += (out_spikes_counter.max(1)[1] == labels.max(1)[1]).float().sum().item()
            loss = F.mse_loss(out_spikes_counter, labels)
            test_loss += loss.item()*inputs.shape[0]
            test_sum += inputs.shape[0]
            functional.reset_net(model)
        test_accuracy = correct_sum / test_sum
        test_loss = test_loss/len(valset)
        test_losses.append(test_loss)
        test_accs.append(test_accuracy)
        min_test_loss = min(min_test_loss, test_loss)
        if min_test_loss == test_loss:
            best_model_state_dict = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), r'\SNN_topic_new(%d).tar' %(epoch+1))
    print("Epoch {}: train_loss={}, test_loss={}, test_accuracy={}, min_test_loss={}, train_times={}".format(
                epoch,train_loss,test_loss,test_accuracy,min_test_loss,train_times))
    print()
    
torch.save(model.state_dict(), r'\SNN_topic_new(Final).tar')
model.load_state_dict(best_model_state_dict)
torch.save(model.state_dict(), r'\SNN_topic_new.tar')
