In [75]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import snntorch as snn
import pandas as pd

In [76]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [77]:
# setting seeds
np.random.seed(445)
torch.manual_seed(445)

<torch._C.Generator at 0x2351cb572b0>

## Model

In [78]:
class SimpleSNNPredictor(nn.Module):
    def __init__(self, num_inputs, num_hidden, beta=0.95, num_steps=25):
        super().__init__()
        self.num_steps = num_steps

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [79]:
batch_size = 128
dtype = torch.float
def print_batch_accuracy(data, targets, train=False):
    output, _ =model(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

## Data

In [80]:
data06 = pd.read_csv("../data/processed/target06.csv")

In [81]:
all_data = data.TensorDataset(torch.from_numpy((data06.values[:,:-1] - data06.values[:,:-1].min(0)) / data06.values[:,:-1].ptp(0)).float(), torch.from_numpy(data06.values[:,-1]).float())  # with normalization
train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(all_data, (round(0.7 * len(all_data)), round(0.2 * len(all_data)), round(0.1 * len(all_data))))

In [82]:
train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=128, shuffle=False, drop_last=False)
valid_loader = data.DataLoader(valid_dataset, batch_size=128, shuffle=False, drop_last=False)

## Prepare model

In [83]:
num_inputs=8
num_hidden=1000
num_outputs=150

model = SimpleSNNPredictor(num_inputs, num_hidden, num_outputs)
model.to(device)
print(model)

SimpleSNNPredictor(
  (fc1): Linear(in_features=8, out_features=1000, bias=True)
  (lif1): Leaky()
  (fc2): Linear(in_features=1000, out_features=150, bias=True)
  (lif2): Leaky()
)


In [84]:
optimizer = torch.optim.Adam(model.parameters(), lr=6e-4, betas=(0.9, 0.999))
loss = nn.CrossEntropyLoss()
epochs = 1
num_steps = 20


## Training

In [85]:
num_epochs = 20
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.round().type(torch.LongTensor).to(device)

    # forward pass
        model.train()
        spk_rec, mem_rec = model(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros(1, dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            model.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.round().type(torch.LongTensor).to(device)

            # Test set forward pass
            test_spk, test_mem = model(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros(1, dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 110.53
Test Set Loss: 97.94
Train set accuracy for a single minibatch: 3.91%
Test set accuracy for a single minibatch: 7.03%


Epoch 1, Iteration 3
Train Set Loss: 77.12
Test Set Loss: 78.19
Train set accuracy for a single minibatch: 10.16%
Test set accuracy for a single minibatch: 3.91%


Epoch 2, Iteration 6
Train Set Loss: 74.85
Test Set Loss: 75.76
Train set accuracy for a single minibatch: 9.38%
Test set accuracy for a single minibatch: 6.25%


Epoch 3, Iteration 9
Train Set Loss: 72.22
Test Set Loss: 76.88
Train set accuracy for a single minibatch: 9.38%
Test set accuracy for a single minibatch: 5.47%


Epoch 4, Iteration 12
Train Set Loss: 73.23
Test Set Loss: 77.51
Train set accuracy for a single minibatch: 10.16%
Test set accuracy for a single minibatch: 7.81%


Epoch 5, Iteration 15
Train Set Loss: 72.87
Test Set Loss: 80.81
Train set accuracy for a single minibatch: 8.59%
Test set accuracy for a single minibatch: 5.47%


Epoch 6, Iteratio

In [86]:
test_mem.sum(dim=0).max(1)

torch.return_types.max(
values=tensor([-89.9639,   7.9948,   6.4838,   7.9867,  12.5874, -41.4855, -15.5760,
         11.4371, -44.0091,   0.6600, -13.6747,  -5.2879,   0.3238, -44.4139,
        -54.6852, -49.0480,   7.6722, -79.1993, -25.6788,  12.2297,   8.6836,
        -35.4031, -23.9025, -47.6584, -30.4154, -10.5103,   5.9806, -46.8056,
        -32.0493,   6.6479,  -2.9635,  11.3102, -69.8873,  -2.1698,   8.0040,
          8.8136, -19.6020,   2.1339, -13.5232, -36.1154, -23.4889, -18.5171,
        -45.0278, -45.0049, -43.8459,  11.5084,  -7.0310,  13.7153,  -1.1603,
        -49.8283,   7.7661,   2.9539, -10.6930, -65.5782,  12.5545, -45.2888,
         13.5379,   9.5674, -48.5092,  10.5516, -59.3391,   7.7039,  15.4329,
         12.8692, -58.3383, -13.7325,  10.8780,  11.4188, -28.0685, -27.4100,
          5.6322, -52.1165,   0.5348,   9.8030,   4.3293, -24.6781, -83.0302,
        -27.0966, -26.4096, -15.6699, -82.9912,  12.1810,   5.5019, -12.0608,
         -9.9057,  -6.7921, -24.6

In [87]:
test_targets

tensor([49, 18, 16, 55, 11, 25,  9, 15, 10, 21, 15, 17,  9, 35, 12, 12, 10, 26,
        17, 12,  6, 25, 15, 20, 27,  5, 17, 33, 15, 12,  9, 22, 65,  7,  5, 17,
        12, 21, 11, 55, 23, 16, 19, 14,  9, 14, 13, 13, 13, 26, 12,  9, 10, 65,
        15, 54, 20, 13, 21,  6, 15,  9, 23, 11, 32, 12,  6,  9, 32, 29, 18, 98,
        29, 25, 12,  7, 25, 32, 16, 10, 11, 12,  9,  5, 33, 32, 25, 10,  3,  8,
        64, 17, 17, 30, 12, 15, 14,  9, 16, 47, 29, 22, 25, 68, 18, 10, 24, 13,
        17, 10, 16, 16, 10, 29, 11, 35, 14, 10, 12, 16, 14, 14, 44,  9, 32, 14,
        15, 14], device='cuda:0')