In [1]:
import torch
import torch.utils.data as data
import pandas as pd
from src.error_measures import get_accuracy, get_error_measures, collect_stats, print_average_stats
from src.models import SimpleANN, SynapticSNN, LeakySNN, DoubleLeakySNN
from src.train_model import training_loop

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

## Data

In [3]:
data06 = pd.read_csv("../data/processed/target06.csv")
data06.drop(columns=["weekday"], inplace=True)

In [4]:
all_data = data.TensorDataset(torch.from_numpy((data06.values[:,:-1] - data06.values[:,:-1].min(0)) / data06.values[:,:-1].ptp(0)).float(),
                              torch.from_numpy(data06.values[:,-1]).float())  # with normalization
train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(all_data, (round(0.75 * len(all_data)), round(0.15 * len(all_data)), round(0.1 * len(all_data))))

In [5]:
train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)
valid_loader = data.DataLoader(valid_dataset, batch_size=256, shuffle=False, drop_last=False)

### Simple regressor

In [6]:
model1 = SimpleANN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model1)

TwoLayerPerceptron(
  (lin1): Linear(in_features=8, out_features=3500, bias=True)
  (relu): LeakyReLU(negative_slope=0.01)
  (lin2): Linear(in_features=3500, out_features=1, bias=True)
)


In [7]:
training_loop(model1, train_loader, valid_loader, device, num_epochs=20, lr=0.003, validation=True)

  5%|▌         | 1/20 [00:02<00:45,  2.39s/it]

MAE: 6.977179198733644, MSE:100.34206957404284,
RMSE: 10.017088877215917, IA: 0.8486668125695818, MAPE: 44.371183599082215%


 10%|█         | 2/20 [00:02<00:21,  1.20s/it]

MAE: 6.859676299894476, MSE:97.8839128699343,
RMSE: 9.893629913734104, IA: 0.8635434747844735, MAPE: 44.288341437970004%


 15%|█▌        | 3/20 [00:03<00:14,  1.21it/s]

MAE: 6.969098228939695, MSE:97.33048478839102,
RMSE: 9.86562135845437, IA: 0.8621816093750774, MAPE: 47.43356733014305%


 20%|██        | 4/20 [00:03<00:10,  1.54it/s]

MAE: 6.659572271666775, MSE:93.28218445162643,
RMSE: 9.65827026188574, IA: 0.8630887539125274, MAPE: 41.27137472499447%


 25%|██▌       | 5/20 [00:03<00:08,  1.82it/s]

MAE: 6.639479773306433, MSE:92.46664903629382,
RMSE: 9.615958040481136, IA: 0.861579510532865, MAPE: 41.16119479137254%


 30%|███       | 6/20 [00:04<00:06,  2.07it/s]

MAE: 7.435063612392183, MSE:102.40852001580932,
RMSE: 10.119709482777127, IA: 0.8551964159121052, MAPE: 53.959092683081266%


 35%|███▌      | 7/20 [00:04<00:05,  2.26it/s]

MAE: 6.705942194172413, MSE:94.4678576468421,
RMSE: 9.719457682753813, IA: 0.8724682135287085, MAPE: 42.567716138133704%


 40%|████      | 8/20 [00:04<00:05,  2.40it/s]

MAE: 6.892034649159867, MSE:95.52961900386084,
RMSE: 9.773925465434083, IA: 0.8677285824378084, MAPE: 46.80935050946897%


 45%|████▌     | 9/20 [00:05<00:04,  2.51it/s]

MAE: 6.6076300791922336, MSE:90.64253816187352,
RMSE: 9.520637487157755, IA: 0.8678590980466889, MAPE: 42.12130664455094%


 50%|█████     | 10/20 [00:05<00:03,  2.59it/s]

MAE: 6.619526055782517, MSE:94.29543233673624,
RMSE: 9.710583521948424, IA: 0.8713893469527922, MAPE: 38.4304342658162%


 55%|█████▌    | 11/20 [00:06<00:03,  2.64it/s]

MAE: 6.639676304084028, MSE:90.99430945085841,
RMSE: 9.539093743687522, IA: 0.8694541682390546, MAPE: 41.77840944955717%


 60%|██████    | 12/20 [00:06<00:02,  2.68it/s]

