In [1]:
# Execute this code block to install dependencies when running on colab
try:
    import torch
except:
    from os.path import exists
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

try: 
    import torchbearer
except:
    !pip install torchbearer

In [2]:
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torchvision.datasets import MNIST

# fix random seed for reproducibility
seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
np.random.seed(seed)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# flatten 28*28 images to a 784 vector for each image
transform = transforms.Compose([
    transforms.ToTensor(),  # convert to tensor
    transforms.Lambda(lambda x: x.view(-1))  # flatten into vector
])




# load data
trainset = MNIST(".", train=True, download=True, transform=transform)
testset = MNIST(".", train=False, download=True, transform=transform)

# https://stackoverflow.com/questions/47432168/taking-subsets-of-a-pytorch-dataset
# code from 2nd answer used to get only the first 27105 examples in the training set
# for the ANN to train on

mask = list([i for i in range(27105)])
trainset_1 = torch.utils.data.Subset(trainset, mask)

# create data loaders
trainloader = DataLoader(trainset_1, batch_size=128, sampler=SubsetRandomSampler(np.where(mask)[0]))
testloader = DataLoader(testset, batch_size=128, shuffle=True)

(len(trainloader.dataset))


27105

In [3]:
class ANN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, hidden_size)  
        self.fc3 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = F.dropout(out,p=0.2)
        out = self.fc2(out)
        out = F.relu(out)
        out = F.dropout(out,p=0.2)
        out = self.fc3(out)
        
        if not self.training:
            out = F.softmax(out, dim=1)
        return out

# build the model
model = ANN(784, 1200, 10)

In [4]:
# define the loss function and the optimiser
loss_function = nn.CrossEntropyLoss()
optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)

# Construct a trial object with the model, optimiser and loss.
# Also specify metrics we wish to compute.
trial = torchbearer.Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)

# Provide the data to the trial
trial.with_generators(trainloader, test_generator=testloader)

# Run 10 epochs of training
trial.run(epochs=2)

# test the performance
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

