In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd
import numpy as np

# Convert dataframe to numpy and normalize features
from sklearn.preprocessing import StandardScaler

df = pd.read_csv("spiking_time_dataset.csv")

# Extract features and labels
X = df.iloc[:, 1:].values  # Exclude the first column (labels)
y = df.iloc[:, 0].values   # Labels

# Normalize the features
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)  # Long for classification

# Split dataset: 1600 train, 200 validation, 200 test
train_size = 1600
val_size = 200
test_size = 200

train_data, val_data, test_data = random_split(
    TensorDataset(X_tensor, y_tensor), [train_size, val_size, test_size]
)

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Define RNN model
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out, _ = self.rnn(x.unsqueeze(1))  # Add sequence dimension
        out = self.fc(out[:, -1, :])  # Take last time step output
        return out

# Model parameters
input_size = X.shape[1]  # 30 features
hidden_size = 64
num_layers = 1
output_size = 2  # Binary classification

# Instantiate model, loss, and optimizer
model = RNNModel(input_size, hidden_size, num_layers, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Evaluate on validation set
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    val_accuracy = correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

# Evaluate on test set
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model(X_batch)
        _, predicted = torch.max(outputs, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
test_accuracy = correct / total

print(f"Test Accuracy: {test_accuracy:.4f}")


Epoch [1/20], Loss: 26.2553, Validation Accuracy: 0.9050
Epoch [2/20], Loss: 13.5218, Validation Accuracy: 0.9700
Epoch [3/20], Loss: 6.3309, Validation Accuracy: 0.9950
Epoch [4/20], Loss: 3.0523, Validation Accuracy: 1.0000
Epoch [5/20], Loss: 1.7056, Validation Accuracy: 1.0000
Epoch [6/20], Loss: 1.0639, Validation Accuracy: 1.0000
Epoch [7/20], Loss: 0.7272, Validation Accuracy: 1.0000
Epoch [8/20], Loss: 0.5264, Validation Accuracy: 1.0000
Epoch [9/20], Loss: 0.4013, Validation Accuracy: 1.0000
Epoch [10/20], Loss: 0.3146, Validation Accuracy: 1.0000
Epoch [11/20], Loss: 0.2543, Validation Accuracy: 1.0000
Epoch [12/20], Loss: 0.2096, Validation Accuracy: 1.0000
Epoch [13/20], Loss: 0.1760, Validation Accuracy: 1.0000
Epoch [14/20], Loss: 0.1502, Validation Accuracy: 1.0000
Epoch [15/20], Loss: 0.1293, Validation Accuracy: 1.0000
Epoch [16/20], Loss: 0.1122, Validation Accuracy: 1.0000
Epoch [17/20], Loss: 0.0987, Validation Accuracy: 1.0000
Epoch [18/20], Loss: 0.0873, Validatio