In [1]:
from tonic import datasets
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn.functional as F


class LSTMClassifier(nn.Module):
    def __init__(self, input_size=700, hidden_size=128, num_layers=10, num_classes=20):
        super(LSTMClassifier, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, dropout=0.2)
        self.fc = nn.Linear(hidden_size, num_classes)

    
    def forward(self, x):
        lstm_out, _ = self.lstm(x)         
        out = self.fc(lstm_out[-1])
        return out


In [3]:
def convert_to_time_binned_sequences(data):
    X = []
    Y = []

    for i, (spikes, label) in enumerate(data):
        sequences = []
        current_i = 0

        while current_i < 1400000:
            filtered_spikes = spikes[(spikes['t'] > current_i) & (spikes['t'] <= current_i + 10000)]

            sequence = np.zeros(700)
            for neuron, count in Counter(filtered_spikes['x']).items():
                sequence[neuron] = count

            current_i = current_i + 10000
            sequences.append(sequence)

        X.append(sequences)
        Y.append(label)
    
    return np.array(X), np.array(Y)

In [4]:
def create_data_loader(X, Y):
    X = torch.tensor(X, dtype=torch.float32)      
    y = torch.tensor(Y, dtype=torch.long)        
    
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=1)

    return loader

In [5]:
def load_lstm_train_test_data():
    train_data = datasets.SHD("./data", train=True)
    test_data = datasets.SHD("./data", train=False)

    x_train, y_train = convert_to_time_binned_sequences(train_data)
    x_test, y_test = convert_to_time_binned_sequences(test_data)

    train_loader = create_data_loader(x_train, y_train)
    test_loader = create_data_loader(x_test, y_test)

    return train_loader, test_loader 
   

In [6]:
train_loader, test_loader = load_lstm_train_test_data()
print('DATA LOADED')

DATA LOADED


In [7]:
model = LSTMClassifier(num_classes=20)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1000):
    correct = 0
    total = 0

    for X_batch, y_batch in train_loader:
        X_batch = X_batch.permute(1, 0, 2).contiguous()

        output = model(X_batch)

        loss = criterion(output, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(output, 1)
        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

    epoch_acc = correct / total

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Train Acc: {epoch_acc}")


Epoch 1, Loss: 2.9927, Train Acc: 0.04916625796959294
Epoch 2, Loss: 2.9622, Train Acc: 0.048675821481118195
Epoch 3, Loss: 2.9864, Train Acc: 0.054561059342815105
Epoch 4, Loss: 3.0013, Train Acc: 0.05186365865620402
Epoch 5, Loss: 2.9906, Train Acc: 0.049901912702305054
Epoch 6, Loss: 2.9943, Train Acc: 0.05223148602256008
Epoch 7, Loss: 2.9997, Train Acc: 0.05247670426679745
Epoch 8, Loss: 2.9890, Train Acc: 0.05259931338891614
Epoch 9, Loss: 2.9925, Train Acc: 0.05051495831289848
Epoch 10, Loss: 2.9963, Train Acc: 0.05272192251103482
Epoch 11, Loss: 3.0010, Train Acc: 0.05492888670917116
Epoch 12, Loss: 2.9910, Train Acc: 0.05345757724374693
Epoch 13, Loss: 2.9989, Train Acc: 0.050882785679254534
Epoch 14, Loss: 3.0002, Train Acc: 0.05358018636586562
Epoch 15, Loss: 2.9955, Train Acc: 0.051495831289847964
Epoch 16, Loss: 3.0100, Train Acc: 0.048675821481118195


KeyboardInterrupt: 

In [None]:
test_correct = 0
test_total = 0

for X_batch, y_batch in test_loader:
    output = model(X_batch)

    _, predicted = torch.max(output, 1)
    test_correct += (predicted == y_batch).sum().item()
    test_total += y_batch.size(0)

print(f'test acc = {test_correct / test_total}')

# LSTM improvement

In [None]:
# Author: Robert Guthrie

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

In [None]:
torch.randn(1, 1, 3).shape

In [None]:
lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5

# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))

for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)