HBox(children=(FloatProgress(value=0.0, description='0/2(t)', max=212.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='1/2(t)', max=212.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='0/1(e)', max=79.0, style=ProgressStyle(description_width=…


{'test_loss': 1.5664728879928589, 'test_acc': 0.9335999488830566}


In [5]:
def convertANNtoSNN(model):
    
    W1 = torch.clone(model.fc1.weight.data)
    W2 = torch.clone(model.fc2.weight.data)
    W3 = torch.clone(model.fc3.weight.data)
    
    SNN = ANN(W1.shape[1],W1.shape[0],W3.shape[0])
    
    SNN.fc1.weight.data = W1
    SNN.fc2.weight.data = W2
    SNN.fc3.weight.data = W3
    
    return SNN

In [6]:
SNN = convertANNtoSNN(model)

In [7]:
def sleep(SNN,inputs):
  
  maxFiringRate = 40.0
  current = 2.19

  W1 = torch.clone(SNN.fc1.weight).cuda(device)
  W2 = torch.clone(SNN.fc2.weight).cuda(device)
  W3 = torch.clone(SNN.fc3.weight).cuda(device)

  SNN.fc1.weight = torch.nn.Parameter(W1*current)
  SNN.fc2.weight = torch.nn.Parameter(W2*current)
  SNN.fc3.weight = torch.nn.Parameter(W3*current)


    
  voltages_1 = torch.zeros(1200,1).cuda(device)
  voltages_2 = torch.zeros(1200,1).cuda(device)
  voltages_3 = torch.zeros(10,1).cuda(device)
    
    
    
  for i in range(inputs.shape[0]):
        
    this_input = torch.reshape(inputs[i],(inputs.shape[1],1)).cuda(device)
        
    firingRates = this_input * maxFiringRate # list of firing rates for each feature
    firingRatesPer10ms = firingRates / 100 # expected spikes per 10ms per feature
        
    spike_trains = torch.rand(this_input.shape[0],100).cuda(device)
        
    print(firingRatesPer10ms.shape)
    print(spike_trains.shape)
        
    spike_trains = spike_trains < firingRatesPer10ms
        
    spike_trains = spike_trains.float().cuda(device)
        
        
        
        
    # propogate spike trains through snn:
        
    for t in range(100):
      torch.cuda.empty_cache() 
      print("t=",t)
            
      spike_trains_t = spike_trains[:,t].cuda(device)
      spike_trains_t = torch.reshape(spike_trains_t, (spike_trains_t.shape[0],1))
            
      # pass through input layer to first hidden layer:
            
      W1 = torch.clone(SNN.fc1.weight).cuda(device)
            
      voltages_1 = voltages_1 + (W1 @  spike_trains_t)
            
      x = voltages_1 > 36.18
            
      spike_trains_t1 = x.float().cuda(device)
            
      W1 = W1 + (0.063 * (spike_trains_t1 @ spike_trains_t.T))
      W1 = W1 - (0.069 * (spike_trains_t1 @ (1-spike_trains_t).T))
            
            
            
      # pass through first hidden layer to second hidden layer:
            
      W2 = torch.clone(SNN.fc2.weight).cuda(device)
            
      print(W2.shape)
      voltages_2 = voltages_2 + (W2 @  spike_trains_t1)
        
      print(voltages_2[0:500])

      y = voltages_2 > 23.36
            
      spike_trains_t2 = y.float().cuda(device)
            
      W2 = W2 + (0.063 * (spike_trains_t2 @ spike_trains_t1.T))
      W2 = W2 - (0.069 * (spike_trains_t2 @ (1-spike_trains_t1).T))
            
      # pass through second hidden layer to output layer:
            
      W3 = torch.clone(SNN.fc3.weight).cuda(device)
            
      voltages_3 = voltages_3 + (W3 @  spike_trains_t2)
            
      z = voltages_3 > 36.38
            
      spike_trains_t3 = z.float().cuda(device)

      print(spike_trains_t3)  
            
      W3 = W3 + (0.063 * (spike_trains_t3 @ spike_trains_t2.T))
      W3 = W3 - (0.069 * (spike_trains_t3 @ (1-spike_trains_t2).T))
            
      SNN.fc1.weight = torch.nn.Parameter(W1)
      SNN.fc2.weight = torch.nn.Parameter(W2)
      SNN.fc3.weight = torch.nn.Parameter(W3)

  W1 = torch.clone(SNN.fc1.weight.data).cuda(device)
  W2 = torch.clone(SNN.fc2.weight.data).cuda(device)
  W3 = torch.clone(SNN.fc3.weight.data).cuda(device)

  SNN.fc1.weight = torch.nn.Parameter(W1/current)
  SNN.fc2.weight = torch.nn.Parameter(W2/current)
  SNN.fc3.weight = torch.nn.Parameter(W3/current)

       



        
      
        
        


In [8]:
with torch.no_grad():
  for data in trainloader:
      
      inputs = data[0]
      
      sleep(SNN,inputs)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        [ 1.9459e-01],
        [-1.1712e-01],
        [ 6.7815e-01],
        [ 1.2607e-01],
        [-3.0969e-01],
        [ 1.4737e-02],
        [-5.5289e-02],
        [ 4.8391e-02],
        [ 2.2453e-02],
        [ 3.6575e-01],
        [ 6.7376e-01],
        [-3.0923e-01],
        [ 3.5319e-01],
        [-5.5001e-01],
        [-4.0671e-01],
        [ 1.6408e-02],
        [ 3.1586e-01],
        [ 3.9989e-01],
        [-5.8150e-01],
        [ 2.0500e-01],
        [ 1.2112e-01],
        [-5.4558e-01],
        [-4.8595e-01],
        [ 2.7405e-01],
        [-3.4192e-01],
        [-6.0409e-01],
        [ 1.5181e-01],
        [-6.5684e-02],
        [-1.8557e-01],
        [ 5.7386e-01],
        [ 3.7092e-02],
        [ 6.0639e-02],
        [-2.8973e-01],
        [ 3.0545e-01],
        [ 2.4576e-01],
        [ 1.8653e-01],
        [-5.0607e-01],
        [ 3.5822e-01],
        [-4.4173e-01],
        [-2.7079e-01],
        [-4.287

KeyboardInterrupt: ignored

In [11]:
print(SNN.fc1.weight.data.shape)
print(SNN.fc2.weight.data.shape)
print(SNN.fc3.weight.data.shape)

torch.Size([1200, 784])
torch.Size([1200, 1200])
torch.Size([10, 1200])
