In [1]:
import numpy as np
import pandas as pd

np_all_train = np.load(r'\np_all_train.npy')
np_train_speed = np.load(r'\np_train_speed.npy')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import sampler, Dataset, DataLoader

class TopicData(Dataset):
    def  __init__(self,inputs,labels):
        self.inputs=torch.DoubleTensor(inputs.astype(float)).to(torch.float32)
        self.labels=torch.DoubleTensor((labels-1).astype(float)).to(torch.float32)
        self.labels_one_hot = F.one_hot(self.labels.to(torch.int64), 900).float()
        self.len = inputs.shape[0]
        
    def __getitem__(self,index):
        inp = self.inputs[index,:,:]
        label = self.labels[index,:]
        label_one_hot = self.labels_one_hot[index,:,:].squeeze(0)
        return inp,label,label_one_hot
        
    def __len__(self):
        return self.len
    
x = np_all_train[:,:14,:]
y = np_train_speed

shuffled_indices = np.random.permutation(len(np_all_train))
train_indices = shuffled_indices[:int(len(np_all_train)*0.8)]
val_indices = shuffled_indices[int(len(np_all_train)*0.8):int(len(np_all_train)*0.9)]
test_indices = shuffled_indices[int(len(np_all_train)*0.9):]

x_train = x[train_indices,:,:]
y_train = y[train_indices,:]
x_val = x[val_indices,:,:]
y_val = y[val_indices,:]
x_test = x[test_indices,:,:]
y_test = y[test_indices,:]

print(x_train.shape,y_train.shape)
print(x_val.shape,y_val.shape)
print(x_test.shape,y_test.shape)

trainset = TopicData(x_train, y_train)
train_data_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
valset = TopicData(x_val, y_val)
val_data_loader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
testset = TopicData(x_test, y_test)
test_data_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=0, drop_last=True)

(153359, 14, 8) (153359, 1)
(19170, 14, 8) (19170, 1)
(19170, 14, 8) (19170, 1)


In [3]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import numpy as np
from spikingjelly.clock_driven import neuron, encoding, functional
#from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# Define and initialize the network
tau=5.0
num_topic = 8
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(14*num_topic, 256, bias=False),
    nn.Linear(256, 512, bias=False),
    nn.LayerNorm(512),
    nn.Linear(512, 256, bias=False),
    nn.Linear(256, 900, bias=False),
    neuron.LIFNode(tau=tau)
)
model = model.to('cpu')

No CUDA runtime is found, using CUDA_HOME='D:\Program\anaconda\envs\pytorch-gpu'
spikingjelly.clock_driven.spike_op: list index out of range




In [4]:
import math
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

def array_to_cor(inputs_speed):
    array_speed = inputs_speed.cpu().numpy()
    cor_speed = []
    for i in range(6):
        cor_speed.append((int(array_speed[0,2*i+0]),int(array_speed[0,2*i+1])))
    return cor_speed

def words_to_cor(words):
    a = math.ceil(words/30)
    b = words % 30
    return (a,b)

def cor_to_words(cor):
    a,b = cor
    words = (a-1)*30.0+b
    return words

def lon_speed(lon):
    if lon != 6:
        return 2+(lon-1)*4
    else:
        return 22

def lat_speed(lat):
    if lat == 1:
        return -0.4479
    elif lat == 2:
        return -0.2240
    elif lat == 3:
        return 0
    elif lat == 4:
        return 0.2240
    elif lat == 5:
        return 0.4479
    
def cor_to_speed(cor):
    a,b = cor
    a_lon = math.ceil(a/5)
    a_lat = a - (a_lon-1)*5
    b_lon = math.ceil(b/5)
    b_lat = b - (b_lon-1)*5
    
    return [a_lon,a_lat,b_lon,b_lat],[lon_speed(a_lon),lon_speed(b_lon)],[lat_speed(a_lat),lat_speed(b_lat)]

def batch_speed_trans(outputs):
    #outputs = outputs.item()
    lons = np.zeros([0,2])
    lats = np.zeros([0,2])
    for output in outputs:
        _,lon_v_pre,lat_v_pre = cor_to_speed(words_to_cor(output.item()))
        lons = np.concatenate([lons,np.array(lon_v_pre).reshape(1,-1)],axis=0)
        lats = np.concatenate([lats,np.array(lat_v_pre).reshape(1,-1)],axis=0)
    return lons,lats

