In [1]:
"""
 * @file fine_search.py
 * @author Leif Huender
 * @brief 
 * @version 0.1
 * @date 2024-06-13
 * This script runs a fine search narrowing down further the best performant model on a search space of 352 models 
 * @copyright Copyright (c) 2024 Leif Huender
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
"""

import lstm
import torch
import plotly
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
#load all the datasets
train = torch.load('../../../data/cleaned/train.pt')
val = torch.load('../../../data/cleaned/val.pt')
test = torch.load('../../../data/cleaned/test.pt')

#make them into dataloaders
batch_size = 32
train_loader = DataLoader(train, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test, batch_size, shuffle=True, drop_last=True)

'''
Fine Search results
Min Train RMSE: 11.33083176851998
Min Validation RMSE: 70.89486087741568
{'hidden_size': 200, 'num_layers': 2, 'bias': False, 'batch_first': True, 'dropout': 0, 'bidirectional': False, 'proj_size': 0}
'''

model = lstm.LSTM(input_size=19, hidden_size=200, num_layers=2, bias=False, batch_first=True, dropout=0, bidirectional=False, proj_size=0)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
trainer = lstm.Trainer(model, train_loader, val_loader, loss_fn, optimizer)
model, train_losses, val_losses = trainer.train(epochs=100)

Epoch: 1/100
---------
Train Loss RMSE: 289.5264564023874, Validation Loss RMSE: 139.48220110045628

Epoch: 2/100
---------
Train Loss RMSE: 287.92893388338155, Validation Loss RMSE: 136.34371821638985

Epoch: 3/100
---------
Train Loss RMSE: 288.32619343088083, Validation Loss RMSE: 138.51910819436046

Epoch: 4/100
---------
Train Loss RMSE: 287.93211921730676, Validation Loss RMSE: 103.31949934698919

Epoch: 5/100
---------
Train Loss RMSE: 283.1783927472261, Validation Loss RMSE: 142.430250799377

Epoch: 6/100
---------
Train Loss RMSE: 281.60943850744576, Validation Loss RMSE: 133.98978373740326

Epoch: 7/100
---------
Train Loss RMSE: 276.946527045703, Validation Loss RMSE: 152.60717129436856

Epoch: 8/100
---------
Train Loss RMSE: 270.7905172173636, Validation Loss RMSE: 131.62426038862122

Epoch: 9/100
---------
Train Loss RMSE: 280.60648374257516, Validation Loss RMSE: 136.94628878046382

Epoch: 10/100
---------
Train Loss RMSE: 270.4463264691871, Validation Loss RMSE: 139.289

In [3]:
print(train_losses)

[289.5264564023874, 287.92893388338155, 288.32619343088083, 287.93211921730676, 283.1783927472261, 281.60943850744576, 276.946527045703, 270.7905172173636, 280.60648374257516, 270.4463264691871, 275.5212306675792, 263.8531826518442, 258.15918304046, 252.63983388223807, 272.96763670004344, 261.7487733879408, 270.1776803859669, 247.36560500987449, 239.90934002501444, 225.55062854795835, 203.61614689166765, 260.8288501931771, 263.77040602622327, 272.1688575268679, 260.4015152586397, 240.37840237495737, 223.8314675067749, 273.79914185759804, 249.73499667890007, 242.5623277627277, 263.29149087695356, 283.7944753425953, 261.72286917151945, 229.04416070432467, 213.05666714971514, 208.62772442279044, 205.96247036194262, 196.7831203095114, 187.26978115577398, 210.8288038846771, 283.77627789252364, 179.6366608704434, 169.36258568100956, 159.01073248261815, 219.52383666033077, 159.38294802679434, 161.0188510535832, 147.98631223400082, 150.39579209407472, 163.76748977917723, 157.3175688613123, 156

In [4]:
print(val_losses)

[139.48220110045628, 136.34371821638985, 138.51910819436046, 103.31949934698919, 142.430250799377, 133.98978373740326, 152.60717129436856, 131.62426038862122, 136.94628878046382, 139.28982210131372, 137.47049754515513, 149.65350340550992, 172.90100793495068, 70.98303535770889, 143.86624842722298, 131.04960776295448, 131.99305290667763, 126.54639943885208, 129.53076866805927, 130.1860588239362, 206.88149595635812, 137.31838458530302, 144.36016182881684, 101.88118271116114, 136.34177917832713, 193.67048771689505, 133.06817095969492, 128.2660956144674, 141.55775292123485, 131.00997677151108, 137.45944536878415, 131.86650779825123, 127.48105854630302, 107.55590672431273, 119.97315398964327, 135.26582805084678, 129.70180455748775, 76.756208237052, 177.11657358826824, 136.03798818166794, 85.69740395102592, 90.71221423020648, 150.87134209396552, 142.39255775120301, 112.91640276625834, 167.27011548517515, 167.54179894740074, 74.61903487858133, 153.37364339189668, 215.8133303452025, 189.2218320