In [1]:
from dataloader import load_data_one_hot_encoded
from trainer import train_model, evaluate_model
from snn import SNNModelSimple
import torch
import matplotlib.pyplot as plt
import pickle

In [2]:
root_folder = "./tactile_dataset/"
file_name = "final_merged_df_sw500.csv"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_losses_wd = {}
train_accuracies_wd = {}
val_losses_wd = {}
val_accuracies_wd = {}
test_accuracies_wd = {}

for wd in [0, 1e-5, 1e-4, 5e-4, 1e-3, 5e-3]:
    print(f"Training for weight decay={wd}")
    train_loader, val_loader, test_loader, num_outputs, num_features = load_data_one_hot_encoded(root_folder, file_name, spike_encoding='temporal', num_steps=100, batch_size=128, device=device)

    model = SNNModelSimple(num_features, num_outputs)
    model.to(device)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=wd)

    train_losses, train_accuracies, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)
    train_losses_wd[wd] = train_losses
    train_accuracies_wd[wd] = train_accuracies
    val_losses_wd[wd] = val_losses
    val_accuracies_wd[wd] = val_accuracies

    test_accuracies_wd[wd] = evaluate_model(model, test_loader,encoding='one-hot')

out_file_name = 'exp4_temporal_wd_0_5e-3_t_100_bs_128_ep_10_lr_1e-3_temporal_mse.pkl'
with open(out_file_name, 'wb') as f:
    pickle.dump({"train_losses": train_losses_wd, "train_accuracies": train_accuracies_wd, "val_losses": val_losses_wd, "val_accuracies": val_accuracies_wd, "test_accuracies": test_accuracies_wd}, f)


Training for weight decay=0
Epoch 1/10, Train Loss: 0.0767, Train Accuracy: 9.42%, Val Loss: 71.3671, Val Accuracy: 13.14%, Time: 53.67s
Epoch 2/10, Train Loss: 0.0733, Train Accuracy: 20.32%, Val Loss: 115.2372, Val Accuracy: 25.78%, Time: 49.69s
Epoch 3/10, Train Loss: 0.0695, Train Accuracy: 30.84%, Val Loss: 117.8568, Val Accuracy: 32.65%, Time: 50.38s
Epoch 4/10, Train Loss: 0.0677, Train Accuracy: 35.38%, Val Loss: 126.8988, Val Accuracy: 38.86%, Time: 50.33s
Epoch 5/10, Train Loss: 0.0664, Train Accuracy: 38.09%, Val Loss: 135.5742, Val Accuracy: 39.14%, Time: 47.02s
Epoch 6/10, Train Loss: 0.0654, Train Accuracy: 37.77%, Val Loss: 136.0088, Val Accuracy: 39.44%, Time: 50.41s
Epoch 7/10, Train Loss: 0.0641, Train Accuracy: 40.75%, Val Loss: 145.3115, Val Accuracy: 43.17%, Time: 48.76s
Epoch 8/10, Train Loss: 0.0631, Train Accuracy: 42.46%, Val Loss: 156.6035, Val Accuracy: 44.72%, Time: 49.80s
Epoch 9/10, Train Loss: 0.0625, Train Accuracy: 43.57%, Val Loss: 160.8859, Val Accura