MAE: 6.578254351863971, MSE:91.89733847197041,
RMSE: 9.58630995075636, IA: 0.8657298929002797, MAPE: 39.135861200880356%


 65%|██████▌   | 13/20 [00:06<00:02,  2.71it/s]

MAE: 6.521699430349934, MSE:91.26648364863358,
RMSE: 9.553349341913211, IA: 0.8700819336024915, MAPE: 37.88006219805131%


 70%|███████   | 14/20 [00:07<00:02,  2.70it/s]

MAE: 6.662855497536632, MSE:90.8411297328617,
RMSE: 9.531061311987333, IA: 0.8744882700136389, MAPE: 43.76052564916662%


 75%|███████▌  | 15/20 [00:07<00:01,  2.66it/s]

MAE: 6.670450896334786, MSE:90.18152167454797,
RMSE: 9.496395193679966, IA: 0.8671459428730897, MAPE: 43.059054670465%


 80%|████████  | 16/20 [00:07<00:01,  2.60it/s]

MAE: 6.692971044330927, MSE:92.81798205019625,
RMSE: 9.634208947816953, IA: 0.872296162262297, MAPE: 43.77075432551715%


 85%|████████▌ | 17/20 [00:08<00:01,  2.60it/s]

MAE: 6.549028937320489, MSE:90.16558553543666,
RMSE: 9.495556094059824, IA: 0.8763357912263023, MAPE: 41.5671629640403%


 90%|█████████ | 18/20 [00:08<00:00,  2.58it/s]

MAE: 6.551630107722531, MSE:88.51039284960504,
RMSE: 9.407996218621957, IA: 0.8712973891329541, MAPE: 41.5764665007009%


 95%|█████████▌| 19/20 [00:09<00:00,  2.57it/s]

MAE: 6.593156081817054, MSE:88.8215455622859,
RMSE: 9.424518319908232, IA: 0.8761724603181775, MAPE: 43.051662889198596%


100%|██████████| 20/20 [00:09<00:00,  2.11it/s]

MAE: 6.9960809159141055, MSE:97.06604644953235,
RMSE: 9.852210231695848, IA: 0.8679841691513017, MAPE: 48.403454706458916%
Accuracy on validation dataset: 47.980000000000004%





In [8]:
print(f'Accuracy on test set: {get_accuracy(model1, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model1, test_loader, device=device, print_e=True)

Accuracy on test set: 47.53%
MAE: 6.584075526390311, MSE:81.99053064639483,
RMSE: 9.054862265456876, IA: 0.8827876434020421, MAPE: 46.74369985409349%


(6.584075526390311,
 81.99053064639483,
 9.054862265456876,
 0.8827876434020421,
 46.74369985409349)

### SNN with Leaky layer

In [9]:
model2 = LeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1, num_steps=25)
print(model2)

LeakySNN(
  (fc1): Linear(in_features=8, out_features=3500, bias=True)
  (lif): Leaky()
  (fc2): Linear(in_features=3500, out_features=1, bias=True)
)


In [10]:
training_loop(model2, train_loader, valid_loader, device, num_epochs=25, lr=0.001, validation=True)

  4%|▍         | 1/25 [00:04<01:59,  4.97s/it]

MAE: 6.771930640556909, MSE:97.31362862957145,
RMSE: 9.86476703372013, IA: 0.8662075685537425, MAPE: 42.58734285394422%


  8%|▊         | 2/25 [00:09<01:52,  4.91s/it]

MAE: 6.901768796016715, MSE:98.0451238094351,
RMSE: 9.901773770867273, IA: 0.8591297149681488, MAPE: 39.89053553427123%


 12%|█▏        | 3/25 [00:14<01:48,  4.92s/it]

MAE: 6.961861290683636, MSE:95.61082538412307,
RMSE: 9.778078818670009, IA: 0.8637187518777154, MAPE: 45.33794373194227%


 16%|█▌        | 4/25 [00:19<01:43,  4.92s/it]

MAE: 6.720520240171797, MSE:93.47617190583425,
RMSE: 9.668307602979658, IA: 0.8782898789256761, MAPE: 39.98168358120336%


 20%|██        | 5/25 [00:24<01:38,  4.92s/it]

MAE: 6.520843923178022, MSE:89.6802288570574,
RMSE: 9.46996456472026, IA: 0.8816546016881952, MAPE: 37.79932812395198%


 24%|██▍       | 6/25 [00:29<01:34,  4.97s/it]