In [None]:
import copy

device = 'cpu'
train_epoch = 100
T = 20
train_times = 0
min_test_loss = 10000
log_interval = 100
train_accs = []
test_accs = []
train_losses = []
test_losses = []

lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
encoder = encoding.PoissonEncoder()
criterion = nn.CrossEntropyLoss()

for epoch in range(train_epoch):
    print("Epoch {}:".format(epoch))
    print("Training...")
    train_correct_sum = 0
    train_sum = 0
    train_loss = 0
    model.train()
    for batch_idx, (inputs, labels,labels_one_hot) in enumerate(train_data_loader):
        inputs = inputs.to(device)
        labels = labels.squeeze(1).to(device)
        labels_one_hot = labels_one_hot.to(device)

        optimizer.zero_grad()

        # run for T times，out_spikes_counter is a tensor with shape=[batch_size, 10]
        # record the pulse times of 10 neurons in the output layer during the simulation
        for t in range(T):
            if t == 0:
                out_spikes_counter = model(encoder(inputs).float())
            else:
                out_spikes_counter += model(encoder(inputs).float())

        # out_spikes_counter / T obtained the pulse frequency of 10 neurons in the output layer during the simulation
        out_spikes_counter_frequency = out_spikes_counter / T

        #loss = criterion(out_spikes_counter_frequency.to(torch.float32), labels.long())
        loss = F.mse_loss(out_spikes_counter_frequency,labels_one_hot)
        loss.backward()
        optimizer.step()
        # reset the network because the neurons of SNN has memories
        functional.reset_net(model)

        train_correct_sum += (out_spikes_counter_frequency.max(1)[1] == labels).float().sum().item()
        train_sum += inputs.shape[0]

        train_loss += loss.item()*inputs.shape[0]
        
        train_batch_accuracy = (out_spikes_counter_frequency.max(1)[1] == labels).float().mean().item()
        if train_times%log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Train Loss : {:.4f} Train batch Accuracy: {:.4f}'.format(
                                epoch,batch_idx*inputs.shape[0],len(x_train),batch_idx*inputs.shape[0]/len(trainset)*100,loss.item(),train_batch_accuracy*100))
        #writer.add_scalar('train_batch_accuracy', train_batch_accuracy, train_times)
        train_accs.append(train_batch_accuracy)

        train_times += 1
    train_accuracy = train_correct_sum / train_sum
    train_loss = train_loss/len(trainset)
    train_losses.append(train_loss)
    
    print("Testing...")
    model.eval()
    with torch.no_grad():
        test_sum = 0
        correct_sum = 0
        test_loss = 0
        all_outputs_lons = np.zeros([0,2])
        all_outputs_lats = np.zeros([0,2])
        all_labels_lons = np.zeros([0,2])
        all_labels_lats = np.zeros([0,2])
        for batch_idx, (inputs, labels,labels_one_hot) in enumerate(val_data_loader):
            inputs = inputs.to(device)
            labels = labels.squeeze(1).to(device)
            labels_one_hot = labels_one_hot.to(device)
            for t in range(T):
                if t == 0:
                    out_spikes_counter = model(encoder(inputs).float())
                else:
                    out_spikes_counter += model(encoder(inputs).float())
                    
            out_spikes_counter = out_spikes_counter / T
            correct_sum += (out_spikes_counter.max(1)[1] == labels).float().sum().item()
            #loss = criterion(out_spikes_counter.to(torch.float32), labels.long())
            loss = F.mse_loss(out_spikes_counter,labels_one_hot)
            test_loss += loss.item()*inputs.shape[0]
            test_sum += inputs.shape[0]
            functional.reset_net(model)
            outputs_lons,outputs_lats = batch_speed_trans(out_spikes_counter.max(1)[1]+1)
            labels_lons,labels_lats = batch_speed_trans(labels+1)
            all_outputs_lons = np.concatenate([all_outputs_lons,outputs_lons],axis=0)
            all_outputs_lats = np.concatenate([all_outputs_lats,outputs_lats],axis=0)
            all_labels_lons = np.concatenate([all_labels_lons,labels_lons],axis=0)
            all_labels_lats = np.concatenate([all_labels_lats,labels_lats],axis=0)
        test_accuracy = correct_sum / test_sum
        test_loss = test_loss/len(valset)
        test_losses.append(test_loss)
        test_accs.append(test_accuracy)
        min_test_loss = min(min_test_loss, test_loss)
        if min_test_loss == test_loss:
            best_model_state_dict = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), r'\SNN_0424_speed(%d).tar' %(epoch+1))
    print("Epoch {}: train_loss={}, test_loss={}, test_accuracy={}, min_test_loss={}, train_times={}".format(
                epoch,train_loss,test_loss,test_accuracy,min_test_loss,train_times))
    print(f"Lon_v: MSE：{mean_squared_error(outputs_lons, labels_lons)}")
    print(f"Lon_v: RMSE：{np.sqrt(mean_squared_error(outputs_lons, labels_lons))}")
    print(f"Lon_v: MAE：{mean_absolute_error(outputs_lons, labels_lons)}")
    print(f"Lat_v: MSE：{mean_squared_error(outputs_lats, labels_lats)}")
    print(f"Lat_v: RMSE：{np.sqrt(mean_squared_error(outputs_lats, labels_lats))}")
    print(f"Lat_v: MAE：{mean_absolute_error(outputs_lats, labels_lats)}")
    print()
    
