In [1]:
from tonic import datasets
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size=700, hidden_size=256, num_layers=1, num_classes=20):
        super(LSTMClassifier, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, dropout=0.2)
        self.fc = nn.Linear(hidden_size, num_classes)

    
    def forward(self, x):
        lstm_out, _ = self.lstm(x)         
        out = self.fc(torch.mean(lstm_out, dim=0))
        return out


In [3]:
def convert_to_time_binned_sequences(data):
    X = []
    Y = []

    for i, (spikes, label) in enumerate(data):
        sequences = []
        current_i = 0

        while current_i < 1400000:
            filtered_spikes = spikes[(spikes['t'] > current_i) & (spikes['t'] <= current_i + 10000)]

            sequence = np.zeros(700)
            for neuron, count in Counter(filtered_spikes['x']).items():
                sequence[neuron] = count

            current_i = current_i + 10000
            sequences.append(sequence)

        X.append(sequences)
        Y.append(label)
    
    return np.array(X), np.array(Y)

In [4]:
def create_data_loader(X, Y):
    X = torch.tensor(X, dtype=torch.float32)      
    y = torch.tensor(Y, dtype=torch.long)        
    
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=1)

    return loader

In [5]:
def load_lstm_train_test_data():
    train_data = datasets.SHD("./data", train=True)
    test_data = datasets.SHD("./data", train=False)

    x_train, y_train = convert_to_time_binned_sequences(train_data)
    x_test, y_test = convert_to_time_binned_sequences(test_data)

    train_loader = create_data_loader(x_train, y_train)
    test_loader = create_data_loader(x_test, y_test)

    return train_loader, test_loader 
   

In [6]:
train_loader, test_loader = load_lstm_train_test_data()
print('DATA LOADED')

DATA LOADED


In [25]:
model = LSTMClassifier(num_classes=20)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    correct = 0
    total = 0

    for X_batch, y_batch in train_loader:
        X_batch = X_batch.permute(1, 0, 2).contiguous()

        output = model(X_batch)

        loss = criterion(output, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(output, 1)
        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

    epoch_acc = correct / total

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Train Acc: {epoch_acc}")


Epoch 1, Loss: 0.3601, Train Acc: 0.5914664051005395
Epoch 2, Loss: 0.0199, Train Acc: 0.8527464443354585
Epoch 3, Loss: 0.0452, Train Acc: 0.9041196665031879
Epoch 4, Loss: 0.0062, Train Acc: 0.9394310936733693
Epoch 5, Loss: 0.0601, Train Acc: 0.9596615988229524


In [30]:
test_correct = 0
test_total = 0

for X_batch, y_batch in test_loader:
    X_batch = X_batch.permute(1, 0, 2).contiguous()

    output = model(X_batch)

    _, predicted = torch.max(output, 1)
    test_correct += (predicted == y_batch).sum().item()
    test_total += y_batch.size(0)

print(f'test acc = {test_correct / test_total}')

test acc = 0.7840106007067138
