TBC

In [1]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
import torch
import torch.nn as nn
import spikingjelly
import torchvision
import torch.utils.data as data
from tqdm import tqdm
from spikingjelly.activation_based import neuron, layer, learning, surrogate, encoding, functional
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
import matplotlib.pyplot as plt
import numpy as np
from spikingjelly.datasets.n_mnist import NMNIST
import matplotlib.pyplot as plt
import numpy as np

# 指定数据集的根目录
root = '/home/lain/imperial2022/FFRSDTP/download'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 创建数据集实例
transform = Compose([
    Lambda(lambda x: x.reshape(x.shape[0], -1))])
train_dataset = NMNIST(root, train=True, data_type='frame', frames_number=15, split_by='number', transform=transform)
test_dataset = NMNIST(root, train=False, data_type='frame', frames_number=15, split_by='number', transform=transform)
train_data_loader = DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, drop_last=False, num_workers=0)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=50, shuffle=True, drop_last=False, num_workers=0)

2023-08-20 23:59:25.417622: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-20 23:59:25.541663: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-20 23:59:25.566411: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-08-20 23:59:26.008457: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: li

The directory [/home/lain/imperial2022/FFRSDTP/download/frames_number_15_split_by_number] already exists.
The directory [/home/lain/imperial2022/FFRSDTP/download/frames_number_15_split_by_number] already exists.


In [2]:
class LayerOfLain():

  """
  This class is used to instantiate the layer object in the SFF algorithm, 
  provide a training function that is uniformly called by the network during training, 
  and perform local training independently.

  Member variables:
  threshold_pos (float): used to determine whether goodness_pos is large enough and directly participate in training.
  threshold_neg (float): used to determine whether goodness_neg is small enough and directly participate in training.
  min_weight (float): Used to provide a hard bound when calling the STDP module.
  max_weight (float): Used to provide a hard bound when calling the STDP module.
  encoder (encoder): Poisson encoder, which converts traditional data into pulse shape data that conforms to Poisson distribution.
  time_step (int): The time step length of the simulation for each data sample.
  learning_rate (float): learning rate.
  pre_time_au (float): Spiking neural network hyperparameters, time constants related to membrane potential decay and STDP.
  post_time_au (float): spiking neural network hyperparameters, time constants related to membrane potential decay and STDP.
  learner (MSTDPLearner): reward-modulated STDP learner, called when the STDP module of SFF is started.
  """
  
  def __init__(self, N_input, N_output, pre_time_au = 2.,
               post_time_au = 100., time_step = 15,
               batch_size = 50, learning_rate = 0.00003, threshold_both = 0.88):
    self.single_net = nn.Sequential(
        #layer.Flatten(start_dim=2),
        layer.Linear(N_input, N_output, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan())
    ).to(device)
    self.threshold_pos = threshold_both
    self.threshold_neg = threshold_both
    self.min_weight = -1.
    self.max_weight = 1.
    self.encoder = encoding.PoissonEncoder()
    self.time_step = time_step
    self.learning_rate = learning_rate
    self.pre_time_au = pre_time_au
    self.post_time_au = post_time_au
    self.batch_size = batch_size
    self.N_output = N_output
    self.encoder = encoding.PoissonEncoder()
    self.learner = learning.MSTDPLearner(step_mode='s', batch_size=self.batch_size,
                     synapse=self.single_net[0], sn=self.single_net[1],
                     tau_pre=self.pre_time_au, tau_post=self.post_time_au,
                     )
    self.learner.disable()

  def goodness_cal(self, output):
    goodness = output.pow(2).mean(1)
    return goodness

  def reward_from_goodness(self, output, pos_flag):
    alpha_pos = 1.
    alpha_neg = 1.
    goodness = output.pow(2).mean(1)
    if(pos_flag==True):
      return alpha_pos * (goodness - self.threshold_pos)
    else:
      return alpha_neg * (self.threshold_neg - goodness)


  def forward_with_training(self, input_pos, input_neg, insight_pos, insight_neg, stdpflag = True):

    weight_opter_stdp = torch.optim.SGD(self.single_net.parameters(), lr=0.01, momentum=0.)
    weight_opter_surrogate = torch.optim.Adam(self.single_net.parameters(), lr=self.learning_rate)

    if(stdpflag == True):
      with torch.no_grad():
          self.learner.enable()
          reward_pos = 0.
          for t in range(self.time_step):
              # Positive update
              reward_pos = self.reward_from_goodness(self.single_net(input_pos[t]), True)

              weight_opter_stdp.zero_grad()
              self.learner.step(reward_pos, on_grad=True)
              weight_opter_stdp.step()
          self.learner.reset()

          reward_neg = 0.
          for t3 in range(self.time_step):
              # Negative update
              reward_neg = self.reward_from_goodness(self.single_net(input_neg[t3]), False)

              weight_opter_stdp.zero_grad()
              self.learner.step(reward_neg, on_grad=True)
              weight_opter_stdp.step()
          self.learner.reset()
          torch.cuda.empty_cache()
          self.learner.disable()
      functional.reset_net(self.single_net)
      torch.cuda.empty_cache()

    goodness_pos = 0.
    for t in range(self.time_step):
        # Positive update
        goodness_pos += self.goodness_cal(self.single_net(input_pos[t]))
    
    goodness_pos = goodness_pos / self.time_step
    #print(goodness_pos)
    goodness_neg = 0.
    for t3 in range(self.time_step):
        # Negative update
        goodness_neg += self.goodness_cal(self.single_net(input_neg[t3]))

    goodness_neg = goodness_neg / self.time_step
    
    combined_pos = self.threshold_pos - goodness_pos# - insight_pos
    combined_neg = - self.threshold_neg + goodness_neg# - insight_neg

    loss_mixed = torch.log(torch.exp(torch.cat([combined_pos, combined_neg])) + 1).mean()
    weight_opter_surrogate.zero_grad()
    loss_mixed.backward()
    weight_opter_surrogate.step()
    functional.reset_net(self.single_net)

  def forward_withOUT_training(self, input_pos, input_neg):
    total_output_pos_list = []
    total_output_neg_list = []
    for t2 in range(self.time_step):
      total_output_pos_list.append((self.single_net(input_pos[t2])).detach())
      total_output_neg_list.append((self.single_net(input_neg[t2])).detach())

    total_output_pos = torch.stack(total_output_pos_list, dim=0)
    total_output_neg = torch.stack(total_output_neg_list, dim=0)
    return total_output_pos, total_output_neg

  def forward_withOUT_training_single(self, input_pos, firstflag):
    total_output_pos_list = []
    if(firstflag==0):
      for t2 in range(self.time_step):
        total_output_pos_list.append(self.single_net(input_pos[t2]).detach())
      total_output_pos = torch.stack(total_output_pos_list, dim=0)
    else:
      for t2 in range(self.time_step):
        total_output_pos_list.append(self.single_net(input_pos[t2]).detach())
      total_output_pos = torch.stack(total_output_pos_list, dim=0)


    return total_output_pos

