Skip to content

ZhejianglabNCRC/SPAIC

Repository files navigation

SPAIC

English : 中文

Spike-based artificial intelligence computing platform

The spaic platform simulation training platform is a network construction, forward simulation and learning training platform developed for spiking neural networks. It mainly includes modules such as front-end network modeling, back-end simulation and training, model algorithm library, data display and analysis, etc.

Dependency packages: pytorch, numpy

Installation

Recently, SPAIC use PyTorch as backend for computation. If you also want to use CUDA, please make sure you have a CUDA version PyTorch installed.

Tutorial documentation for the SPAIC: https://spaic.readthedocs.io/en/latest/index.html

Install the last stable version from PyPI:

pip install SPAIC

From GitHub:

git clone https://github.com/ZhejianglabNCRC/SPAIC.git
cd SPAIC
python setup.py install

If you still have some questions, please feel free to contact us:
Chaofei Hong hongchf@zhejianglab.com
Mengwen Yuan yuanmw@zhejianglab.com
Mengxiao Zhang mxzhangice@zju.edu.com

Front-end Network Modeling Components and functions

​ The platform mainly builds the network through five types of structural modules such as Assembly, Connection, NeuronGroup, Node, and Network. The specific functions are described as follows, and the modeling structure relationship is shown in the following figure.

  • Assembly: It is an abstract class of neural network structure topology, representing any network structure, and other network modules are subclasses of the Assembly class. The Assembly object has three dict attributes named _groups , _projections and _connections, which save the set of neurons and connections inside the neural assembly. It also has list attributes named _supers, _input_connections, and _output_connections, which represent the upper neural set containing this neural set and the connections to this neural set, respectively. As the main interface for network modeling, it includes the following main modeling functions:

    • add_assembly(name, assembly): add a new assembly member to the neural assembly
    • del_assembly(assembly=None, name=None): delete an assembly member that already exists in the neural assembly
    • copy_assembly(name, assembly): Copy an existing assembly structure and add the new assembly to this neural assembly
    • replace_assembly(old_assembly, new_assembly):Replace an existing neural assembly inside the assembly with a new neural assembly
    • merge_assembly( assembly): Merge this neural set with another neural set to get a new neural set
    • select_assembly(assemblies, name=None):Select some members in this neural assembly and the connections between them to form a new neural assembly, the original assembly remains unchanged
    • add_connection( name, connection): add the connect between two groups of neurons inside the assembly
    • del_connection(connection=None, name=None): delete a connection inside the assembly
    • assembly_hide(): hide this neural assembly and do not participate in this training, simulation or display
    • assembly_show(): convert this neural assembly from hidden state to normal state.
  • Connection: A class for establishing connections between NeuronGroups, including the functions of generating and managing different types of synaptic connections and specific links. Some Key parameters for initialize connections are list below:

    • pre_assembly - presynaptic neuron, can also be regarded as the starting point of the connection, the previous layer
    • post_assembly - Postsynaptic neuron, can also be regarded as the end point of the connection, the next layer
    • name - the name of the connection, it is recommended that the user give a meaningful name
    • link_type - connection type, such as full connection, sparse connection, convolution connection, etc.
    • max_delay - the synaptic delay, i.e. the signal from the presynaptic neuron will be delayed by several time steps before being delivered to the postsynaptic neuron
    • sparse_with_mask - enable or disable the filter used for sparse matrices
    • pre_var_name - the output of the presynaptic neuron to the synapse, that is, the signal received by the connection, the default is to receive the spike signal sent by the presynaptic neuron named as 'O'
  • Projection: A class for establishing connections between Assemblies, it contains multiple specific Connections of NeuornGroups, the Connections can be coded by user or automatically generated by ConnectionPolicy.

  • NeuronGroup: is a class of a certain number of neurons, usually called a layer of neurons, with the same neuron model, connection form, etc. Although it inherits from the Assembly class, its internal The _groups, _projections and _connections properties are empty. Key parameters are neuron numbers, neuron model type and shape of the NeuronGroup.

  • Node: The node is the object to transfer the input and output of the neural network, including the encoding and decoding, which converts the input into discharge or converts the discharge into output. Like NeuronGroup, the internal _groups and _connections properties are empty.

  • Network: The top-level structure in the Assembly subclass. All modules of each constructed neural network are included in a Network object, which is also responsible for network training, simulation, data interaction and other network modeling work. In addition to the _groups and _connections attributes of the Assemby object, it also has _monitors, _learners, _optimizers, _backend, and other attributes, while _supers, _input_connections, _output_connections and other attributes are empty. Network provide the following interfaces for network building and training:

    • set_backend: set the compuation backend
    • build: build the front-end network into computation graph
    • set_runtime: set the simulation time
    • run: run a simulation
    • save_state: save network weights
    • state_from_dict: read network weights

Typical Use Case

The simulation and training using the SPAIC mainly includes following steps: 1) data or environment import, 2)parameter selection related to the training process of the trainer, 3)model construction (including input and output node construction, neuron cluster, network connection, network topology, learning algorithm, data recorder and other units), 4) procedures of neuron simulation or training, model data analysis and saving

Import SPAIC library

import spaic

Set training simulation parameters

