Skip to content
Synaptic Neural Network
Branch: master
Clone or download

Latest commit

Fetching latest commit…
Cannot retrieve the latest commit at this time.


Type Name Latest commit message Commit time
Failed to load latest commit information.

SynaNN: Synaptic Neural Network


Synapses play an important role in biological neural networks. They're joint points of neurons where learning and memory happened. Inspired by the synapse research of neuroscience, we found that a simple model can describe the key property of a synapse. By combining deep learning, we expect to build ultra large scale neural networks to solve real-world AI problems. At the same time, we want to create an explainable neural network model to better understand what an AI model doing instead of a black box solution.

Based on the analysis of excitatory and inhibitory channels of synapses, we proposed a synapse model called Synaptic Neural Network (SynaNN) where a synapse function is represented as the inputs of probabilities of both excitatory and inhibitory channels along with their joint probability as the output. SynaNN is constructed by synapses and neurons. A synapse graph is a connection of synapses. In particular, a synapse tensor is fully connected synapses from input neurons to output neurons. Synapse learning can work with gradient descent and backpropagation algorithms. SynaNN can be applied to construct MLP, CNN, and RNN models.

SynaNN Key Features

  • Synapses are joint points of neurons with electronic and chemical functions, location of learning and memory
  • A synapse function is nonlinear, log concavity, infinite derivative in surprisal space (negative log space)
  • Surprisal synapse is Bose-Einstein distribution with coefficient as negative chemical potential
  • SynaNN graph & tensor, surprisal space, commutative diagram, topological conjugacy, backpropagation algorithm
  • SynaMLP, SynaCNN, SynaRNN are models for various neural network architecture
  • Synapse block can be embedded into other neural network models
  • Swap equation links between swap and odds ratio for healthcare, fin-tech, and insurance applications

One challenge was to represent the links of synapses as tensors so we can apply the neural network framework such as tensorflow for deep learning. A key step is to construct a Synapse layer to replace Dense layer in Keras so we can embed synapse in deep learning neural network. This has been done by defining a custom model of Synapse.

Synapse pluses BatchNormalization can greatly speed up the processing to achieve an accuracy goal. We can think of a synapse as a statistical distribution computing unit while batch normalization makes evolution faster.

SynaNN for MNIST in Pytorch

# SynaNN for Image Classification with MNIST Dataset in Pytorch
# Copyright (c) 2019, Chang LI. All rights reserved.
# Open source, MIT License.

# header
from __future__ import print_function

import math
import argparse

import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import Module

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

train_losses = []
train_counter = []
test_counter = []
test_losses = []

class Synapse(nn.Module):
    r"""Applies a synapse function to the incoming data.`

        in_features:  size of each input sample
        out_features: size of each output sample
        bias:         if set to ``False``, the layer will not learn an additive bias.
                      Default: ``True``

        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
             additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
             are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

        weight: the learnable weights of the module of shape
            	:math:`(\text{out\_features}, \text{in\_features})`. The values are
            	initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            	:math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`


        >>> m = Synapse(64, 64)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Synapse, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            self.register_parameter('bias', None)

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
    # synapse core
    def forward(self, input):
        # shapex = matrix_diag(input)
        diagx = torch.stack(tuple(t.diag() for t in torch.unbind(input,0)))
        shapex = diagx.view(-1, self.out_features)
        betax = torch.log1p(-shapex @ self.weight.t())
        row = betax.size()
        allone = torch.ones(int(row[0]/self.out_features), row[0])
        if torch.cuda.is_available():
          allone = allone.cuda()
        return torch.exp(torch.log(input) + allone @ betax) # + self.bias)    

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        # fully connected with synapse function
        self.fc1 = nn.Linear(320, 64)
        self.fcn = Synapse(64,64)
        self.fcb = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,
        x = F.softmax(x, dim=1)
        # fcn is the output of synapse
        x = self.fcn(x)
        # fcb is the batch no)rmal 
        x = self.fcb(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target =,
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
                (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
  , 'model.pth')
  , 'optimizer.pth')

def test(args, model, device, test_loader):
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target =,
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
def main():
    # Training settings
    import easydict
    args = easydict.EasyDict({
      "batch_size": 100,
      "test_batch_size": 100,
      "epochs": 200,
      "lr": 0.012,
      "momentum": 0.5,
      "no_cuda": False,
      "seed": 5,
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.backends.cudnn.enabled = False
    device = torch.device("cuda:0" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader =
        datasets.MNIST('../data', train=True, download=True,
                           transforms.Normalize((0.1307,), (0.3081,))
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader =
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.Normalize((0.1307,), (0.3081,))
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(),, momentum=args.momentum)
    # optimizer = optim.Adam(model.parameters(),
    test_counter = [i*len(train_loader.dataset) for i in range(args.epochs)]
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
    # draw curves
    fig = plt.figure()
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
if __name__ == '__main__':


"SynaNN: A Synaptic Neural Network and Synapse Learning"

You can’t perform that action at this time.