MAE: 6.454150306007077, MSE:85.95201904565258,
RMSE: 9.271031174883007, IA: 0.8790284043720602, MAPE: 41.635009143622796%


 28%|██▊       | 7/25 [00:34<01:28,  4.94s/it]

MAE: 6.384115606649763, MSE:85.64019349884134,
RMSE: 9.254198695664652, IA: 0.8831132778098126, MAPE: 39.063750661854584%


 32%|███▏      | 8/25 [00:39<01:23,  4.92s/it]

MAE: 6.418624311788923, MSE:84.76355539804662,
RMSE: 9.20671251848599, IA: 0.8837554879138177, MAPE: 39.31964606900376%


 36%|███▌      | 9/25 [00:44<01:19,  4.99s/it]

MAE: 6.452265844593159, MSE:83.87951184244623,
RMSE: 9.158575863224929, IA: 0.8822788428796944, MAPE: 42.283987081299266%


 40%|████      | 10/25 [00:49<01:15,  5.02s/it]

MAE: 6.690867027381941, MSE:85.64757252610487,
RMSE: 9.254597372447105, IA: 0.8763375519780078, MAPE: 45.63211340104512%


 44%|████▍     | 11/25 [00:54<01:11,  5.11s/it]

MAE: 6.646658954386077, MSE:83.81300822066162,
RMSE: 9.15494446846411, IA: 0.8845853366094114, MAPE: 45.46390846460197%


 48%|████▊     | 12/25 [01:00<01:06,  5.09s/it]

MAE: 6.476462414326695, MSE:84.4609130058195,
RMSE: 9.190261857304149, IA: 0.886270356973127, MAPE: 36.955597818819335%


 52%|█████▏    | 13/25 [01:05<01:00,  5.08s/it]

MAE: 6.607884364734495, MSE:88.59070902237626,
RMSE: 9.412263756524052, IA: 0.8890412989091947, MAPE: 43.50608609686243%


 56%|█████▌    | 14/25 [01:10<00:56,  5.15s/it]

MAE: 6.210601162497019, MSE:76.93435882655574,
RMSE: 8.771223336944267, IA: 0.8930755680539831, MAPE: 39.15253553098941%


 60%|██████    | 15/25 [01:15<00:52,  5.21s/it]

MAE: 6.915999995766348, MSE:90.46376376045102,
RMSE: 9.511244070070488, IA: 0.8838685589357955, MAPE: 49.1038675391145%


 64%|██████▍   | 16/25 [01:20<00:46,  5.14s/it]

MAE: 6.318948000781454, MSE:77.98540622253626,
RMSE: 8.83093461772514, IA: 0.8873202769518008, MAPE: 41.56156574377161%


 68%|██████▊   | 17/25 [01:25<00:40,  5.09s/it]

MAE: 6.336150337057996, MSE:79.77658074931736,
RMSE: 8.931773662006744, IA: 0.8984971337183131, MAPE: 41.706208352268135%


 72%|███████▏  | 18/25 [01:30<00:35,  5.08s/it]

MAE: 6.1198612785752795, MSE:76.9745595321123,
RMSE: 8.773514662443569, IA: 0.9013242241974695, MAPE: 36.60769196213682%


 76%|███████▌  | 19/25 [01:35<00:30,  5.04s/it]

MAE: 6.16842774931406, MSE:76.40167817360714,
RMSE: 8.74080535040148, IA: 0.8976166860787187, MAPE: 40.62938177767238%


 80%|████████  | 20/25 [01:40<00:25,  5.02s/it]

MAE: 6.907121393911411, MSE:89.52649206953248,
RMSE: 9.461844009997865, IA: 0.893173761723787, MAPE: 48.88903863490625%


 84%|████████▍ | 21/25 [01:45<00:20,  5.02s/it]

MAE: 6.170427740447094, MSE:76.88922579743519,
RMSE: 8.76865016963473, IA: 0.8969650129631844, MAPE: 35.77531257867898%


 88%|████████▊ | 22/25 [01:50<00:15,  5.03s/it]

MAE: 5.954393073175684, MSE:73.21774050281033,
RMSE: 8.556736556819446, IA: 0.9063309292598926, MAPE: 37.324126288608824%


 92%|█████████▏| 23/25 [01:55<00:10,  5.01s/it]

