In [1]:
from util_fun import narx_sim_nrms, calculate_error_nrms, print_log
from model import Narx
from data import load_data, convert_to_narx, GS_Dataset, create_gs_dataset
import numpy as np
import pandas as pd
import sys
import torch
from sklearn.model_selection import train_test_split
from copy import deepcopy
import matplotlib.pyplot as plt
from dataclasses import dataclass
from time import time, sleep
from datetime import timedelta
from os import path

sys.path.insert(1, "../")

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N_EPOCHS = 100

In [2]:
params = {
    "n_a": [5, 9, 13],
    "n_b": [20,25, 30],
    "n_layers": [1],
    "n_nodes": [10, 30, 50, 100],
}

test_params = {
    "n_a": [*range(1, 4)],
    "n_b": [*range(10, 13)],
    "n_layers": [1, 3],
    "n_nodes": [10, 20],
}


In [3]:
@dataclass
class GSResults:
    best_model: Narx = None
    best_sim_model: Narx = None
    best_nrms: float = None
    best_sim_nrms: float = None
    loss_list: list = None
    nrms_list: list = None
    sim_nrms_list: list = None

In [4]:
def train_narx_simval(
    model: Narx,
    n_a: int,
    n_b: int,
    data: GS_Dataset,
    log_file: str,
    param_msg: str = None,
    n_epochs: int = N_EPOCHS,
    device: torch.device = DEVICE,
):
    # initialise comparison values and results lists
    best_nrms = float("inf")
    best_model = None
    best_sim_nrms = float("inf")
    best_sim_model = None
    loss_list = []
    nrms_list = []
    sim_nrms_list = []

    # initialise checkpoints for validation
    checkpoints = [*range(0, n_epochs + 1, max(n_epochs // 25, 1))]
    if checkpoints[-1] != n_epochs:
        checkpoints += [n_epochs]

    # start training loop
    for epoch in range(n_epochs):
        optimizer = torch.optim.Adam(model.parameters())
        loss = torch.mean((model(data.x_train) - data.y_train) ** 2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch in checkpoints:
            print_log(f"Checkpoint at epoch {epoch+1}: " + param_msg + " \n", log_file)
            # append loss to list, check prediction and simulation nrms
            loss_list.append(loss.item())

            nrms = calculate_error_nrms(model.forward(data.x_val), data.y_val)
            nrms_list.append(nrms)
            if nrms < best_nrms:
                print_log(
                    f"current pred NRMS: {nrms}, previous best pred NRMS: {best_nrms} \n",
                    log_file,
                )
                best_nrms = nrms
                best_model = deepcopy(model)

            _, _, _, sim_nrms = narx_sim_nrms(
                model, n_a, n_b, data.x_data, data.y_data, device
            )
            sim_nrms_list.append(sim_nrms)
            if sim_nrms < best_sim_nrms:
                print_log(
                    f"current sim NRMS: {sim_nrms}, previous best sim NRMS: {best_sim_nrms} \n",
                    log_file,
                )
                best_sim_nrms = sim_nrms
                best_sim_model = deepcopy(model)

    results = GSResults(
        best_model,
        best_sim_model,
        best_nrms,
        best_sim_nrms,
        loss_list,
        nrms_list,
        sim_nrms_list,
    )
    return results

In [5]:
def make_log_file():
    filename = "narx_gs_log"
    fileext = ".txt"
    i = 0
    while path.exists(filename + str(i) + fileext):
        i += 1
    return filename + str(i) + fileext

def save_models(pred_model, sim_model):
    filename_sim = "narx_gs_best_sim"
    i = 0
    while path.exists(filename_sim + str(i)):
        i += 1
    filename_sim= filename_sim + str(i)
    torch.save(sim_model.state_dict(),filename_sim)

    filename = "narx_gs_best_pred"
    i = 0
    while path.exists(filename + str(i)):
        i += 1
    filename= filename + str(i)
    torch.save(pred_model.state_dict(),filename)

    print_log(f'Saved best model in {filename}, best sim model in {filename_sim}', log_file)

In [6]:
results_dict = {}
device = DEVICE
n_epochs = N_EPOCHS
best_nrms = float("inf")
best_sim_nrms = float("inf")
best_model = None
best_sim_model = None
best_params = None
best_sim_params = None

grid_search_params = params
x, y = load_data()
total_runs = 1
for key in grid_search_params:
    total_runs *= len(grid_search_params[key])
run_counter = 0
start_time_list = []
log_file = make_log_file()

print_log(f"Starting new Grid Search with parameters {grid_search_params} \n", log_file)
for i, n_a in enumerate(grid_search_params["n_a"]):
    for j, n_b in enumerate(grid_search_params["n_b"]):
        # n_a and n_b are the only two parameters that change the dataset
        data = create_gs_dataset(x, y, n_a, n_b, device)
        for k, n_nodes in enumerate(grid_search_params["n_nodes"]):
            for l, n_layers in enumerate(grid_search_params["n_layers"]):
                # general administration and timekeeping
                run_counter += 1
                start_time_list.append(time())
                param_string = f"{n_a=}, {n_b=}, {n_nodes=}, {n_layers=}"
                print_log(
                    f"Starting run {run_counter} out of {total_runs} \n", log_file
                )
                # generate model, do the actual training run, save the results
                model = Narx(n_a + n_b, n_nodes, n_layers).to(DEVICE)
                result = train_narx_simval(
                    model, n_a, n_b, data, log_file, param_string, n_epochs, device
                )
                results_dict[param_string] = result

                # check new results against old results, save if better
                if result.best_nrms < best_nrms:
                    print_log(
                        f"Found new best prediction model, with parameters {param_string} \n",
                        log_file,
                    )
                    print_log(
                        f"new best pred NRMS= {result.best_nrms}, previous best: {best_nrms} \n",
                        log_file,
                    )
                    best_nrms = result.best_nrms
                    best_model = deepcopy(result.best_model)
                    best_params = param_string
                if result.best_sim_nrms < best_sim_nrms:
                    print_log(
                        f"Found new best simulation model, with parameters {param_string} \n",
                        log_file,
                    )
                    print_log(
                        f"new best sim NRMS= {result.best_sim_nrms}, previous best: {best_sim_nrms} \n",
                        log_file,
                    )
                    best_sim_nrms = result.best_sim_nrms
                    best_sim_model = deepcopy(result.best_sim_model)
                    best_sim_params = param_string

                # finish the run
                run_time = timedelta(seconds=time() - start_time_list[-1])
                total_time = timedelta(seconds=time() - start_time_list[0])
                print_log(
                    f"Finished run {run_counter} out of {total_runs}. Time elapsed this run: {run_time}, total time elapsed: {total_time} \n",
                    log_file,
                )

print_log(
    f"Best prediction model found with parameters: {best_params}, and NRMS: {best_nrms}. \n "
    + f"Best simulation model found with parameters: {best_sim_params}, and NRMS: {best_sim_nrms}. \n"
    + f"Total time elapsed: {time()-start_time_list[0]} \n",
    log_file,
)
save_models(best_model, best_sim_model)
print(f'Run log saved in {log_file}')

Starting new Grid Search with parameters {'n_a': [5, 9, 13], 'n_b': [20, 25, 30], 'n_layers': [1], 'n_nodes': [10, 30, 50, 100]} 

Starting run 1 out of 36 

Checkpoint at epoch 1: n_a=5, n_b=20, n_nodes=10, n_layers=1 

current pred NRMS: 1.3527598603634752, previous best pred NRMS: inf 

current sim NRMS: 1.425197598705685, previous best sim NRMS: inf 

Checkpoint at epoch 5: n_a=5, n_b=20, n_nodes=10, n_layers=1 

current pred NRMS: 1.3011145180409576, previous best pred NRMS: 1.3527598603634752 

current sim NRMS: 1.3649383703624245, previous best sim NRMS: 1.425197598705685 

Checkpoint at epoch 9: n_a=5, n_b=20, n_nodes=10, n_layers=1 

current pred NRMS: 1.2519743765017233, previous best pred NRMS: 1.3011145180409576 

current sim NRMS: 1.3073158498743904, previous best sim NRMS: 1.3649383703624245 

Checkpoint at epoch 13: n_a=5, n_b=20, n_nodes=10, n_layers=1 

current pred NRMS: 1.2055883008469066, previous best pred NRMS: 1.2519743765017233 

current sim NRMS: 1.252642144834