run_time = 200.0
bat_size = 100

Import training dataset

# Create Dataset
root = 'D:\Datasets\MNIST'
dataset = spaic.MNIST(root, is_train=False)

#  Create DataLoader
dataloader = spaic.Dataloader(dataset, batch_size=bat_size, shuffle=True, drop_last=True)
n_batch = dataloader.batch_num
>> Dataset loaded

Network model construction

The model can be built in two ways: first, like Pytorch's module class inheritance, which is built in the init function, and second, like Nengo's with statement. In addition, the existing model structure in the model algorithm library can also be introduced into the modeling process

Modeling Method 1: Class Inheritance Form

class ExampleNet(spaic.Network):
     def __init__(self):
        super(ExampleNet, self).__init__()
        
        
        # Create an input node and select the input encoding method
        self.input = spaic.Encoder(dataloader, encoding='latency')
              
        # Establish neurongroups, select neuron types, and set neuron parameter values
        self.layer1 = spaic.NeuronGroup(100, model='clif')
        self.layer2 = spaic.NeuronGroup(10, model='clif')
        
        # Establish connections between Neurongroups
        self.connection1 = spaic.Connection(self.input, self.layer1, link_type='full')
        self.connection2 = spaic.Connection(self.layer1, self.layer2, link_type='full')
        
        # Create an output node and select the output decoding method
        self.output = spaic.Decoder(decoding='spike_counts',target=self.layer2)

        # Establish a state detector, which can monitor the state of various objects
        self.monitor1 = spaic.StateMonitor(self.layer1, 'V')

        # Add the learning algorithm and select the network structure to be trained 
        self.learner1 = spaic.STCA(0.5, self)
        
        # Add optimization algorithm
        self.optim = spaic.Adam(lr=0.01, schedule='StepLR', maxstep=1000)

# Initialize the ExampleNet network object
Net = ExampleNet()

Modeling method 2: Using "with"

# Initialize the object of the basic network class
Net = spaic.Network()

# Create a network structure by defining network components in with
with Net:
    # Create an input node and select the input encoding method
    input = spaic.Encoder(dataloader, encoding='latency')


    # Establish neurongroups, select neuron types, and set neuron parameter values
    layer1 = spaic.NeuronGroup(100, model='clif')
    layer2 = spaic.NeuronGroup(10, model='clif')

    # Establish connections between Neurongroups
    connection1 = spaic.Connection(input1, layer1, link_type='full')
    connection2 = spaic.Connection(layer1, layer2, link_type='full')

    # Create an output node and select the output decoding method
    output = spaic.Decoder(decoding='spike_counts',target=layer2)

    # Establish a state detector, which can monitor the state of various objects
    monitor1 = spaic.StateMonitor(layer1, 'V')

    # Add the learning algorithm and select the network structure to be trained 
    learner1 = spaic.STCA(0.5, Net)
    
    # Add optimization algorithm
    optim = spaic.Adam(lr=0.01, schedule='StepLR', maxstep=1000)
    

Modeling method 3: importing a model library model and modifying it with functions

from spaic.Library import ExampleNet
Net = ExampleNet()
# neuron parameters
neuron_param = {
    'tau_m': 8.0,
    'V_th': 1.5,
}
# New neurongroups
layer3 = spaic.NeuronGroup(100, model='lif', param=neuron_param)
layer4 = spaic.NeuronGroup(100, model='lif', param=neuron_param)

# Add a new member to the Assembly
Net.add_assembly('layer3', layer3)
# Delete the members that already exist in the Assembly
Net.del_assembly(Net.layer3)
#  Copy an existing assembly structure and add the new assembly to this assembly
Net.copy_assembly('net_layer', ExampleNet())
# Replace an existing neural assembly inside the assembly with a new assembly
Net.replace_assembly(Net.layer1, layer3)
# Merge this neural assembly with another assembly to get a new neural assembly
Net2 = ExampleNet()
Net.merge_assembly(Net2)
#Connect two neurongroups inside the assembly
con = spaic.Connection(Net.layer2, Net.net_layer, link_type='full')
Net.add_connection('con3', con)
#Take out some set members in this assembly and their connections
Net3 = Net.select_assembly([Net.layer2, net_layer])

Choose a backend and compile the network

backend = spaic.Torch_Backend()
sim_name = backend.backend_name
Net.build(backend)

Start training

for epoch in range(100):
    print("Start training")
    train_loss = 0
    train_acc = 0
    pbar = tqdm(total=len(train_loader))
    for i, item in enumerate(train_loader):
        # forward
        data, label = item
        Net.input(data)
        Net.output(label)
        Net.run(run_time)
        output = Net.output.predict
        output = (output - torch.mean(output).detach()) / (torch.std(output).detach() + 0.1)
        label = torch.tensor(label, device=device)
        batch_loss = F.cross_entropy(output, label)

        # backward
        Net.learner.optim_zero_grad()
        batch_loss.backward()
        Net.learner.optim_step()

Single simulation run

Net.run(run_time=run_time)

ploting the results

from matplotlib import pyplot as plt
plt.plot(Net.monitor1.times, Net.monitor1.values[0,0,:])
plt.show()

Save the network

Net.save(filename='TestNet')