In [None]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

import random

In [None]:
batch_size = 8192
data_path='/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

# Transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# Network Architecture
num_inputs = 28*28
num_outputs = 10

# Temporal Dynamics
num_steps = 30
beta = 0.95

In [None]:
# Network

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(num_inputs, num_outputs, bias=False)
        self.lif1 = snn.Leaky(beta=beta)
    def forward(self, x):
        mem1 = self.lif1.init_leaky()

        spk1_rec = []
        mem1_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)

        return torch.stack(spk1_rec, dim=0), torch.stack(mem1_rec, dim=0)
        
net = Net().to(device)

In [None]:
state_dict = net.state_dict()

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

spk_rec, mem_rec = net(data.view(batch_size, -1))
print(data.size())

In [None]:
# Function for testing network on test dataset.

def test_acc():
  total = 0
  correct = 0

  test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

  with torch.no_grad():
    net.eval()
    for data, targets in test_loader:
      data = data.to(device)
      targets = targets.to(device)
      
      # forward pass
      test_spk, _ = net(data.view(data.size(0), -1))

      # calculate total accuracy
      _, predicted = test_spk.sum(dim=0).max(1)
      total += targets.size(0)
      correct += (predicted == targets).sum().item()

  return (100 * correct / total)

In [None]:
# Cross-entropy fitness function for GA

def fitness_function(chromosome_input, parameter_input):
    
    # Loading chromosome onto network
    state_dict = net.state_dict()
    state_dict[parameter_input] = chromosome_input.clone().detach()
    net.load_state_dict(state_dict)

    spk_rec, mem_rec = net(data.view(batch_size, -1))

    # Initialize the total loss value
    loss_val = torch.zeros((1), dtype=dtype, device=device)

    # Sum loss at every step
    for step in range(num_steps):
        loss_val += loss(mem_rec[step], targets)

    return loss_val.item()

In [None]:
# Crossover function for GA. Produces one off spring per member of population.

def crossover(population, row, column):

    for i in range(2, int(population_number/2)+1, 2):        
        children = torch.zeros(2,row,column, device=device)
        crossover_col = torch.randint(0,column,(1,))

        for n in range(2):
            children[n,:,:crossover_col], children[n,:,crossover_col:] = population[i-2,:,:crossover_col], population[i-1,:,crossover_col:]

        population = torch.cat((population,children))
    return population

In [None]:
# Mutation function for GA.

def mutation(population, best_chromosome, population_number, row, column, best_prob):
    for i in range(0, int(population_number)):
        if random.random() < best_prob:
            for row_count in range(row):
                for column_count in range(column):
                    if random.random() < 0.0003:
                        if random.random() < 0.5:
                            population[i, row_count, column_count] += random.random()*0.5
                        else:
                            population[i, row_count, column_count] -= random.random()*0.5
        else:
            population[i,:,:] = best_chromosome[:,:]    # Chance of inserting best solution into population
    return population

In [None]:
# Main loop of GA.

def evolve(input, population_number, generation_goal, parameter_input):

    input_row = input.size()[1]
    input_column = input.size()[2]

    test_accuracy_hist= []
    loss_hist = []

    solution_input = torch.zeros(generation_goal, input_row, input_column, device=device)
    best_acc = 99999
    best_loss_index = 0
    best_prob = .98

    for evolution_number in range(generation_goal):
        
        best_acc_gen = 99999
        accuracy_generation_list = torch.full((2,int(population_number/2)), 99999)  # List for storing fitness and index of best solutions.
        temp_input = torch.zeros(int(population_number/2), input_row, input_column, device=device)  # Used to sort population after fitness score.

        for chromosome_number in range(population_number):

            accuracy = fitness_function(input[chromosome_number,:,:], parameter_input)
                  
            if accuracy < best_acc:     #Saving best solution.
                best_acc = accuracy
                best_loss_index = evolution_number

            if accuracy < best_acc_gen:     # Saving best solution of each generation.
                best_acc_gen = accuracy
                solution_input[evolution_number,:,:] = input[chromosome_number,:,:].detach().clone()
                
            highest_loss_index = torch.max(accuracy_generation_list, 1)         
            
            if (accuracy) < highest_loss_index[0][1]:   # Selection of the best solutions of generations for GA.

                indice = int(highest_loss_index[1][1])
                accuracy_generation_list[0][indice] = chromosome_number
                accuracy_generation_list[1][indice] = accuracy

        temp_input = torch.zeros((int(population_number/2), input_row, input_column), device=device)

        loss_hist.append(best_acc)
        
        try:
            if best_acc_gen > loss_hist[evolution_number - 1]:  #Increases probability of inserting best solution into population if perfomance worsen
                best_prob = .9
        except:
            pass

        for n in range(int(population_number/2)):       # Sorting the population based on fitness.
            index = torch.min(accuracy_generation_list,1)
            temp_input[n,:,:] = input[index[1][1],:,:]
            accuracy_generation_list[1][index[1][1]] = 99999
        
        if (evolution_number % 500) == 0:   # Testing the best solution on test dataset every 500th interval.
            state_dict = net.state_dict()
            state_dict[parameter_input] = solution_input[best_loss_index, :, :]
            net.load_state_dict(state_dict)
            accuracy_test = test_acc()
            test_accuracy_hist.append(accuracy_test)

            print(f"""
            Generation:         {evolution_number}/{generation_goal}
            Training loss:      {best_acc:.2f}
            Test accuracy:      {accuracy_test:.2f}%\n"""
            )

            data, targets = next(iter(train_loader))    # Changes batch every 500th generation.
            data = data.to(device)
            targets = targets.to(device)

            best_acc = 99999          

        temp_input = crossover(temp_input, input_row, input_column)
        temp_input = mutation(temp_input, solution_input[best_loss_index,:,:], population_number, input_row, input_column, best_prob)

        input = temp_input.detach().clone()

    state_dict = net.state_dict()
    state_dict[parameter_input] = solution_input[generation_goal-1].detach().clone()
    net.load_state_dict(state_dict)
    
    return solution_input, loss_hist
            

In [None]:
population_number = 8   # Size of population.
generation_goal = 10000 # Generation termination goal

In [None]:
with torch.no_grad():
    net.eval()

    fc1w = torch.randn(population_number, num_outputs, num_inputs, device=device)
    mul = torch.tensor([0.01], device=device)
    fc1w = torch.multiply(fc1w, mul)

    input, loss_hist = evolve(fc1w, population_number, generation_goal, 'fc1.weight')

test_acc()

In [None]:
# Plotting the training loss.

fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.title("Loss Curves")
plt.legend(["Train Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()