In [3]:
def label_encoder(input, label):
    labeled_input = input.clone()
    for t in range(input.shape[0]):
      start_index = 34*34
      end_index = start_index + 10
      labeled_input[t][:, start_index:end_index] *= 0.0
      labeled_input[t][range(input[t].shape[0]), start_index + label] = 2 * input.max()
      labeled_input[t][:, -34:-24] *= 0.0
      labeled_input[t][range(input[t].shape[0]), -34+label] = 2 * input.max()
    return labeled_input

def poisson_iter(input, t):
    batch_size, dim = input.shape
    output = torch.zeros((t, batch_size, dim))
    encoder = encoding.PoissonEncoder()
    for i in range(t):
        encoden_input = encoder(input)
        output[i] = encoden_input
    return output

class NetOfLain(torch.nn.Module):
    
    """
    This class is used to instantiate the net object in the SFF algorithm, coordinate and call the training functions of each layer during training, so that they can perform local training independently.

    Member variables:
    lain_layers (LayerOfLain list): used to store layers for constructing SFF spiking neural network.
    insight_pos (float): The key constant for SFF to realize layer collaboration, which is the sum of the goodness of each layer after positive data propagation.
    """
    
    def __init__(self, lain_dimension):
        super().__init__()
        self.lain_layers = []
        self.insight_pos = 0.
        self.insight_neg = 0.
        for d in range(len(lain_dimension) - 1):
            if(d == 0):
              layer = LayerOfLain(lain_dimension[d], lain_dimension[d + 1], pre_time_au = 2., post_time_au = 100.)
              self.lain_layers.append(layer)
            else:
              layer = LayerOfLain(lain_dimension[d], lain_dimension[d + 1], pre_time_au = 2., post_time_au = 100., learning_rate=0.0001, threshold_both=0.9)
              self.lain_layers.append(layer)

    def network_train_layers(self, train_data_loader, epo):
      torch.cuda.empty_cache()
      for i, lain_layer in enumerate(self.lain_layers):
        print('training layer', i, '...')
        for features, labels in tqdm(train_data_loader):
          if(epo > i*1):
            break
          torch.cuda.empty_cache()
          features, labels = features.to(device), labels.to(device)
          features = features.transpose(0, 1)
          features_pos = label_encoder(features, labels)
          rnd = torch.randperm(features.size(1))
          features_neg = label_encoder(features, labels[rnd])
          #features_pos = poisson_iter(features_pos, lain_layer.time_step)
          features_pos = features_pos.to(device)
          #features_neg = poisson_iter(features_neg, lain_layer.time_step)
          features_neg = features_neg.to(device)
          del features, labels
          torch.cuda.empty_cache()
          #features_pos = features_pos.transpose(0, 1)
          #features_neg = features_neg.transpose(0, 1)
          #self.insight_pos = self.network_collaboration(features_pos)
          #self.insight_neg = self.network_collaboration(features_neg)
          positive_hidden, negative_hidden = features_pos, features_neg
          if(i > 0) :
            for o in range(i):
              positive_hidden, negative_hidden = self.lain_layers[o].forward_withOUT_training(positive_hidden, negative_hidden)
              positive_hidden = positive_hidden*10
              negative_hidden = negative_hidden*10
              functional.reset_net(self.lain_layers[o].single_net)
          torch.cuda.empty_cache()
          if(i==0):
            lain_layer.forward_with_training(positive_hidden, negative_hidden, self.insight_pos, self.insight_neg, stdpflag=False)
          else:
            lain_layer.forward_with_training(positive_hidden, negative_hidden, self.insight_pos, self.insight_neg, stdpflag=True)

    def network_predict(self, input):
      every_labels_goodness = []
      for label in range(10):
        hidden = label_encoder(input, label)
        #hidden = poisson_iter(hidden, 50)
        hidden = hidden.to(device)
        torch.cuda.empty_cache()
        every_layer_goodness = []
        for p, lain_layer in enumerate(self.lain_layers):
          hidden = lain_layer.forward_withOUT_training_single(hidden, p)
          goodnesstem = []
          for t in range(lain_layer.time_step):
            goodnesstem.append((hidden[t].pow(2).mean(1)).unsqueeze(0))
          every_layer_goodness += [(torch.cat(goodnesstem, dim=0)).sum(0)]
        #print(every_layer_goodness[0])
        every_labels_goodness += [sum(every_layer_goodness).unsqueeze(1)]
        del hidden
        #for lain_layer in self.lain_layers:
          #functional.reset_net(lain_layer.single_net)
        torch.cuda.empty_cache()
      every_labels_goodness = torch.cat(every_labels_goodness, 1)
      #print(every_labels_goodness.argmax(1))
      return every_labels_goodness.argmax(1)

    def network_collaboration(self, input):
        hidden = input.clone()
        every_layer_goodness = []
        for p, lain_layer in enumerate(self.lain_layers):
          hidden = lain_layer.forward_withOUT_training_single(hidden, p)
          goodnesstem = []
          for t in range(lain_layer.time_step):
            goodnesstem.append((hidden[t].pow(2).mean(1)).unsqueeze(0))
          every_layer_goodness += [(torch.cat(goodnesstem, dim=0)).sum(0)]
          functional.reset_net(lain_layer.single_net)
        del hidden
        torch.cuda.empty_cache()
        return sum(every_layer_goodness)

In [4]:
if __name__ == "__main__":
    torch.manual_seed(1000)
    torch.cuda.empty_cache()
    alice = NetOfLain([2312, 1000, 500])
    for epo in range(2):
      print("Epoch:", epo)
      torch.cuda.empty_cache()
      alice.network_train_layers(train_data_loader, epo)
      countT = 0.
      lossT = 0.
      for test_x, test_y in test_data_loader:
        test_x, test_y = test_x.to(device), test_y.to(device)
        test_x = test_x.transpose(0, 1)
        lossT += 1.0 - alice.network_predict(test_x).eq(test_y).float().mean().item()
        countT += 1
        for lain_layer in alice.lain_layers:
          functional.reset_net(lain_layer.single_net)
      print('test error:', lossT / countT)

Epoch: 0
training layer 0 ...


100%|██████████| 1200/1200 [01:34<00:00, 12.71it/s]


training layer 1 ...


100%|██████████| 1200/1200 [04:54<00:00,  4.07it/s]


test error: 0.4156000143289566
Epoch: 1
training layer 0 ...


  0%|          | 0/1200 [00:00<?, ?it/s]


training layer 1 ...


100%|██████████| 1200/1200 [04:53<00:00,  4.09it/s]


test error: 0.41490001454949377
