In [1]:
from dataloader import load_data_one_hot_encoded
from trainer import train_model, evaluate_model
from snn import SNNModelSimple
import torch
import matplotlib.pyplot as plt
import pickle

In [2]:
root_folder = "./tactile_dataset/"
file_name = "final_merged_df_sw500.csv"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_losses_t = {}
train_accuracies_t = {}
val_losses_t = {}
val_accuracies_t = {}
test_accuracies_t = {}
for t in [10, 50, 100, 150, 200]:
    print(f"Training for {t} steps")
    train_loader, val_loader, test_loader, num_outputs, num_features = load_data_one_hot_encoded(root_folder, file_name, spike_encoding='temporal', num_steps=t, batch_size=128, device=device)

    model = SNNModelSimple(num_features, num_outputs)
    model.to(device)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    train_losses, train_accuracies, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)
    train_losses_t[t] = train_losses
    train_accuracies_t[t] = train_accuracies
    val_losses_t[t] = val_losses
    val_accuracies_t[t] = val_accuracies

    test_accuracies_t[t] = evaluate_model(model, test_loader, encoding='one-hot')

out_file_name = 'exp1_t_10_250_50_bs_128_ep_10_temporal_mse.pkl' # exp1, t=10...250, step=50, batch_size=128, epochs=10
with open(out_file_name, 'wb') as f:
    pickle.dump({"train_losses": train_losses_t, "train_accuracies": train_accuracies_t, "val_losses": val_losses_t, "val_accuracies": val_accuracies_t, "test_accuracies": test_accuracies_t}, f)


Training for 10 steps
Epoch 1/10, Train Loss: 0.0709, Train Accuracy: 22.89%, Val Loss: 1.4347, Val Accuracy: 37.88%, Time: 5.31s
Epoch 2/10, Train Loss: 0.0606, Train Accuracy: 40.52%, Val Loss: 1.7483, Val Accuracy: 41.96%, Time: 4.62s
Epoch 3/10, Train Loss: 0.0585, Train Accuracy: 43.02%, Val Loss: 1.8717, Val Accuracy: 45.36%, Time: 4.86s
Epoch 4/10, Train Loss: 0.0570, Train Accuracy: 45.20%, Val Loss: 2.0143, Val Accuracy: 46.28%, Time: 6.74s
Epoch 5/10, Train Loss: 0.0561, Train Accuracy: 46.34%, Val Loss: 2.1073, Val Accuracy: 45.73%, Time: 7.88s
Epoch 6/10, Train Loss: 0.0555, Train Accuracy: 47.09%, Val Loss: 1.9111, Val Accuracy: 47.49%, Time: 5.11s
Epoch 7/10, Train Loss: 0.0550, Train Accuracy: 48.10%, Val Loss: 2.0204, Val Accuracy: 48.22%, Time: 4.35s
Epoch 8/10, Train Loss: 0.0547, Train Accuracy: 48.84%, Val Loss: 2.0827, Val Accuracy: 47.74%, Time: 5.37s
Epoch 9/10, Train Loss: 0.0543, Train Accuracy: 49.09%, Val Loss: 1.9642, Val Accuracy: 48.86%, Time: 4.90s
Epoch 