In [1]:
import numpy as np
import torch
import torch.utils.data as data
import pandas as pd
from src.error_measures import get_error_measures
from src.models import TwoLayerLeaky
from src.train_model import training_loop

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [3]:
np.random.seed(333)
torch.manual_seed(333)

<torch._C.Generator at 0x23b3c4d7230>

In [4]:
data06 = pd.read_csv("../data/processed/target06.csv")
data06.drop(columns=["weekday"], inplace=True)

In [5]:
all_data = data.TensorDataset(torch.from_numpy((data06.values[:,:-1] - data06.values[:,:-1].min(0)) / data06.values[:,:-1].ptp(0)).float(), torch.from_numpy(data06.values[:,-1]).float())  # normalization
train_dataset, test_dataset = torch.utils.data.random_split(all_data, (round(0.4 * len(all_data)), round(0.6 * len(all_data))))

In [6]:
#train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)

In [7]:
for i in range(30):
    h = np.random.randint(1000, 2500)
    e = np.random.randint(20, 50)
    s = np.random.randint(20, 30)
    b = np.random.choice([32, 64])
    print(f'\n------------- Random set: {i} -----------------\nHidden: {h}, epochs: {e}, steps: {s}, batch size: {b}')
    train_loader = data.DataLoader(train_dataset, batch_size=int(b), shuffle=True, drop_last=True)
    model = TwoLayerLeaky(num_inputs=7, num_hidden=h, num_outputs=1, beta=0.95, num_steps=s)
    training_loop(model, train_loader, test_loader, device, num_epochs=e, validation=False)
    get_error_measures(model, test_loader, device=device, print_e=True)


------------- Random set: 0 -----------------
Hidden: 1973, epochs: 23, steps: 23, batch size: 32


100%|██████████| 23/23 [00:50<00:00,  2.22s/it]


Accuracy on validation dataset: 38.45116547871316%
MAE: 6.929131181087137, MSE:98.79313193820037,
RMSE: 9.939473423587406, Index of Agreement: 0.8500468653125463

------------- Random set: 1 -----------------
Hidden: 2398, epochs: 34, steps: 23, batch size: 64


100%|██████████| 34/34 [00:35<00:00,  1.06s/it]


Accuracy on validation dataset: 38.932768252745134%
MAE: 6.866106725891898, MSE:98.50688131193391,
RMSE: 9.925063290071954, Index of Agreement: 0.8507236827152364

------------- Random set: 2 -----------------
Hidden: 1071, epochs: 43, steps: 26, batch size: 32


100%|██████████| 43/43 [01:42<00:00,  2.39s/it]


Accuracy on validation dataset: 38.45116547871316%
MAE: 6.904280479508681, MSE:101.25382162002806,
RMSE: 10.062495794783125, Index of Agreement: 0.8432570857401828

------------- Random set: 3 -----------------
Hidden: 1116, epochs: 32, steps: 24, batch size: 64


100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Accuracy on validation dataset: 37.14120593334618%
MAE: 7.177809658479516, MSE:107.18825281208647,
RMSE: 10.353175977065515, Index of Agreement: 0.8313244801154727

------------- Random set: 4 -----------------
Hidden: 1151, epochs: 33, steps: 25, batch size: 64


100%|██████████| 33/33 [00:37<00:00,  1.15s/it]


Accuracy on validation dataset: 37.603544596416874%
MAE: 6.984923908462921, MSE:102.36238043588122,
RMSE: 10.11742953698622, Index of Agreement: 0.8423753684443211

------------- Random set: 5 -----------------
Hidden: 1584, epochs: 28, steps: 27, batch size: 32


100%|██████████| 28/28 [01:13<00:00,  2.63s/it]


Accuracy on validation dataset: 38.48969370063571%
MAE: 6.900265835172476, MSE:97.68478570428842,
RMSE: 9.883561387692618, Index of Agreement: 0.8508499852196194

------------- Random set: 6 -----------------
Hidden: 1595, epochs: 23, steps: 23, batch size: 64


100%|██████████| 23/23 [00:25<00:00,  1.09s/it]


Accuracy on validation dataset: 36.1009439414371%
MAE: 7.073130424806687, MSE:104.04626597146874,
RMSE: 10.200307150839564, Index of Agreement: 0.8410390856954644

------------- Random set: 7 -----------------
Hidden: 1441, epochs: 39, steps: 23, batch size: 64


100%|██████████| 39/39 [00:42<00:00,  1.10s/it]


Accuracy on validation dataset: 39.067617029474086%
MAE: 6.937934870665713, MSE:99.23975165614411,
RMSE: 9.96191505967322, Index of Agreement: 0.8493174236197023

------------- Random set: 8 -----------------
Hidden: 1279, epochs: 37, steps: 20, batch size: 64


100%|██████████| 37/37 [00:35<00:00,  1.04it/s]


Accuracy on validation dataset: 36.33211327297246%
MAE: 6.993848967296129, MSE:103.68205285430862,
RMSE: 10.182438453254143, Index of Agreement: 0.8400348051965585

------------- Random set: 9 -----------------
Hidden: 1834, epochs: 27, steps: 20, batch size: 64


