In [None]:
import numpy as np
import torch
import json

def save_best_performing_model(models, test_accs, path):
    best_test_acc_index = np.argmax(test_accs)
    best_model = models[best_test_acc_index]

    torch.save(best_model.state_dict(), path)

def save_all_models_per_epoch(models, path_prefix):
    for index, model in enumerate(models):
        torch.save(model.state_dict(), f'{path_prefix}_epoch_{index}.pth')

def save_configuration_output(result_data, output_file_path):
    with open(output_file_path, "w") as file:
        json.dump(result_data, file, indent=4) 

In [None]:
from constants import (
    BETA,
    NUMBER_INPUT_NEURONS,
    NUMBER_OUTPUT_NEURONS,
    THRESHOLD,
    TIME_STEPS,
)
from neural_nets.configurable_spiking_neural_net import ConfigurableSpikingNeuralNet
from training.train_snn import train_snn
from util.save_plots import save_history_plot, save_loss_per_time_step_plot

num_epochs = 'early_stopping'
sparsity = 0.2
number_hidden_neurons = 10
number_hidden_layer = 1
loss_configuration="population_coding"

model = ConfigurableSpikingNeuralNet(
    number_input_neurons=NUMBER_INPUT_NEURONS,
    number_hidden_neurons=number_hidden_neurons,
    number_output_neurons=NUMBER_OUTPUT_NEURONS,
    beta=BETA,
    threshold=THRESHOLD,
    time_steps=TIME_STEPS,
    number_hidden_layers=number_hidden_layer,
)

training_acc_history, test_acc_history, loss_history, total_training_time, epoch_loss_per_time_step, models_per_epoch, = train_snn(
    model,
    num_epochs=num_epochs,
    sparsity=sparsity,
    loss_configuration=loss_configuration,
    use_train_data_subset=1000
)

assert len(training_acc_history) == len(test_acc_history) == len(models_per_epoch)

save_best_performing_model(models_per_epoch, test_acc_history, './models/test.pth')

save_all_models_per_epoch(models_per_epoch, f'./models/test')

result_data = {
    "number_input_neurons": NUMBER_INPUT_NEURONS,
    "number_hidden_neurons": number_hidden_neurons,
    "number_output_neurons": NUMBER_OUTPUT_NEURONS,
    "beta": BETA,
    "threshold": THRESHOLD,
    "time_steps": TIME_STEPS,
    "number_hidden_layers": number_hidden_layer,
    "epochs": len(test_acc_history),
    "sparsity": sparsity,
    "loss_configuration": loss_configuration,
    "best_test_accuracy": np.max(test_acc_history),
    "training_accuracy": training_acc_history[np.argmax(test_acc_history)]
}

save_configuration_output(result_data, './output/test.json')

save_history_plot(training_acc_history, './output/test_train_acc.jpg')

save_history_plot(test_acc_history, './output/test_test_acc.jpg')

save_history_plot(loss_history, './output/test_loss.jpg')

if epoch_loss_per_time_step:
    save_loss_per_time_step_plot(epoch_loss_per_time_step, path=f'./output/test_loss_per_time_steps.png')

In [None]:
from util.plot_layer_development import plot_layer_development

plot_layer_development(models_per_epoch, [f'Epoch {epoch}' for epoch in range(len(models_per_epoch))])