In [6]:
# Execute this code block to install dependencies when running on colab
try:
    import torch
except:
    from os.path import exists
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

try: 
    import torchbearer
except:
    !pip install torchbearer

In [8]:
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torchvision.datasets import MNIST

# fix random seed for reproducibility
seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
np.random.seed(seed)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# flatten 28*28 images to a 784 vector for each image
transform = transforms.Compose([
    transforms.ToTensor(),  # convert to tensor
    transforms.Lambda(lambda x: x.view(-1))  # flatten into vector
])


trainset = MNIST(".", train=True, download=True, transform=transform)
testset = MNIST(".", train=False, download=True, transform=transform)

trainset.data = trainset.data[0:27105]
trainset.targets = trainset.targets[0:27105]

trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=128, shuffle=True)

(len(trainloader.dataset))


27105

In [9]:
class ANN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, hidden_size)  
        self.fc3 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = F.dropout(out,p=0.2)
        out = self.fc2(out)
        out = F.relu(out)
        out = F.dropout(out,p=0.2)
        out = self.fc3(out)
        
        if not self.training:
            out = F.softmax(out, dim=1)
        return out

# build the model
model = ANN(784, 1200, 10)

In [10]:
# define the loss function and the optimiser
loss_function = nn.CrossEntropyLoss()
optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)

# Construct a trial object with the model, optimiser and loss.
# Also specify metrics we wish to compute.
trial = torchbearer.Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)

# Provide the data to the trial
trial.with_generators(trainloader, test_generator=testloader)

# Run 10 epochs of training
trial.run(epochs=2)

# test the performance
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

HBox(children=(FloatProgress(value=0.0, description='0/2(t)', max=212.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='1/2(t)', max=212.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='0/1(e)', max=79.0, style=ProgressStyle(description_width=…


{'test_loss': 1.566105842590332, 'test_acc': 0.9332999587059021}


In [11]:
print(model.fc1.weight)
print(model.fc2.weight)
print(model.fc3.weight)

Parameter containing:
tensor([[ 2.4945e-03, -2.1514e-02,  1.1372e-02,  ...,  2.3490e-02,
         -3.4548e-02,  2.3653e-02],
        [ 3.3076e-05, -2.2735e-02, -1.0311e-02,  ...,  1.7026e-02,
         -1.3917e-02, -2.0039e-02],
        [-1.6752e-02, -1.9683e-02, -1.9263e-02,  ...,  6.4771e-03,
         -1.4055e-02,  1.4966e-02],
        ...,
        [ 2.5927e-02,  4.1935e-03, -2.5759e-02,  ..., -3.3507e-02,
         -2.5739e-02, -4.2646e-04],
        [ 3.1898e-02, -1.6602e-02,  3.0114e-02,  ...,  8.6647e-04,
         -2.3524e-02, -8.1189e-03],
        [-1.3606e-02, -2.0262e-02, -2.4811e-02,  ...,  2.5272e-03,
          2.9828e-02, -3.3756e-02]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[-0.0130, -0.0204,  0.0043,  ..., -0.0198,  0.0176,  0.0024],
        [ 0.0230,  0.0019,  0.0015,  ..., -0.0249, -0.0057,  0.0088],
        [-0.0150, -0.0297,  0.0174,  ...,  0.0089, -0.0118,  0.0136],
        ...,
        [-0.0053, -0.0224, -0.0144,  ..., -0.0270, -0.0239, -0.0

In [12]:
def convertANNtoSNN(model):
    
    W1 = torch.clone(model.fc1.weight.data)
    W2 = torch.clone(model.fc2.weight.data)
    W3 = torch.clone(model.fc3.weight.data)
    
    SNN = ANN(W1.shape[1],W1.shape[0],W3.shape[0])
    
    SNN.fc1.weight.data = W1
    SNN.fc2.weight.data = W2
    SNN.fc3.weight.data = W3
    
    return SNN

In [13]:
SNN = convertANNtoSNN(model)

In [14]:
def sleep(SNN,inputs):
  
  maxFiringRate = 40.0
  current = 2.19

  W1 = torch.clone(SNN.fc1.weight).cuda(device)
  W2 = torch.clone(SNN.fc2.weight).cuda(device)
  W3 = torch.clone(SNN.fc3.weight).cuda(device)

  SNN.fc1.weight = torch.nn.Parameter(W1*current)
  SNN.fc2.weight = torch.nn.Parameter(W2*current)
  SNN.fc3.weight = torch.nn.Parameter(W3*current)


    
  voltages_1 = torch.zeros(1200,1).cuda(device)
  voltages_2 = torch.zeros(1200,1).cuda(device)
  voltages_3 = torch.zeros(10,1).cuda(device)
    
    
    
  for i in range(inputs.shape[0]):
        
    this_input = torch.reshape(inputs[i],(inputs.shape[1],1)).cuda(device)
        
    firingRates = this_input * maxFiringRate # list of firing rates for each feature
    firingRatesPer10ms = firingRates / 100 # expected spikes per 10ms per feature
        
    spike_trains = torch.rand(this_input.shape[0],100).cuda(device)
        
    
        
    spike_trains = spike_trains < firingRatesPer10ms
        
    spike_trains = spike_trains.float().cuda(device)
        
        
        
        
    # propogate spike trains through snn:
        
    for t in range(100):
      torch.cuda.empty_cache() 
      
            
      spike_trains_t = spike_trains[:,t].cuda(device)
      spike_trains_t = torch.reshape(spike_trains_t, (spike_trains_t.shape[0],1))
            
      # pass through input layer to first hidden layer:
            
      W1 = torch.clone(SNN.fc1.weight).cuda(device)
            
      voltages_1 = voltages_1 + (W1 @  spike_trains_t)

            
      x = voltages_1 > 0.03618
            
      spike_trains_t1 = x.float().cuda(device)
            
      W1 = W1 + (0.063 * (spike_trains_t1 @ spike_trains_t.T))
      W1 = W1 - (0.069 * (spike_trains_t1 @ (1-spike_trains_t).T))
            
            
            
      # pass through first hidden layer to second hidden layer:
            
      W2 = torch.clone(SNN.fc2.weight).cuda(device)
            
      
      voltages_2 = voltages_2 + (W2 @  spike_trains_t1)
        

      y = voltages_2 > 0.02336
            
      spike_trains_t2 = y.float().cuda(device)
            
      W2 = W2 + (0.063 * (spike_trains_t2 @ spike_trains_t1.T))
      W2 = W2 - (0.069 * (spike_trains_t2 @ (1-spike_trains_t1).T))
            
      # pass through second hidden layer to output layer:
            
      W3 = torch.clone(SNN.fc3.weight).cuda(device)
            
      voltages_3 = voltages_3 + (W3 @  spike_trains_t2)
            
      z = voltages_3 > 0.03638
            
      spike_trains_t3 = z.float().cuda(device)

      
      
            
      W3 = W3 + (0.063 * (spike_trains_t3 @ spike_trains_t2.T))
      W3 = W3 - (0.069 * (spike_trains_t3 @ (1-spike_trains_t2).T))

      
            
      SNN.fc1.weight = torch.nn.Parameter(W1)
      SNN.fc2.weight = torch.nn.Parameter(W2)
      SNN.fc3.weight = torch.nn.Parameter(W3)

      voltages_1 = voltages_1 - (voltages_1 * spike_trains_t1)
      voltages_2 = voltages_2 - (voltages_2 * spike_trains_t2)
      voltages_3 = voltages_3 - (voltages_3 * spike_trains_t3)

      #print(spike_trains_t1)
      #print("-----------------------")
      #print(spike_trains_t2)
      #print("-----------------------")
      #print(spike_trains_t3)
      #print("-----------------------")

  W1 = torch.clone(SNN.fc1.weight.data).cuda(device)
  W2 = torch.clone(SNN.fc2.weight.data).cuda(device)
  W3 = torch.clone(SNN.fc3.weight.data).cuda(device)

  SNN.fc1.weight = torch.nn.Parameter(W1/current)
  SNN.fc2.weight = torch.nn.Parameter(W2/current)
  SNN.fc3.weight = torch.nn.Parameter(W3/current)

       



        
      
        
        


In [15]:
with torch.no_grad():
  for data in trainloader:
      
      inputs = data[0]
      
      sleep(SNN,inputs)

Spike in train 1
Spike in train 2
Spike in train 3
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 3
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 1
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 1
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train 2
Spike in train 1
Spike in train

KeyboardInterrupt: ignored

In [None]:
print(SNN.fc1.weight.data.shape)
print(SNN.fc2.weight.data.shape)
print(SNN.fc3.weight.data.shape)