In [4]:
from dataloader import load_data_one_hot_encoded
from trainer import train_model, evaluate_model
from snn import SNNModelSimple
import torch
import matplotlib.pyplot as plt
import pickle

In [6]:
root_folder = "./tactile_dataset/"
file_name = "final_merged_df_sw500.csv"

train_losses_l = {}
train_accuracies_l = {}
val_losses_l = {}
val_accuracies_l = {}
test_accuracies_l = {}

for lr in [1e-4, 1e-3, 5e-3, 1e-2, 5e-2]:
    print(f"Training for lr={lr}")
    train_loader, val_loader, test_loader, num_outputs, num_features = load_data_one_hot_encoded(root_folder, file_name, num_steps=100, batch_size=128)

    model = SNNModelSimple(num_features, num_outputs)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_losses, train_accuracies, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)
    train_losses_l[lr] = train_losses
    train_accuracies_l[lr] = train_accuracies
    val_losses_l[lr] = val_losses
    val_accuracies_l[lr] = val_accuracies

    test_accuracies_l[lr] = evaluate_model(model, test_loader, encoding='one-hot')

out_file_name = 'exp3_mse_lr_1e-4_5e-2_t_100_bs_128_ep_10.pkl'
with open(out_file_name, 'wb') as f:
    pickle.dump({"train_losses": train_losses_l, "train_accuracies": train_accuracies_l, "val_losses": val_losses_l, "val_accuracies": val_accuracies_l, "test_accuracies": test_accuracies_l}, f)


Training for lr=0.0001
Epoch 1/10, Train Loss: 0.0737, Train Accuracy: 19.16%, Val Loss: 114.3438, Val Accuracy: 26.84%, Time: 24.83s
Epoch 2/10, Train Loss: 0.0672, Train Accuracy: 28.39%, Val Loss: 142.6147, Val Accuracy: 29.61%, Time: 26.65s
Epoch 3/10, Train Loss: 0.0659, Train Accuracy: 29.72%, Val Loss: 150.7349, Val Accuracy: 30.75%, Time: 27.94s
Epoch 4/10, Train Loss: 0.0653, Train Accuracy: 30.60%, Val Loss: 160.0719, Val Accuracy: 31.04%, Time: 30.90s
Epoch 5/10, Train Loss: 0.0649, Train Accuracy: 31.16%, Val Loss: 163.3020, Val Accuracy: 31.31%, Time: 32.58s
Epoch 6/10, Train Loss: 0.0645, Train Accuracy: 31.68%, Val Loss: 162.1619, Val Accuracy: 31.78%, Time: 32.43s
Epoch 7/10, Train Loss: 0.0643, Train Accuracy: 32.08%, Val Loss: 170.8922, Val Accuracy: 32.63%, Time: 34.08s
Epoch 8/10, Train Loss: 0.0640, Train Accuracy: 32.35%, Val Loss: 167.4145, Val Accuracy: 32.08%, Time: 35.85s
Epoch 9/10, Train Loss: 0.0638, Train Accuracy: 32.69%, Val Loss: 175.8862, Val Accuracy: