In [14]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import statistics
import os
import matplotlib.pyplot as plt
import scipy.stats
import math 
import sklearn
from sklearn import metrics
from snntorch import functional as SF
import csv 
import pandas as pd
# dataset citation: https://www.kaggle.com/datasets/dmitryshkadarevich/branch-prediction?resource=download

In [15]:
filepath = "Datasets\\bp_data\\I04.csv" 
file = pd.read_csv(filepath)

print(type(file))

<class 'pandas.core.frame.DataFrame'>


In [None]:
class Net(nn.Module):
    def __init__(self, num_inputs, num_hidden, beta, n_actions, num_steps, batch_size):
        super(Net, self).__init__()
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.input_width = num_inputs

        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_hidden)
        self.lif2 = snn.Leaky(beta=beta)
        self.fc3 = nn.Linear(num_hidden, n_actions)
        self.lif3 = snn.Leaky(beta=beta)
        
        # self.spikes_hist = []
        #self.spikes_hist = torch.empty(0).to(device)

    def forward(self, x):
        
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk1_rec = []
        spk2_rec = []
        spk3_rec = []
        mem3_rec = [] 
        for step in range(self.num_steps):
            cur1 = self.fc1(x[0:self.batch_size, step, 0:self.input_width]) 
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec.append(spk1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)
        
        spk1_rec = torch.stack(spk1_rec)
        spk2_rec = torch.stack(spk2_rec)
        spk3_rec = torch.stack(spk3_rec)
        self.spikes_hist = [spk1_rec, spk2_rec, spk3_rec]

        mem3_rec = torch.stack(mem3_rec)

        return spk3_rec, mem3_rec

In [3]:
def count_spike(spikes): 
    output = []
    unpacked = []
    for item in spikes: 
        # time * batch * neurons 
        # sum in time and neuron dimensions 
        unpacked.append(item.sum(0).sum(-1))
        #gives spike count by batch 
        #print(item.sum(0).sum(-1).size())
    for i in range(unpacked[0].size()[0]): 
        count = 0
        for item in unpacked:
            count += item[i] 
        
        output.append(count)
    return torch.stack(output)

In [5]:
def train_loop(data, model, num_epochs, loss, optimizer, batch_size, device):
    loss_hist = []
    dtype = torch.float
    loss_by_epoch = [] 

    # Outer training loop
    for epoch in range(num_epochs):
        epoch_loss = [] 
        #iter_counter = 0

        # Minibatch training loop
        for i, (item, targets) in enumerate(iter(data)):
            targets = targets[0:batch_size, 1:2]
            targets = targets.flatten()

            targets = targets.type(torch.long)

            # forward pass
            model.train()
            spk_rec, mem_rec = model(item)

            # initialize the loss & sum over time
            loss_val = torch.zeros((1), dtype=dtype, device=device)

            loss_val = torch.sum(loss(mem_rec, targets))

            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward(retain_graph=True)
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())
            epoch_loss.append(loss_val.item())

            # Calculate accuracy without SF.accuracy_rate
            pred = torch.argmax(mem_rec.detach(), dim=1)
            accuracy = metrics.accuracy_score(targets.detach(), pred)

            # Print accuracy
            if loss_val.item() > 0.001:
                if i % 2 == 0:
                    print(f"Epoch {epoch}, Iteration {i}, Train Loss: {loss_val.item():.2f}, Accuracy: {accuracy * 100:.2f}%")
            else:
                if i % 20 == 0:
                    print(f"Epoch {epoch}, Iteration {i}, Train Loss: {loss_val.item():.2f}, Accuracy: {accuracy * 100:.2f}%")

        loss_by_epoch.append(epoch_loss)
    
    return (loss_hist, loss_by_epoch)

In [6]:
def eval_loop(data_eval, model, net, batch_size):
    f1_score = []
    accuracy = []
    preds = [] 
    spike_counts = []
    net.eval()
    for item, targets in iter(data_eval):
        targets = targets[0:batch_size, 1:2]
        targets = targets.flatten()
        targets = targets.type(torch.long)

        #Get predictions
        spk_rec, mem_rec = model(item)
        spike_counts = spike_counts + count_spike(spk_rec).tolist()

        #Convert to binary 
        pred = torch.argmax(mem_rec.detach(), dim=1)
        # Calculate metrics 



        f1_score.append(metrics.f1_score(targets.detach(), pred))
        accuracy.append(metrics.accuracy_score(targets.detach(), pred))
        preds.extend(pred.tolist())

    net.train()
    return(spike_counts, preds, f1_score, accuracy)

In [None]:
num_inputs = 80 # modify
num_steps = 1 # modify 
batch_size = 128 # may be modified 
num_hidden = 128
num_outputs = 2  
beta = .95 
# self, n_observations, num_hidden, beta, n_actions, num_steps, batch_size, input_width: int=80, use_mempot = False
net = Net(num_inputs, num_hidden, beta, num_outputs, num_steps, batch_size, )
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))