torch.save(model.state_dict(), r'\SNN_0424_speed(Final).tar')
model.load_state_dict(best_model_state_dict)
torch.save(model.state_dict(), r'\SNN_0424_speed.tar')


In [None]:
model.load_state_dict(torch.load(r'\SNN_0424_speed.tar'))

device = 'cpu'
T = 50
encoder = encoding.PoissonEncoder()

test_sum = 0
correct_sum = 0
test_loss = 0
all_outputs_lons = np.zeros([0,2])
all_outputs_lats = np.zeros([0,2])
all_labels_lons = np.zeros([0,2])
all_labels_lats = np.zeros([0,2])
for batch_idx, (inputs, labels,labels_one_hot) in enumerate(val_data_loader):
    if batch_idx%100==0:
        print(batch_idx*inputs.shape[0]/len(x_val)*100)
    inputs = inputs.to(device)
    labels = labels.squeeze(1).to(device)
    labels_one_hot = labels_one_hot.to(device)
    for t in range(T):
        if t == 0:
            out_spikes_counter = model(encoder(inputs).float())
        else:
            out_spikes_counter += model(encoder(inputs).float())
                    
    out_spikes_counter = out_spikes_counter / T
    correct_sum += (out_spikes_counter.max(1)[1] == labels).float().sum().item()
    #loss = criterion(out_spikes_counter.to(torch.float32), labels.long())
    loss = F.mse_loss(out_spikes_counter,labels_one_hot)
    test_loss += loss.item()*inputs.shape[0]
    test_sum += inputs.shape[0]
    functional.reset_net(model)
    outputs_lons,outputs_lats = batch_speed_trans(out_spikes_counter.max(1)[1]+1)
    labels_lons,labels_lats = batch_speed_trans(labels+1)
    all_outputs_lons = np.concatenate([all_outputs_lons,outputs_lons],axis=0)
    all_outputs_lats = np.concatenate([all_outputs_lats,outputs_lats],axis=0)
    all_labels_lons = np.concatenate([all_labels_lons,labels_lons],axis=0)
    all_labels_lats = np.concatenate([all_labels_lats,labels_lats],axis=0)

In [None]:
print(f"Lon_v: MSE：{mean_squared_error(all_outputs_lons, all_labels_lons)}")
print(f"Lon_v: RMSE：{np.sqrt(mean_squared_error(all_outputs_lons, all_labels_lons))}")
print(f"Lon_v: MAE：{mean_absolute_error(all_outputs_lons, all_labels_lons)}")
print(f"Lat_v: MSE：{mean_squared_error(all_outputs_lats, all_labels_lats)}")
print(f"Lat_v: RMSE：{np.sqrt(mean_squared_error(all_outputs_lats, all_labels_lats))}")
print(f"Lat_v: MAE：{mean_absolute_error(all_outputs_lats, all_labels_lats)}")