In [1]:
from tonic import datasets
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size=700, hidden_size=128, num_layers=1, num_classes=20):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        _, (hn, _) = self.lstm(x)         
        out = self.fc(hn[-1])             
        return out


In [3]:
def convert_to_time_binned_sequences(data):
    X = []
    Y = []

    for i, (spikes, label) in enumerate(data):
        sequences = []
        current_i = 0

        while current_i < 1400000:
            filtered_spikes = spikes[(spikes['t'] > current_i) & (spikes['t'] <= current_i + 10000)]

            sequence = np.zeros(700)
            for neuron, count in Counter(filtered_spikes['x']).items():
                sequence[neuron] = count

            current_i = current_i + 10000
            sequences.append(sequence)

        X.append(sequences)
        Y.append(label)
    
    return X, Y

In [4]:
def create_data_loader(X, Y):
    X = torch.tensor(X, dtype=torch.float32)      
    y = torch.tensor(Y, dtype=torch.long)        
    
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    return loader

In [5]:
def load_lstm_train_test_data():
    train_data = datasets.SHD("./data", train=True)
    test_data = datasets.SHD("./data", train=False)

    x_train, y_train = convert_to_time_binned_sequences(train_data)
    x_test, y_test = convert_to_time_binned_sequences(test_data)

    train_loader = create_data_loader(x_train, y_train)
    test_loader = create_data_loader(x_test, y_test)

    return train_loader, test_loader 
   

In [7]:
train_loader, test_loader = load_lstm_train_test_data()
print('DATA LOADED')

model = LSTMClassifier(num_classes=20)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    correct = 0
    total = 0

    for X_batch, y_batch in train_loader:
        output = model(X_batch)
        loss = criterion(output, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(output, 1)
        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

    epoch_acc = correct / total

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Train Acc: {epoch_acc}")


  X = torch.tensor(X, dtype=torch.float32)


DATA LOADED
Epoch 1, Loss: 2.9938, Train Acc: 0.04769494850416871
Epoch 2, Loss: 2.9989, Train Acc: 0.049288867091711625
Epoch 3, Loss: 2.9990, Train Acc: 0.04818538499264345
Epoch 4, Loss: 2.9932, Train Acc: 0.04732712113781265
Epoch 5, Loss: 2.9935, Train Acc: 0.046468857282981856
Epoch 6, Loss: 2.9957, Train Acc: 0.04769494850416871
Epoch 7, Loss: 2.9948, Train Acc: 0.04843060323688082
Epoch 8, Loss: 3.0163, Train Acc: 0.04855321235899951
Epoch 9, Loss: 2.9947, Train Acc: 0.046223639038744484
Epoch 10, Loss: 2.9961, Train Acc: 0.04659146640510054


In [None]:
test_correct = 0
test_total = 0

for X_batch, y_batch in test_loader:
    output = model(X_batch)

    _, predicted = torch.max(output, 1)
    test_correct += (predicted == y_batch).sum().item()
    test_total += y_batch.size(0)

print(f'test acc = {test_correct / test_total}')

test acc = 0.053886925795053005