MAE: 6.11515609935529, MSE:73.42830192414766,
RMSE: 8.569031562793294, IA: 0.9051800888525983, MAPE: 40.48702337767838%


 96%|█████████▌| 24/25 [02:00<00:04,  4.98s/it]

MAE: 6.042265920280721, MSE:72.06809979207522,
RMSE: 8.489293244556652, IA: 0.9079222372639418, MAPE: 34.68317112339825%


100%|██████████| 25/25 [02:05<00:00,  5.03s/it]

MAE: 6.309930639976711, MSE:79.58138106118966,
RMSE: 8.920839706058485, IA: 0.894345374289621, MAPE: 35.43424808735098%
Accuracy on validation dataset: 49.830000000000005%





In [11]:
print(f'Accuracy on test set: {get_accuracy(model2, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model2, test_loader, device=device, print_e=True)

Accuracy on test set: 51.160000000000004%
MAE: 5.972411362158683, MSE:73.31110708098069,
RMSE: 8.562190553881681, IA: 0.8962826474424391, MAPE: 34.6520643229448%


(5.972411362158683,
 73.31110708098069,
 8.562190553881681,
 0.8962826474424391,
 34.6520643229448)

### SNN with Synaptic layer

In [12]:
model3 = SynapticSNN(num_inputs=8, num_hidden=3500, num_outputs=1, num_steps=50)
print(model3)

SynapticSNN(
  (fc1): Linear(in_features=8, out_features=3500, bias=True)
  (lif): Synaptic()
  (fc2): Linear(in_features=3500, out_features=1, bias=True)
)


In [13]:
training_loop(model3, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=True)

  4%|▍         | 1/25 [00:10<04:14, 10.62s/it]

MAE: 6.909225411635603, MSE:95.00316251195699,
RMSE: 9.74695657689912, IA: 0.8730246451840118, MAPE: 46.58436028497872%


  8%|▊         | 2/25 [00:20<03:58, 10.37s/it]

MAE: 6.61289022313377, MSE:86.37375911138491,
RMSE: 9.293748388641953, IA: 0.891136062933712, MAPE: 45.24222761881174%


 12%|█▏        | 3/25 [00:31<03:50, 10.49s/it]

MAE: 7.4061040227812835, MSE:98.5948264660103,
RMSE: 9.929492759754162, IA: 0.8788456464442631, MAPE: 54.398678502505035%


 16%|█▌        | 4/25 [00:42<03:42, 10.57s/it]

MAE: 6.85529481592895, MSE:88.3944516697486,
RMSE: 9.401832357032783, IA: 0.8901833513242269, MAPE: 47.86134655235981%


 20%|██        | 5/25 [00:52<03:29, 10.47s/it]

MAE: 10.154624281315446, MSE:150.65065658564495,
RMSE: 12.273982914508434, IA: 0.8292753319943889, MAPE: 79.89974117605733%


 24%|██▍       | 6/25 [01:02<03:17, 10.40s/it]

MAE: 6.88276304874806, MSE:82.02942635213785,
RMSE: 9.0570097908823, IA: 0.8845243902924794, MAPE: 50.27797844430815%


 28%|██▊       | 7/25 [01:12<03:06, 10.36s/it]

MAE: 6.681699407272945, MSE:82.1582405304412,
RMSE: 9.064118298568328, IA: 0.8954078669645948, MAPE: 46.52560800252502%


 32%|███▏      | 8/25 [01:23<02:58, 10.50s/it]

MAE: 6.113687837054964, MSE:69.50652966600765,
RMSE: 8.337057614410954, IA: 0.9093543455757851, MAPE: 39.8033875445438%


 36%|███▌      | 9/25 [01:33<02:46, 10.41s/it]

MAE: 6.962739857266834, MSE:85.84820551012368,
RMSE: 9.265430670515196, IA: 0.893151076941457, MAPE: 38.430910586193406%


 40%|████      | 10/25 [01:44<02:35, 10.36s/it]

MAE: 6.338254382741245, MSE:75.5733924864318,
RMSE: 8.69329583566738, IA: 0.9022235996460952, MAPE: 38.162365973878444%


 44%|████▍     | 11/25 [01:54<02:25, 10.37s/it]

MAE: 5.933907544268349, MSE:66.37123088966956,
RMSE: 8.14685404862942, IA: 0.9133602776387405, MAPE: 39.72608475973113%


 48%|████▊     | 12/25 [02:04<02:14, 10.32s/it]

MAE: 6.560312133407317, MSE:77.15029395014707,
RMSE: 8.783524005212662, IA: 0.9048939970048228, MAPE: 36.062290025781984%


 52%|█████▏    | 13/25 [02:15<02:03, 10.30s/it]

MAE: 6.556751434237971, MSE:79.10952511168229,
RMSE: 8.894353552208406, IA: 0.8999616429059705, MAPE: 46.150159015689354%


 56%|█████▌    | 14/25 [02:25<01:53, 10.28s/it]

MAE: 6.022265860111038, MSE:68.31749070485213,
RMSE: 8.265439534885736, IA: 0.9177026445463119, MAPE: 38.40245121524752%


 60%|██████    | 15/25 [02:35<01:42, 10.29s/it]

MAE: 6.0873410193217286, MSE:67.75917325335767,
RMSE: 8.2315960331735, IA: 0.9147233134076402, MAPE: 41.98466052927882%


 64%|██████▍   | 16/25 [02:45<01:32, 10.27s/it]

MAE: 7.57091329556626, MSE:93.34564982087737,
RMSE: 9.66155524855483, IA: 0.8959970125598946, MAPE: 43.73354151489742%


 68%|██████▊   | 17/25 [02:56<01:22, 10.25s/it]

MAE: 6.22139879182584, MSE:72.09993593190444,
RMSE: 8.491168113510911, IA: 0.917470540313773, MAPE: 37.03478962810397%


 72%|███████▏  | 18/25 [03:06<01:11, 10.27s/it]

MAE: 6.661861299090303, MSE:75.4785357399242,
RMSE: 8.687838381319269, IA: 0.9034099653282196, MAPE: 48.43366781862938%


 76%|███████▌  | 19/25 [03:16<01:01, 10.33s/it]

MAE: 7.919641613822452, MSE:100.25824100590887,
RMSE: 10.012903724989513, IA: 0.887501305453359, MAPE: 59.68249460625999%


 80%|████████  | 20/25 [03:27<00:52, 10.41s/it]

MAE: 5.668913288388638, MSE:60.86539225921909,
RMSE: 7.80162753912407, IA: 0.9265198858891592, MAPE: 34.402331725852854%


 84%|████████▍ | 21/25 [03:38<00:42, 10.55s/it]

MAE: 5.711942245783231, MSE:59.54558069605794,
RMSE: 7.716578302334392, IA: 0.9228790261284239, MAPE: 37.879658116981055%


 88%|████████▊ | 22/25 [03:48<00:31, 10.48s/it]

MAE: 8.674427825729282, MSE:118.93292186322516,
RMSE: 10.905637159892363, IA: 0.8751615402194524, MAPE: 63.64140981206755%


 92%|█████████▏| 23/25 [03:58<00:20, 10.44s/it]

MAE: 5.914231181282529, MSE:65.17750602677856,
RMSE: 8.073258699359172, IA: 0.9218099389258596, MAPE: 34.949823254561196%


 96%|█████████▌| 24/25 [04:09<00:10, 10.40s/it]

MAE: 6.90268213294145, MSE:80.44699312128886,
RMSE: 8.969224778167222, IA: 0.9081367779385366, MAPE: 49.10808721514611%


100%|██████████| 25/25 [04:19<00:00, 10.38s/it]

MAE: 5.509052003015673, MSE:55.67434434941341,
RMSE: 7.461524264479303, IA: 0.9316703376768273, MAPE: 33.72133850987771%
Accuracy on validation dataset: 52.14%





In [14]:
print(f'Accuracy on test set: {get_accuracy(model3, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model3, test_loader, device=device, print_e=True)

Accuracy on test set: 53.0%
MAE: 5.332241890380857, MSE:53.110854451668686,
RMSE: 7.287719427342733, IA: 0.9303150895932372, MAPE: 34.26196541916594%


(5.332241890380857,
 53.110854451668686,
 7.287719427342733,
 0.9303150895932372,
 34.26196541916594)

### SNN with two Leaky layers

In [15]:
model4 = DoubleLeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model4)

DoubleLeakySNN(
  (lin1): Linear(in_features=8, out_features=3500, bias=True)
  (lif1): Leaky()
  (lif2): Leaky()
  (lin2): Linear(in_features=3500, out_features=1, bias=True)
)


In [16]:
training_loop(model4, train_loader, valid_loader, device, num_epochs=25, lr=0.002, validation=True)

  4%|▍         | 1/25 [00:08<03:18,  8.27s/it]

MAE: 8.698820817539458, MSE:138.23902871559272,
RMSE: 11.757509460578492, IA: 0.7005968754790048, MAPE: 60.63466696968963%


  8%|▊         | 2/25 [00:16<03:11,  8.31s/it]

MAE: 7.131606938521986, MSE:102.52068320709769,
RMSE: 10.125249784923712, IA: 0.8469535881528957, MAPE: 45.17123081718359%


 12%|█▏        | 3/25 [00:25<03:08,  8.55s/it]

MAE: 7.146393032845734, MSE:105.15033381109744,
RMSE: 10.25428368103289, IA: 0.8443156633866629, MAPE: 43.801508610456104%


 16%|█▌        | 4/25 [00:34<03:00,  8.59s/it]

MAE: 7.353734079812993, MSE:105.74844436907466,
RMSE: 10.283406262959499, IA: 0.8515757087972873, MAPE: 49.72287274450645%


 20%|██        | 5/25 [00:42<02:50,  8.54s/it]

MAE: 7.0508785997512025, MSE:101.63153575625462,
RMSE: 10.081246736205529, IA: 0.8512148840354006, MAPE: 42.71340939680549%


 24%|██▍       | 6/25 [00:51<02:41,  8.52s/it]

MAE: 7.078786136097991, MSE:100.65202621973802,
RMSE: 10.03254834126096, IA: 0.8545221329994781, MAPE: 47.906482561055064%


 28%|██▊       | 7/25 [00:59<02:33,  8.52s/it]

MAE: 6.933572222318263, MSE:99.00504570346477,
RMSE: 9.950127923974886, IA: 0.8549457045337052, MAPE: 44.29434113619051%


 32%|███▏      | 8/25 [01:07<02:24,  8.49s/it]

MAE: 6.996658915591378, MSE:98.64449193154827,
RMSE: 9.931993351364483, IA: 0.8581193575123288, MAPE: 46.01858305623641%


 36%|███▌      | 9/25 [01:16<02:16,  8.54s/it]

MAE: 6.970612685253165, MSE:98.40618078345774,
RMSE: 9.919988950772966, IA: 0.8541286567549796, MAPE: 42.01529141218816%


 40%|████      | 10/25 [01:25<02:07,  8.51s/it]

MAE: 6.972265907794754, MSE:98.53739313380126,
RMSE: 9.926600280750769, IA: 0.856288011134571, MAPE: 41.6013750491906%


 44%|████▍     | 11/25 [01:34<02:01,  8.70s/it]

MAE: 6.974647403452438, MSE:95.69301147121215,
RMSE: 9.78228048418221, IA: 0.8627312810927534, MAPE: 46.11772835606155%


 48%|████▊     | 12/25 [01:42<01:52,  8.63s/it]

MAE: 8.099976914466462, MSE:118.54335410118078,
RMSE: 10.88776166625541, IA: 0.8395284391860356, MAPE: 60.28945855587331%


 52%|█████▏    | 13/25 [01:51<01:43,  8.59s/it]

MAE: 6.937930653963475, MSE:97.70930122663039,
RMSE: 9.884801526921539, IA: 0.8600715962649388, MAPE: 44.28378788995665%


 56%|█████▌    | 14/25 [01:59<01:34,  8.55s/it]

MAE: 7.031722527432304, MSE:97.68878943692248,
RMSE: 9.883763930655288, IA: 0.8621289305317149, MAPE: 46.44314497528758%


 60%|██████    | 15/25 [02:08<01:25,  8.53s/it]

MAE: 7.0732369745397845, MSE:99.78459503573764,
RMSE: 9.98922394561948, IA: 0.8587066555490734, MAPE: 41.72139982396599%


 64%|██████▍   | 16/25 [02:16<01:16,  8.51s/it]

MAE: 6.905109781199108, MSE:95.94232523192562,
RMSE: 9.795015325762671, IA: 0.8680285786967724, MAPE: 43.891297295633294%


 68%|██████▊   | 17/25 [02:25<01:07,  8.50s/it]

MAE: 6.8891098168544, MSE:96.2600818778904,
RMSE: 9.811222241794873, IA: 0.861104906649226, MAPE: 41.91022056434629%


 72%|███████▏  | 18/25 [02:33<01:00,  8.58s/it]

MAE: 6.943144512452142, MSE:98.81590173953416,
RMSE: 9.940618780515333, IA: 0.855242843521357, MAPE: 43.00775718525104%


 76%|███████▌  | 19/25 [02:42<00:52,  8.68s/it]

MAE: 7.539988428457624, MSE:107.07137796481011,
RMSE: 10.34753004174475, IA: 0.8520599603954205, MAPE: 54.40326357433809%


 80%|████████  | 20/25 [02:51<00:43,  8.73s/it]

MAE: 6.92744507596672, MSE:95.45852190519804,
RMSE: 9.770287708414632, IA: 0.8622021819597516, MAPE: 45.246524769196945%


 84%|████████▍ | 21/25 [03:00<00:34,  8.70s/it]

MAE: 7.321988435976767, MSE:104.25280437560491,
RMSE: 10.210426258271733, IA: 0.8554190744456193, MAPE: 51.03696195638911%


 88%|████████▊ | 22/25 [03:08<00:25,  8.65s/it]

MAE: 6.949375733061333, MSE:99.92242437290605,
RMSE: 9.996120466106142, IA: 0.8572765687523192, MAPE: 42.54713775814095%


 92%|█████████▏| 23/25 [03:17<00:17,  8.59s/it]

MAE: 6.878647420034243, MSE:97.08553656817278,
RMSE: 9.853199306223983, IA: 0.8607174308468796, MAPE: 45.34781950406298%


 96%|█████████▌| 24/25 [03:26<00:08,  8.66s/it]

MAE: 6.911283267577948, MSE:97.30215012302486,
RMSE: 9.864185223475118, IA: 0.8578450871803953, MAPE: 45.27262108124026%


100%|██████████| 25/25 [03:34<00:00,  8.57s/it]

MAE: 6.9349132863083325, MSE:99.96194535219932,
RMSE: 9.998097086555987, IA: 0.8525111392066496, MAPE: 42.31609397360791%
Accuracy on validation dataset: 47.28%





In [17]:
print(f'Accuracy on test set: {get_accuracy(model4, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model4, test_loader, device=device, print_e=True)

Accuracy on test set: 48.0%
MAE: 6.554738069168408, MSE:84.30115118115977,
RMSE: 9.181565834930323, IA: 0.8682392415719199, MAPE: 39.55185705235553%


(6.554738069168408,
 84.30115118115977,
 9.181565834930323,
 0.8682392415719199,
 39.55185705235553)

### Repeated training on random datasets

In [9]:
err_model1 = [0., 0., 0., 0., 0., 0.]
err_model2 = [0., 0., 0., 0., 0., 0.]
err_model3 = [0., 0., 0., 0., 0., 0.]
err_model4 = [0., 0., 0., 0., 0., 0.]

In [10]:
trials = 5
for i in range(trials):
    # Prepare data loaders.
    train_dataset, test_dataset = torch.utils.data.random_split(all_data, (round(0.8 * len(all_data)), round(0.2 * len(all_data))))
    train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
    test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)

    # Prepare models.
    model1 = SimpleANN(num_inputs=8, num_hidden=4000, num_outputs=1)
    model2 = LeakySNN(num_inputs=8, num_hidden=4000, num_outputs=1, num_steps=25)
    model3 = SynapticSNN(num_inputs=8, num_hidden=4000, num_outputs=1, num_steps=50)
    model4 = DoubleLeakySNN(num_inputs=8, num_hidden=4000, num_outputs=1, num_steps=20)

    # Train models.
    print(f'------------------------------------\nTraining {i+1}')
    training_loop(model1, train_loader, test_loader, device, num_epochs=25, lr=0.005, validation=False)
    training_loop(model2, train_loader, test_loader, device, num_epochs=25, lr=0.001, validation=False)
    training_loop(model3, train_loader, test_loader, device, num_epochs=25, lr=0.0005, validation=False)
    training_loop(model4, train_loader, test_loader, device, num_epochs=25, lr=0.001, validation=False)

    # Gather results.
    collect_stats(model1, err_model1, test_loader, device)
    collect_stats(model2, err_model2, test_loader, device)
    collect_stats(model3, err_model3, test_loader, device)
    collect_stats(model4, err_model4, test_loader, device)

------------------------------------
Training 1


100%|██████████| 25/25 [00:11<00:00,  2.21it/s]


Accuracy on validation dataset: 48.79%


100%|██████████| 25/25 [02:15<00:00,  5.43s/it]


Accuracy on validation dataset: 50.690000000000005%


100%|██████████| 25/25 [04:36<00:00, 11.08s/it]


Accuracy on validation dataset: 53.410000000000004%


100%|██████████| 25/25 [03:27<00:00,  8.29s/it]


Accuracy on validation dataset: 51.790000000000006%
------------------------------------
Training 2


100%|██████████| 25/25 [00:10<00:00,  2.45it/s]


Accuracy on validation dataset: 48.5%


100%|██████████| 25/25 [02:17<00:00,  5.49s/it]


Accuracy on validation dataset: 46.42%


100%|██████████| 25/25 [04:39<00:00, 11.18s/it]


Accuracy on validation dataset: 49.08%


100%|██████████| 25/25 [03:26<00:00,  8.27s/it]


Accuracy on validation dataset: 52.83%
------------------------------------
Training 3


100%|██████████| 25/25 [00:10<00:00,  2.28it/s]


Accuracy on validation dataset: 50.74999999999999%


100%|██████████| 25/25 [02:19<00:00,  5.58s/it]


Accuracy on validation dataset: 54.56999999999999%


100%|██████████| 25/25 [04:41<00:00, 11.25s/it]


Accuracy on validation dataset: 50.06%


100%|██████████| 25/25 [03:31<00:00,  8.46s/it]


Accuracy on validation dataset: 52.949999999999996%
------------------------------------
Training 4


100%|██████████| 25/25 [00:10<00:00,  2.32it/s]


Accuracy on validation dataset: 47.980000000000004%


100%|██████████| 25/25 [02:19<00:00,  5.60s/it]


Accuracy on validation dataset: 51.33%


100%|██████████| 25/25 [04:47<00:00, 11.49s/it]


Accuracy on validation dataset: 53.7%


100%|██████████| 25/25 [03:28<00:00,  8.35s/it]


Accuracy on validation dataset: 54.39000000000001%
------------------------------------
Training 5


100%|██████████| 25/25 [00:11<00:00,  2.19it/s]


Accuracy on validation dataset: 51.21%


100%|██████████| 25/25 [02:23<00:00,  5.76s/it]


Accuracy on validation dataset: 54.1%


100%|██████████| 25/25 [04:40<00:00, 11.20s/it]


Accuracy on validation dataset: 43.53%


100%|██████████| 25/25 [03:30<00:00,  8.43s/it]


Accuracy on validation dataset: 54.1%


In [11]:
# Print average results
print('------Simple ANN ------')
print_average_stats(err_model1, trials)
print('------LIF SNN ------')
print_average_stats(err_model2, trials)
print('------Synaptic SNN ------')
print_average_stats(err_model3, trials)
print('------Double LIF SNN ------')
print_average_stats(err_model4, trials)

------Simple ANN ------
MAE: 6.208275137208618, MSE:80.33695688073658,
RMSE: 8.96052686643614, IA: 0.8727338962870699,
MAPE: 37.073615240599054 %, acc: 49.446 %

------LIF SNN ------
MAE: 5.84460810140206, MSE:66.32328005726299,
RMSE: 8.138005750333555, IA: 0.9009785371074507,
MAPE: 38.07243964931761 %, acc: 51.422000000000004 %

------Synaptic SNN ------
MAE: 5.832383813015213, MSE:61.60718308915766,
RMSE: 7.840439236958607, IA: 0.9151266930054927,
MAPE: 38.96341075980955 %, acc: 49.955999999999996 %

------Double LIF SNN ------
MAE: 5.750469359445434, MSE:66.74097353878959,
RMSE: 8.16253254611893, IA: 0.9000361252886762,
MAPE: 37.34842986271641 %, acc: 53.212 %