100%|██████████| 27/27 [00:25<00:00,  1.04it/s]


Accuracy on validation dataset: 39.02908880755153%
MAE: 6.978874967980536, MSE:101.25223338320724,
RMSE: 10.062416875840876, Index of Agreement: 0.8468243162079679

------------- Random set: 10 -----------------
Hidden: 1269, epochs: 30, steps: 20, batch size: 32


100%|██████████| 30/30 [00:55<00:00,  1.86s/it]


Accuracy on validation dataset: 35.96609516470815%
MAE: 7.046362911247146, MSE:104.21872219904488,
RMSE: 10.208757132924893, Index of Agreement: 0.840001686740031

------------- Random set: 11 -----------------
Hidden: 1645, epochs: 39, steps: 23, batch size: 64


100%|██████████| 39/39 [00:42<00:00,  1.10s/it]


Accuracy on validation dataset: 37.29531882103641%
MAE: 6.947408963939951, MSE:99.54358062709152,
RMSE: 9.977152931928602, Index of Agreement: 0.8493134931586613

------------- Random set: 12 -----------------
Hidden: 2077, epochs: 47, steps: 20, batch size: 64


100%|██████████| 47/47 [00:46<00:00,  1.02it/s]


Accuracy on validation dataset: 38.027355037565016%
MAE: 6.90287998452441, MSE:100.60447382141672,
RMSE: 10.030178155018818, Index of Agreement: 0.8460974256547953

------------- Random set: 13 -----------------
Hidden: 2090, epochs: 48, steps: 28, batch size: 32


100%|██████████| 48/48 [02:04<00:00,  2.60s/it]


Accuracy on validation dataset: 38.85571180890002%
MAE: 6.851918699148417, MSE:98.63818581658893,
RMSE: 9.931675881571495, Index of Agreement: 0.847948854439009

------------- Random set: 14 -----------------
Hidden: 1596, epochs: 45, steps: 26, batch size: 32


100%|██████████| 45/45 [01:48<00:00,  2.41s/it]


Accuracy on validation dataset: 39.318050471970714%
MAE: 6.857518774217045, MSE:98.22411546602375,
RMSE: 9.910808012771902, Index of Agreement: 0.8517614033912232

------------- Random set: 15 -----------------
Hidden: 2078, epochs: 40, steps: 21, batch size: 64


100%|██████████| 40/40 [00:39<00:00,  1.00it/s]


Accuracy on validation dataset: 38.412637256790596%
MAE: 6.852371410146841, MSE:98.13500175818977,
RMSE: 9.90631120842616, Index of Agreement: 0.8508897081404853

------------- Random set: 16 -----------------
Hidden: 1460, epochs: 21, steps: 20, batch size: 32


100%|██████████| 21/21 [00:37<00:00,  1.79s/it]


Accuracy on validation dataset: 37.43016759776536%
MAE: 7.313827784207027, MSE:105.92909411172928,
RMSE: 10.292186070594006, Index of Agreement: 0.8395182098603532

------------- Random set: 17 -----------------
Hidden: 2133, epochs: 29, steps: 23, batch size: 64


100%|██████████| 29/29 [00:30<00:00,  1.06s/it]


Accuracy on validation dataset: 38.62454247736467%
MAE: 6.881088433663433, MSE:98.16843765548464,
RMSE: 9.90799867054314, Index of Agreement: 0.850624270898964

------------- Random set: 18 -----------------
Hidden: 1691, epochs: 48, steps: 22, batch size: 32


100%|██████████| 48/48 [01:38<00:00,  2.04s/it]


Accuracy on validation dataset: 38.23926025813908%
MAE: 6.874879591360737, MSE:97.45593289563654,
RMSE: 9.871977152305233, Index of Agreement: 0.8501710719701114

------------- Random set: 19 -----------------
Hidden: 2324, epochs: 32, steps: 22, batch size: 64


100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Accuracy on validation dataset: 38.528221922558274%
MAE: 6.871178959426887, MSE:98.76849967709256,
RMSE: 9.938234233358186, Index of Agreement: 0.8499566215772401

------------- Random set: 20 -----------------
Hidden: 1272, epochs: 26, steps: 25, batch size: 32


100%|██████████| 26/26 [00:58<00:00,  2.24s/it]


Accuracy on validation dataset: 38.412637256790596%
MAE: 7.016079769408259, MSE:100.4245742873016,
RMSE: 10.021206229157325, Index of Agreement: 0.8473470316242413

------------- Random set: 21 -----------------
Hidden: 1367, epochs: 36, steps: 27, batch size: 32


100%|██████████| 36/36 [01:28<00:00,  2.47s/it]


Accuracy on validation dataset: 38.586014255442116%
MAE: 6.794721629343564, MSE:98.20495831912395,
RMSE: 9.90984148809273, Index of Agreement: 0.8497950972433074

------------- Random set: 22 -----------------
Hidden: 1432, epochs: 21, steps: 28, batch size: 32


 10%|▉         | 2/21 [00:05<00:48,  2.54s/it]


KeyboardInterrupt: 