In [4]:
from dataloader import load_data_label_encoded
from trainer import train_model, evaluate_model
from snn import SNNModelSimple
import torch
import matplotlib.pyplot as plt
import pickle

In [5]:
root_folder = "./tactile_dataset/"
file_name = "final_merged_df_sw500.csv"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_losses_t = {}
train_accuracies_t = {}
val_losses_t = {}
val_accuracies_t = {}
test_accuracies_t = {}
for t in [10, 50, 100, 150, 200]:
    print(f"Training for {t} steps")
    train_loader, val_loader, test_loader, num_outputs, num_features = load_data_label_encoded(root_folder, file_name, spike_encoding='temporal', num_steps=t, batch_size=128, device=device)

    model = SNNModelSimple(num_features, num_outputs)
    model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    train_losses, train_accuracies, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)
    train_losses_t[t] = train_losses
    train_accuracies_t[t] = train_accuracies
    val_losses_t[t] = val_losses
    val_accuracies_t[t] = val_accuracies

    test_accuracies_t[t] = evaluate_model(model, test_loader)

out_file_name = 'exp1_t_10_250_50_bs_128_ep_10.pkl' # exp1, t=10...250, step=50, batch_size=128, epochs=10
with open(out_file_name, 'wb') as f:
    pickle.dump({"train_losses": train_losses_t, "train_accuracies": train_accuracies_t, "val_losses": val_losses_t, "val_accuracies": val_accuracies_t, "test_accuracies": test_accuracies_t}, f)


Training for 10 steps
Epoch 1/10, Train Loss: 1.9986, Train Accuracy: 27.06%, Val Loss: 1.6613, Val Accuracy: 36.93%, Time: 2.60s
Epoch 2/10, Train Loss: 1.5908, Train Accuracy: 39.58%, Val Loss: 1.5516, Val Accuracy: 40.50%, Time: 2.56s
Epoch 3/10, Train Loss: 1.5162, Train Accuracy: 42.11%, Val Loss: 1.4965, Val Accuracy: 43.33%, Time: 2.64s
Epoch 4/10, Train Loss: 1.4754, Train Accuracy: 43.63%, Val Loss: 1.4697, Val Accuracy: 42.66%, Time: 2.53s
Epoch 5/10, Train Loss: 1.4396, Train Accuracy: 44.74%, Val Loss: 1.4285, Val Accuracy: 45.89%, Time: 2.49s
Epoch 6/10, Train Loss: 1.4077, Train Accuracy: 45.49%, Val Loss: 1.4246, Val Accuracy: 44.60%, Time: 2.71s
Epoch 7/10, Train Loss: 1.3805, Train Accuracy: 46.55%, Val Loss: 1.3891, Val Accuracy: 46.13%, Time: 2.48s
Epoch 8/10, Train Loss: 1.3736, Train Accuracy: 46.72%, Val Loss: 1.3975, Val Accuracy: 46.63%, Time: 2.46s
Epoch 9/10, Train Loss: 1.3606, Train Accuracy: 47.07%, Val Loss: 1.3812, Val Accuracy: 46.64%, Time: 2.46s
Epoch 