In [1]:
from dataloader import load_data_label_encoded
from trainer import train_model, evaluate_model
from snn import SNNModelSimple
import torch
import matplotlib.pyplot as plt
import pickle

In [None]:
root_folder = "./tactile_dataset/"
file_name = "final_merged_df_sw500.csv"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_losses_t = {}
train_accuracies_t = {}
val_losses_t = {}
val_accuracies_t = {}
test_accuracies_t = {}
for t in [10, 50, 100, 150, 200]:
    print(f"Training for {t} steps")
    train_loader, val_loader, test_loader, num_outputs, num_features = load_data_label_encoded(root_folder, file_name, spike_encoding='temporal', num_steps=t, batch_size=128, device=device)

    model = SNNModelSimple(num_features, num_outputs)
    model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    train_losses, train_accuracies, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)
    train_losses_t[t] = train_losses
    train_accuracies_t[t] = train_accuracies
    val_losses_t[t] = val_losses
    val_accuracies_t[t] = val_accuracies

    test_accuracies_t[t] = evaluate_model(model, test_loader)

out_file_name = 'exp1_t_10_250_50_bs_128_ep_10.pkl' # exp1, t=10...250, step=50, batch_size=128, epochs=10
with open(out_file_name, 'wb') as f:
    pickle.dump({"train_losses": train_losses_t, "train_accuracies": train_accuracies_t, "val_losses": val_losses_t, "val_accuracies": val_accuracies_t, "test_accuracies": test_accuracies_t}, f)


Training for 10 steps
Epoch 1/10, Train Loss: 1.9828, Train Accuracy: 27.89%, Val Loss: 1.6521, Val Accuracy: 36.85%, Time: 2.58s
Epoch 2/10, Train Loss: 1.5549, Train Accuracy: 41.42%, Val Loss: 1.4934, Val Accuracy: 43.79%, Time: 2.46s
Epoch 3/10, Train Loss: 1.4625, Train Accuracy: 44.36%, Val Loss: 1.4452, Val Accuracy: 45.80%, Time: 2.44s
Epoch 4/10, Train Loss: 1.4220, Train Accuracy: 45.78%, Val Loss: 1.3924, Val Accuracy: 47.65%, Time: 2.42s
Epoch 5/10, Train Loss: 1.3917, Train Accuracy: 46.80%, Val Loss: 1.3740, Val Accuracy: 47.03%, Time: 2.40s
Epoch 6/10, Train Loss: 1.3774, Train Accuracy: 47.09%, Val Loss: 1.3560, Val Accuracy: 47.35%, Time: 2.42s
Epoch 7/10, Train Loss: 1.3581, Train Accuracy: 47.74%, Val Loss: 1.3572, Val Accuracy: 47.75%, Time: 2.47s
Epoch 8/10, Train Loss: 1.3508, Train Accuracy: 48.04%, Val Loss: 1.3340, Val Accuracy: 49.87%, Time: 2.44s
Epoch 9/10, Train Loss: 1.3334, Train Accuracy: 48.77%, Val Loss: 1.3335, Val Accuracy: 49.24%, Time: 2.42s
Epoch 