In [1]:
from dataloader import load_data_label_encoded
from trainer import train_model, evaluate_model
from snn import SNNModelSimple
import torch
import matplotlib.pyplot as plt
import pickle

Parameters that worked the best
1. Optimizer - Adam with lr ~ 0.005 and 0.001
2. Spike encoding ~ temporal with 50 time steps
3. Batch size - 128 (seems to work on-par with 32)
4. Beta - 0.8, 0.9, 0.999
5. Loss - CrossEntropyLoss
6. Num epochs - 100

In [2]:
root_folder = "./tactile_dataset/"
file_name = "final_merged_df_sw500.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_losses_beta = {}
train_accuracies_beta = {}
val_losses_beta = {}
val_accuracies_beta = {}
test_accuracies_beta = {}
for beta in [0.9]:
    print(f"Training for beta={beta}")
    train_loader, val_loader, test_loader, num_outputs, num_features = load_data_label_encoded(root_folder, file_name, spike_encoding='temporal', num_steps=50, batch_size=128, device=device)

    model = SNNModelSimple(num_features, num_outputs, beta=beta)
    model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

    train_losses, train_accuracies, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=50)
    train_losses_beta[beta] = train_losses
    train_accuracies_beta[beta] = train_accuracies
    val_losses_beta[beta] = val_losses
    val_accuracies_beta[beta] = val_accuracies

    test_accuracies_beta[beta] = evaluate_model(model, test_loader)

out_file_name = "exp_snn_beta_09_lr5e-3.pkl"
with open(out_file_name, "wb") as f:
    pickle.dump({
        "train_losses": train_losses_beta,
        "train_accuracies": train_accuracies_beta,
        "val_losses": val_losses_beta,
        "val_accuracies": val_accuracies_beta,
        "test_accuracies": test_accuracies_beta
    }, f)

Training for beta=0.9
Epoch 1/50, Train Loss: 1.9804, Train Accuracy: 27.91%, Val Loss: 1.5687, Val Accuracy: 42.19%, Time: 24.41s
Epoch 2/50, Train Loss: 1.4517, Train Accuracy: 46.23%, Val Loss: 1.3653, Val Accuracy: 48.93%, Time: 22.65s
Epoch 3/50, Train Loss: 1.3285, Train Accuracy: 50.47%, Val Loss: 1.2941, Val Accuracy: 52.25%, Time: 25.12s
Epoch 4/50, Train Loss: 1.2585, Train Accuracy: 52.84%, Val Loss: 1.2160, Val Accuracy: 55.33%, Time: 23.36s
Epoch 5/50, Train Loss: 1.2030, Train Accuracy: 54.17%, Val Loss: 1.1774, Val Accuracy: 54.40%, Time: 24.02s
Epoch 6/50, Train Loss: 1.1636, Train Accuracy: 55.92%, Val Loss: 1.1529, Val Accuracy: 54.81%, Time: 22.93s
Epoch 7/50, Train Loss: 1.1302, Train Accuracy: 56.88%, Val Loss: 1.1662, Val Accuracy: 58.79%, Time: 21.46s
Epoch 8/50, Train Loss: 1.0968, Train Accuracy: 58.09%, Val Loss: 1.0784, Val Accuracy: 58.34%, Time: 21.10s
Epoch 9/50, Train Loss: 1.0835, Train Accuracy: 58.63%, Val Loss: 1.0679, Val Accuracy: 59.59%, Time: 22.9