In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from tqdm import tqdm
from utils import TimeseriesDataset
from model import nts_RNN, train_model

In [None]:
def run_training(dataset_name, seq_len, batch_size, learning_rate, hidden_size, weight_decay, save_model, model_save_name):
    train_dataset = TimeseriesDataset(dataset_name, seq_len)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    model = nts_RNN(hidden_size)
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    num_epoch = 1000
    loss_log = np.zeros(num_epoch)
    for epoch_id in tqdm(range(num_epoch)):
        loss_log[epoch_id] = train_model(train_loader, model, loss_function, optimizer)
        if epoch_id % 100 == 0 or epoch_id == num_epoch-1:
            print(f"Epoch {epoch_id}, loss: {loss_log[epoch_id]}\n-----------")

    if save_model == True:
        torch.save(model.state_dict(), "trained_models/"+model_save_name)
    return model, loss_log

In [None]:
model, loss_log = run_training('TrainSet.mat', 10, 50, 1e-3, [16], 0, True, "L1_16_S10_B50_LR1e3_NR.pt")

In [None]:
model, loss_log = run_training('TrainSet.mat', 10, 50, 1e-3, [8], 0, True, "L1_8_S10_B50_LR1e3_NR.pt")
plt.figure()
plt.plot(range(1000), loss_log)

In [None]:
model, loss_log = run_training('TrainSet.mat', 10, 50, 1e-3, [8, 7, 6], 1e-4, True, "L3_S10_B50_LR1e3_R1e4.pt")
plt.figure()
plt.plot(range(1000), loss_log)