In [1]:
import torch
import torch.utils.data as data
import pandas as pd
import matplotlib.pyplot as plt
from src.error_measures import get_accuracy, get_error_measures, collect_stats, print_average_stats, gather_predictions
from src.models import MultiLayerANN, SynapticSNN, LeakySNN, DoubleLeakySNN
from src.train_model import training_loop

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

## Data

In [3]:
data12 = pd.read_csv("../data/processed/target12.csv")
data12.drop(columns=["weekday"], inplace=True)

In [4]:
all_data = data.TensorDataset(torch.from_numpy((data12.values[:,:-1] - data12.values[:,:-1].min(0)) / data12.values[:,:-1].ptp(0)).float(),
                              torch.from_numpy(data12.values[:,-1]).float())  # with normalization
train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(all_data, (round(0.8 * len(all_data)), round(0.1 * len(all_data)), round(0.1 * len(all_data))))

In [5]:
train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)
valid_loader = data.DataLoader(valid_dataset, batch_size=256, shuffle=False, drop_last=False)

### Simple regressor

In [6]:
model1 = MultiLayerANN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model1)

SimpleANN(
  (lin1): Linear(in_features=8, out_features=3500, bias=True)
  (tanh): Tanh()
  (lin2): Linear(in_features=3500, out_features=3500, bias=True)
  (relu): ReLU()
  (lin3): Linear(in_features=3500, out_features=1, bias=True)
)


In [7]:
training_loop(model1, train_loader, valid_loader, device, num_epochs=100, lr=0.003, validation=True)

  1%|          | 1/100 [00:04<07:42,  4.67s/it]

MAE: 8.5964, MSE:138.7623,
RMSE: 11.7797, IA: 0.7844, MAPE: 59.7695%


  2%|▏         | 2/100 [00:06<04:44,  2.91s/it]

MAE: 8.2042, MSE:144.9369,
RMSE: 12.0390, IA: 0.7464, MAPE: 51.1733%


  3%|▎         | 3/100 [00:08<03:49,  2.36s/it]

MAE: 8.2704, MSE:139.2700,
RMSE: 11.8013, IA: 0.7789, MAPE: 54.1842%


  4%|▍         | 4/100 [00:09<03:21,  2.10s/it]

MAE: 8.0976, MSE:135.0155,
RMSE: 11.6196, IA: 0.8179, MAPE: 51.7031%


  5%|▌         | 5/100 [00:11<03:04,  1.94s/it]

MAE: 7.7986, MSE:133.5057,
RMSE: 11.5545, IA: 0.8085, MAPE: 43.6454%


  6%|▌         | 6/100 [00:13<02:53,  1.84s/it]

MAE: 7.8416, MSE:135.2949,
RMSE: 11.6316, IA: 0.8205, MAPE: 43.2545%


  7%|▋         | 7/100 [00:14<02:47,  1.80s/it]

MAE: 8.2357, MSE:128.6745,
RMSE: 11.3435, IA: 0.8254, MAPE: 57.0285%


  8%|▊         | 8/100 [00:16<02:42,  1.76s/it]

MAE: 7.6609, MSE:121.5701,
RMSE: 11.0259, IA: 0.8386, MAPE: 48.3948%


  9%|▉         | 9/100 [00:18<02:40,  1.76s/it]

MAE: 7.5670, MSE:121.7344,
RMSE: 11.0333, IA: 0.8315, MAPE: 45.6334%


 10%|█         | 10/100 [00:19<02:35,  1.73s/it]

MAE: 7.6371, MSE:118.4872,
RMSE: 10.8852, IA: 0.8189, MAPE: 49.6367%


 11%|█         | 11/100 [00:21<02:33,  1.72s/it]

MAE: 8.2124, MSE:126.8914,
RMSE: 11.2646, IA: 0.8240, MAPE: 57.5989%


 12%|█▏        | 12/100 [00:23<02:31,  1.72s/it]

MAE: 7.8829, MSE:116.5097,
RMSE: 10.7940, IA: 0.8422, MAPE: 55.0920%


 13%|█▎        | 13/100 [00:25<02:29,  1.72s/it]

MAE: 7.4768, MSE:120.8418,
RMSE: 10.9928, IA: 0.8331, MAPE: 43.7773%


 14%|█▍        | 14/100 [00:26<02:27,  1.71s/it]

MAE: 7.4577, MSE:114.1084,
RMSE: 10.6822, IA: 0.8474, MAPE: 46.3214%


 15%|█▌        | 15/100 [00:28<02:24,  1.70s/it]

MAE: 7.5725, MSE:120.6999,
RMSE: 10.9864, IA: 0.8276, MAPE: 45.5618%


 16%|█▌        | 16/100 [00:30<02:23,  1.71s/it]

MAE: 7.5783, MSE:110.8910,
RMSE: 10.5305, IA: 0.8627, MAPE: 50.4713%


 17%|█▋        | 17/100 [00:31<02:20,  1.70s/it]

MAE: 7.4422, MSE:118.8478,
RMSE: 10.9017, IA: 0.8516, MAPE: 40.4686%


 18%|█▊        | 18/100 [00:33<02:20,  1.71s/it]

MAE: 7.8636, MSE:116.2510,
RMSE: 10.7820, IA: 0.8543, MAPE: 54.2783%


 19%|█▉        | 19/100 [00:35<02:18,  1.71s/it]

MAE: 7.0930, MSE:105.6775,
RMSE: 10.2800, IA: 0.8567, MAPE: 43.2535%


 20%|██        | 20/100 [00:36<02:17,  1.72s/it]

MAE: 7.3597, MSE:123.0002,
RMSE: 11.0905, IA: 0.8351, MAPE: 39.0932%


 21%|██        | 21/100 [00:38<02:14,  1.70s/it]

MAE: 7.2819, MSE:109.2956,
RMSE: 10.4545, IA: 0.8499, MAPE: 46.8423%


 22%|██▏       | 22/100 [00:40<02:13,  1.71s/it]

MAE: 7.2995, MSE:118.1203,
RMSE: 10.8683, IA: 0.8341, MAPE: 39.8673%


 23%|██▎       | 23/100 [00:41<02:09,  1.69s/it]

MAE: 7.2372, MSE:116.6355,
RMSE: 10.7998, IA: 0.8658, MAPE: 42.0125%


 24%|██▍       | 24/100 [00:43<02:08,  1.70s/it]

MAE: 7.7319, MSE:127.3432,
RMSE: 11.2846, IA: 0.8422, MAPE: 45.6598%


 25%|██▌       | 25/100 [00:45<02:06,  1.69s/it]

MAE: 7.3196, MSE:103.9078,
RMSE: 10.1935, IA: 0.8689, MAPE: 48.5369%


 26%|██▌       | 26/100 [00:47<02:03,  1.67s/it]

MAE: 7.5670, MSE:116.5841,
RMSE: 10.7974, IA: 0.8449, MAPE: 45.9902%


 27%|██▋       | 27/100 [00:48<02:02,  1.67s/it]

MAE: 7.0606, MSE:107.3892,
RMSE: 10.3629, IA: 0.8766, MAPE: 41.2892%


 28%|██▊       | 28/100 [00:50<02:00,  1.67s/it]

MAE: 7.0273, MSE:106.8561,
RMSE: 10.3371, IA: 0.8550, MAPE: 40.5892%


 29%|██▉       | 29/100 [00:52<02:00,  1.69s/it]

MAE: 7.0666, MSE:99.7993,
RMSE: 9.9900, IA: 0.8771, MAPE: 44.5806%


 30%|███       | 30/100 [00:53<01:58,  1.70s/it]

MAE: 7.1471, MSE:105.3953,
RMSE: 10.2662, IA: 0.8762, MAPE: 42.5020%


 31%|███       | 31/100 [00:55<01:57,  1.70s/it]

MAE: 7.2613, MSE:114.4430,
RMSE: 10.6978, IA: 0.8401, MAPE: 41.0248%


 32%|███▏      | 32/100 [00:57<01:55,  1.71s/it]

MAE: 7.1599, MSE:105.1749,
RMSE: 10.2555, IA: 0.8663, MAPE: 45.7336%


 33%|███▎      | 33/100 [00:58<01:53,  1.70s/it]

MAE: 7.1150, MSE:105.1386,
RMSE: 10.2537, IA: 0.8662, MAPE: 42.5114%


 34%|███▍      | 34/100 [01:00<01:52,  1.70s/it]

MAE: 7.0848, MSE:103.7898,
RMSE: 10.1877, IA: 0.8730, MAPE: 38.5387%


 35%|███▌      | 35/100 [01:02<01:50,  1.71s/it]

MAE: 7.1778, MSE:104.5217,
RMSE: 10.2236, IA: 0.8635, MAPE: 44.2940%


 36%|███▌      | 36/100 [01:04<01:48,  1.70s/it]

MAE: 7.1696, MSE:108.0492,
RMSE: 10.3947, IA: 0.8514, MAPE: 42.8307%


 37%|███▋      | 37/100 [01:05<01:46,  1.70s/it]

MAE: 6.9240, MSE:98.9492,
RMSE: 9.9473, IA: 0.8801, MAPE: 41.0226%


 38%|███▊      | 38/100 [01:07<01:44,  1.68s/it]

MAE: 7.5614, MSE:119.8934,
RMSE: 10.9496, IA: 0.8483, MAPE: 41.4809%


 39%|███▉      | 39/100 [01:09<01:42,  1.67s/it]

MAE: 7.1584, MSE:100.0308,
RMSE: 10.0015, IA: 0.8768, MAPE: 45.9351%


 40%|████      | 40/100 [01:10<01:39,  1.66s/it]

MAE: 7.2660, MSE:107.6284,
RMSE: 10.3744, IA: 0.8790, MAPE: 42.8673%


 41%|████      | 41/100 [01:12<01:38,  1.66s/it]

MAE: 7.0605, MSE:102.7305,
RMSE: 10.1356, IA: 0.8829, MAPE: 41.3544%


 42%|████▏     | 42/100 [01:13<01:36,  1.66s/it]

MAE: 7.0830, MSE:101.3980,
RMSE: 10.0697, IA: 0.8694, MAPE: 43.8712%


 43%|████▎     | 43/100 [01:15<01:34,  1.66s/it]

MAE: 7.1671, MSE:103.8577,
RMSE: 10.1911, IA: 0.8661, MAPE: 42.6691%


 44%|████▍     | 44/100 [01:17<01:32,  1.65s/it]

MAE: 7.0376, MSE:97.0893,
RMSE: 9.8534, IA: 0.8834, MAPE: 45.7660%


 45%|████▌     | 45/100 [01:18<01:30,  1.65s/it]

MAE: 6.9513, MSE:99.2328,
RMSE: 9.9616, IA: 0.8877, MAPE: 41.0588%


 46%|████▌     | 46/100 [01:20<01:29,  1.65s/it]

MAE: 7.3530, MSE:103.9163,
RMSE: 10.1939, IA: 0.8833, MAPE: 46.7231%


 47%|████▋     | 47/100 [01:22<01:27,  1.66s/it]

MAE: 7.3249, MSE:101.1528,
RMSE: 10.0575, IA: 0.8660, MAPE: 49.5905%


 48%|████▊     | 48/100 [01:23<01:25,  1.65s/it]

MAE: 6.9400, MSE:94.4043,
RMSE: 9.7162, IA: 0.8840, MAPE: 44.3758%


 49%|████▉     | 49/100 [01:25<01:23,  1.65s/it]

MAE: 6.9504, MSE:93.9161,
RMSE: 9.6910, IA: 0.8908, MAPE: 44.2551%


 50%|█████     | 50/100 [01:27<01:22,  1.65s/it]

MAE: 7.0281, MSE:102.4096,
RMSE: 10.1198, IA: 0.8650, MAPE: 41.9839%


 51%|█████     | 51/100 [01:28<01:20,  1.65s/it]

MAE: 6.8860, MSE:96.4504,
RMSE: 9.8209, IA: 0.8828, MAPE: 42.4219%


 52%|█████▏    | 52/100 [01:30<01:18,  1.64s/it]

MAE: 7.1186, MSE:102.9624,
RMSE: 10.1470, IA: 0.8649, MAPE: 42.9987%


 53%|█████▎    | 53/100 [01:32<01:17,  1.65s/it]

MAE: 6.8777, MSE:93.9088,
RMSE: 9.6907, IA: 0.8814, MAPE: 43.6218%


 54%|█████▍    | 54/100 [01:33<01:16,  1.66s/it]

MAE: 6.6621, MSE:92.2769,
RMSE: 9.6061, IA: 0.8821, MAPE: 40.6035%


 55%|█████▌    | 55/100 [01:35<01:14,  1.66s/it]

MAE: 7.4530, MSE:107.2202,
RMSE: 10.3547, IA: 0.8520, MAPE: 46.5452%


 56%|█████▌    | 56/100 [01:37<01:12,  1.66s/it]

MAE: 6.9310, MSE:92.0798,
RMSE: 9.5958, IA: 0.8850, MAPE: 44.1790%


 57%|█████▋    | 57/100 [01:38<01:11,  1.66s/it]

MAE: 7.2045, MSE:94.3164,
RMSE: 9.7117, IA: 0.8855, MAPE: 47.9041%


 58%|█████▊    | 58/100 [01:40<01:09,  1.65s/it]

MAE: 7.1721, MSE:98.8131,
RMSE: 9.9405, IA: 0.8922, MAPE: 46.0995%


 59%|█████▉    | 59/100 [01:42<01:08,  1.66s/it]

MAE: 7.1513, MSE:98.8364,
RMSE: 9.9416, IA: 0.8683, MAPE: 45.0317%


 60%|██████    | 60/100 [01:43<01:06,  1.65s/it]

MAE: 7.0392, MSE:96.7203,
RMSE: 9.8347, IA: 0.8794, MAPE: 44.0701%


 61%|██████    | 61/100 [01:45<01:04,  1.65s/it]

MAE: 6.9729, MSE:91.8939,
RMSE: 9.5861, IA: 0.8924, MAPE: 45.3798%


 62%|██████▏   | 62/100 [01:47<01:02,  1.66s/it]

MAE: 7.9756, MSE:108.2885,
RMSE: 10.4062, IA: 0.8708, MAPE: 58.5556%


 63%|██████▎   | 63/100 [01:48<01:01,  1.65s/it]

MAE: 6.9385, MSE:100.3921,
RMSE: 10.0196, IA: 0.8915, MAPE: 40.9368%


 64%|██████▍   | 64/100 [01:50<00:59,  1.66s/it]

MAE: 7.0168, MSE:97.8673,
RMSE: 9.8928, IA: 0.8735, MAPE: 41.6421%


 65%|██████▌   | 65/100 [01:51<00:58,  1.66s/it]

MAE: 7.2948, MSE:108.5150,
RMSE: 10.4171, IA: 0.8679, MAPE: 38.8252%


 66%|██████▌   | 66/100 [01:53<00:56,  1.65s/it]

MAE: 6.9050, MSE:91.2855,
RMSE: 9.5543, IA: 0.8758, MAPE: 44.6372%


 67%|██████▋   | 67/100 [01:55<00:54,  1.66s/it]

MAE: 7.5119, MSE:104.9614,
RMSE: 10.2451, IA: 0.8649, MAPE: 50.7137%


 68%|██████▊   | 68/100 [01:56<00:53,  1.66s/it]

MAE: 7.1840, MSE:98.8619,
RMSE: 9.9429, IA: 0.8946, MAPE: 47.9066%


 69%|██████▉   | 69/100 [01:58<00:51,  1.66s/it]

MAE: 6.9691, MSE:90.2831,
RMSE: 9.5017, IA: 0.8918, MAPE: 44.0645%


 70%|███████   | 70/100 [02:00<00:49,  1.65s/it]

MAE: 6.5700, MSE:88.2447,
RMSE: 9.3939, IA: 0.8943, MAPE: 39.2614%


 71%|███████   | 71/100 [02:01<00:48,  1.66s/it]

MAE: 7.3626, MSE:98.0507,
RMSE: 9.9021, IA: 0.8771, MAPE: 48.4524%


 72%|███████▏  | 72/100 [02:03<00:46,  1.66s/it]

MAE: 7.0047, MSE:102.7392,
RMSE: 10.1360, IA: 0.8772, MAPE: 39.3510%


 73%|███████▎  | 73/100 [02:05<00:44,  1.66s/it]

MAE: 7.1449, MSE:100.1133,
RMSE: 10.0057, IA: 0.8754, MAPE: 42.1613%


 74%|███████▍  | 74/100 [02:06<00:42,  1.65s/it]

MAE: 6.8332, MSE:91.0937,
RMSE: 9.5443, IA: 0.8897, MAPE: 41.5756%


 75%|███████▌  | 75/100 [02:08<00:41,  1.65s/it]

MAE: 6.7718, MSE:89.7189,
RMSE: 9.4720, IA: 0.8919, MAPE: 43.8498%


 76%|███████▌  | 76/100 [02:10<00:39,  1.66s/it]

MAE: 7.0850, MSE:104.3048,
RMSE: 10.2130, IA: 0.8643, MAPE: 41.1102%


 77%|███████▋  | 77/100 [02:11<00:38,  1.65s/it]

MAE: 6.9049, MSE:91.1386,
RMSE: 9.5467, IA: 0.8951, MAPE: 44.6395%


 78%|███████▊  | 78/100 [02:13<00:36,  1.66s/it]

MAE: 6.9967, MSE:97.5603,
RMSE: 9.8773, IA: 0.8744, MAPE: 41.3019%


 79%|███████▉  | 79/100 [02:15<00:34,  1.66s/it]

MAE: 6.8007, MSE:90.4165,
RMSE: 9.5088, IA: 0.8837, MAPE: 43.0249%


 80%|████████  | 80/100 [02:16<00:33,  1.66s/it]

MAE: 7.0711, MSE:92.0681,
RMSE: 9.5952, IA: 0.8809, MAPE: 46.8687%


 81%|████████  | 81/100 [02:18<00:31,  1.65s/it]

MAE: 6.8257, MSE:91.8930,
RMSE: 9.5861, IA: 0.8887, MAPE: 40.1205%


 82%|████████▏ | 82/100 [02:20<00:29,  1.66s/it]

MAE: 6.8483, MSE:90.5119,
RMSE: 9.5138, IA: 0.8863, MAPE: 43.7087%


 83%|████████▎ | 83/100 [02:21<00:28,  1.66s/it]

MAE: 6.8592, MSE:88.1675,
RMSE: 9.3898, IA: 0.8822, MAPE: 43.9294%


 84%|████████▍ | 84/100 [02:23<00:26,  1.65s/it]

MAE: 7.1169, MSE:93.8809,
RMSE: 9.6892, IA: 0.8904, MAPE: 44.7327%


 85%|████████▌ | 85/100 [02:25<00:24,  1.66s/it]

MAE: 6.8559, MSE:91.8701,
RMSE: 9.5849, IA: 0.8801, MAPE: 43.4282%


 86%|████████▌ | 86/100 [02:26<00:23,  1.66s/it]

MAE: 7.2975, MSE:97.6659,
RMSE: 9.8826, IA: 0.8831, MAPE: 45.4496%


 87%|████████▋ | 87/100 [02:28<00:21,  1.65s/it]

MAE: 6.6083, MSE:89.2836,
RMSE: 9.4490, IA: 0.8910, MAPE: 39.6516%


 88%|████████▊ | 88/100 [02:30<00:19,  1.66s/it]

MAE: 7.1625, MSE:114.2260,
RMSE: 10.6877, IA: 0.8290, MAPE: 42.5293%


 89%|████████▉ | 89/100 [02:31<00:18,  1.66s/it]

MAE: 6.6383, MSE:90.0401,
RMSE: 9.4889, IA: 0.8948, MAPE: 39.0238%


 90%|█████████ | 90/100 [02:33<00:16,  1.66s/it]

MAE: 6.7589, MSE:91.9087,
RMSE: 9.5869, IA: 0.8881, MAPE: 38.2955%


 91%|█████████ | 91/100 [02:35<00:14,  1.65s/it]

MAE: 6.7786, MSE:90.8100,
RMSE: 9.5294, IA: 0.8886, MAPE: 40.4522%


 92%|█████████▏| 92/100 [02:36<00:13,  1.65s/it]

MAE: 6.7829, MSE:93.0996,
RMSE: 9.6488, IA: 0.8875, MAPE: 39.3650%


 93%|█████████▎| 93/100 [02:38<00:11,  1.65s/it]

MAE: 6.8576, MSE:93.0006,
RMSE: 9.6437, IA: 0.8741, MAPE: 43.1651%


 94%|█████████▍| 94/100 [02:39<00:09,  1.66s/it]

MAE: 6.9542, MSE:93.5478,
RMSE: 9.6720, IA: 0.8881, MAPE: 43.3385%


 95%|█████████▌| 95/100 [02:41<00:08,  1.65s/it]

MAE: 6.7470, MSE:89.1128,
RMSE: 9.4400, IA: 0.8899, MAPE: 39.7848%


 96%|█████████▌| 96/100 [02:43<00:06,  1.65s/it]

MAE: 6.7438, MSE:91.6137,
RMSE: 9.5715, IA: 0.8861, MAPE: 40.0536%


 97%|█████████▋| 97/100 [02:44<00:04,  1.66s/it]

MAE: 6.6257, MSE:84.7505,
RMSE: 9.2060, IA: 0.8926, MAPE: 41.2884%


 98%|█████████▊| 98/100 [02:46<00:03,  1.65s/it]

MAE: 6.4774, MSE:86.2822,
RMSE: 9.2888, IA: 0.8892, MAPE: 35.9478%


 99%|█████████▉| 99/100 [02:48<00:01,  1.65s/it]

MAE: 7.0966, MSE:91.6264,
RMSE: 9.5722, IA: 0.8956, MAPE: 46.5617%


100%|██████████| 100/100 [02:49<00:00,  1.70s/it]

MAE: 6.6195, MSE:85.3354,
RMSE: 9.2377, IA: 0.8993, MAPE: 40.9363%
Accuracy on validation dataset: 47.74%





In [8]:
print(f'Accuracy on test set: {get_accuracy(model1, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model1, test_loader, device=device, print_e=True)

Accuracy on test set: 45.540000000000006%
MAE: 6.2182, MSE:70.0225,
RMSE: 8.3679, IA: 0.8697, MAPE: 43.0139%


(6.218238733389424,
 70.02248361721247,
 8.367943810591253,
 0.8697036546021201,
 43.01386963200978)

### SNN with Leaky layer

In [9]:
model2 = LeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1, num_steps=25)
print(model2)

LeakySNN(
  (fc1): Linear(in_features=8, out_features=3500, bias=True)
  (lif): Leaky()
  (fc2): Linear(in_features=3500, out_features=1, bias=True)
)


In [10]:
training_loop(model2, train_loader, valid_loader, device, num_epochs=25, lr=0.001, validation=True)

  4%|▍         | 1/25 [00:05<02:17,  5.74s/it]

MAE: 7.9596, MSE:135.3226,
RMSE: 11.6328, IA: 0.7911, MAPE: 48.9932%


  8%|▊         | 2/25 [00:11<02:11,  5.70s/it]

MAE: 8.1084, MSE:142.2802,
RMSE: 11.9281, IA: 0.7991, MAPE: 43.1768%


 12%|█▏        | 3/25 [00:17<02:05,  5.72s/it]

MAE: 7.8372, MSE:119.7983,
RMSE: 10.9452, IA: 0.8362, MAPE: 53.3887%


 16%|█▌        | 4/25 [00:23<02:03,  5.86s/it]

MAE: 7.7385, MSE:118.2577,
RMSE: 10.8746, IA: 0.8365, MAPE: 52.4866%


 20%|██        | 5/25 [00:28<01:55,  5.77s/it]

MAE: 7.4897, MSE:117.7949,
RMSE: 10.8533, IA: 0.8415, MAPE: 46.0646%


 24%|██▍       | 6/25 [00:34<01:49,  5.74s/it]

MAE: 7.9943, MSE:119.1464,
RMSE: 10.9154, IA: 0.8338, MAPE: 56.9661%


 28%|██▊       | 7/25 [00:40<01:45,  5.83s/it]

MAE: 7.4092, MSE:109.7635,
RMSE: 10.4768, IA: 0.8545, MAPE: 49.0831%


 32%|███▏      | 8/25 [00:46<01:41,  5.97s/it]

MAE: 7.2637, MSE:110.8114,
RMSE: 10.5267, IA: 0.8517, MAPE: 42.8303%


 36%|███▌      | 9/25 [00:52<01:34,  5.90s/it]

MAE: 7.4019, MSE:115.0410,
RMSE: 10.7257, IA: 0.8527, MAPE: 42.1132%


 40%|████      | 10/25 [00:57<01:26,  5.76s/it]

MAE: 7.2120, MSE:107.8789,
RMSE: 10.3865, IA: 0.8563, MAPE: 42.2246%


 44%|████▍     | 11/25 [01:03<01:19,  5.71s/it]

MAE: 7.5932, MSE:108.8017,
RMSE: 10.4308, IA: 0.8603, MAPE: 52.0337%


 48%|████▊     | 12/25 [01:09<01:13,  5.69s/it]

MAE: 7.0247, MSE:105.4520,
RMSE: 10.2690, IA: 0.8611, MAPE: 40.8763%


 52%|█████▏    | 13/25 [01:14<01:07,  5.61s/it]

MAE: 7.0599, MSE:102.7502,
RMSE: 10.1366, IA: 0.8634, MAPE: 43.2953%


 56%|█████▌    | 14/25 [01:20<01:01,  5.55s/it]

MAE: 6.8971, MSE:99.1440,
RMSE: 9.9571, IA: 0.8679, MAPE: 42.4431%


 60%|██████    | 15/25 [01:25<00:55,  5.53s/it]

MAE: 7.1370, MSE:107.7829,
RMSE: 10.3819, IA: 0.8565, MAPE: 39.5473%


 64%|██████▍   | 16/25 [01:31<00:49,  5.56s/it]

MAE: 7.1223, MSE:97.2903,
RMSE: 9.8636, IA: 0.8762, MAPE: 47.5625%


 68%|██████▊   | 17/25 [01:36<00:44,  5.52s/it]

MAE: 7.0613, MSE:107.8935,
RMSE: 10.3872, IA: 0.8601, MAPE: 37.4443%


 72%|███████▏  | 18/25 [01:42<00:38,  5.52s/it]

MAE: 7.0431, MSE:97.7386,
RMSE: 9.8863, IA: 0.8696, MAPE: 45.8088%


 76%|███████▌  | 19/25 [01:47<00:33,  5.56s/it]

MAE: 7.1445, MSE:104.4307,
RMSE: 10.2191, IA: 0.8673, MAPE: 40.6073%


 80%|████████  | 20/25 [01:53<00:28,  5.63s/it]

MAE: 7.0035, MSE:96.2414,
RMSE: 9.8103, IA: 0.8858, MAPE: 44.9744%


 84%|████████▍ | 21/25 [01:59<00:22,  5.60s/it]

MAE: 6.9275, MSE:92.4863,
RMSE: 9.6170, IA: 0.8762, MAPE: 45.5676%


 88%|████████▊ | 22/25 [02:04<00:16,  5.54s/it]

MAE: 6.8346, MSE:96.7540,
RMSE: 9.8364, IA: 0.8771, MAPE: 39.8861%


 92%|█████████▏| 23/25 [02:10<00:11,  5.61s/it]

MAE: 6.8767, MSE:90.3762,
RMSE: 9.5066, IA: 0.8793, MAPE: 44.8334%


 96%|█████████▌| 24/25 [02:15<00:05,  5.51s/it]

MAE: 6.7413, MSE:89.1360,
RMSE: 9.4412, IA: 0.8875, MAPE: 43.4872%


100%|██████████| 25/25 [02:21<00:00,  5.64s/it]

MAE: 7.1766, MSE:91.3375,
RMSE: 9.5571, IA: 0.8907, MAPE: 50.4521%
Accuracy on validation dataset: 44.73%





In [11]:
print(f'Accuracy on test set: {get_accuracy(model2, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model2, test_loader, device=device, print_e=True)

Accuracy on test set: 41.6%
MAE: 6.7399, MSE:74.5174,
RMSE: 8.6323, IA: 0.8647, MAPE: 51.8016%


(6.739942054276118,
 74.51742005584002,
 8.632347308573724,
 0.8646636169928696,
 51.8015526732084)

### SNN with Synaptic layer

In [12]:
model3 = SynapticSNN(num_inputs=8, num_hidden=3500, num_outputs=1, num_steps=50)
print(model3)

SynapticSNN(
  (fc1): Linear(in_features=8, out_features=3500, bias=True)
  (lif): Synaptic()
  (fc2): Linear(in_features=3500, out_features=1, bias=True)
)


In [13]:
training_loop(model3, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=True)

  4%|▍         | 1/25 [00:11<04:27, 11.13s/it]

MAE: 9.8072, MSE:176.2153,
RMSE: 13.2746, IA: 0.7421, MAPE: 49.8143%


  8%|▊         | 2/25 [00:22<04:24, 11.50s/it]

MAE: 7.5273, MSE:115.3310,
RMSE: 10.7392, IA: 0.8072, MAPE: 42.3166%


 12%|█▏        | 3/25 [00:34<04:14, 11.58s/it]

MAE: 7.9439, MSE:108.2215,
RMSE: 10.4030, IA: 0.8204, MAPE: 56.5950%


 16%|█▌        | 4/25 [00:47<04:11, 11.96s/it]

MAE: 7.2483, MSE:108.8404,
RMSE: 10.4327, IA: 0.8243, MAPE: 38.2666%


 20%|██        | 5/25 [00:58<03:53, 11.67s/it]

MAE: 8.4624, MSE:139.5936,
RMSE: 11.8150, IA: 0.7873, MAPE: 43.1950%


 24%|██▍       | 6/25 [01:09<03:36, 11.42s/it]

MAE: 6.7910, MSE:92.3714,
RMSE: 9.6110, IA: 0.8409, MAPE: 41.5429%


 28%|██▊       | 7/25 [01:20<03:23, 11.30s/it]

MAE: 7.4596, MSE:95.1194,
RMSE: 9.7529, IA: 0.8417, MAPE: 52.6353%


 32%|███▏      | 8/25 [01:31<03:09, 11.17s/it]

MAE: 6.8743, MSE:95.6666,
RMSE: 9.7809, IA: 0.8441, MAPE: 37.7746%


 36%|███▌      | 9/25 [01:42<02:59, 11.19s/it]

MAE: 6.9872, MSE:86.8795,
RMSE: 9.3209, IA: 0.8502, MAPE: 48.3343%


 40%|████      | 10/25 [01:53<02:48, 11.22s/it]

MAE: 8.9302, MSE:121.4512,
RMSE: 11.0205, IA: 0.8164, MAPE: 69.2475%


 44%|████▍     | 11/25 [02:04<02:35, 11.13s/it]

MAE: 9.9621, MSE:161.7495,
RMSE: 12.7181, IA: 0.7818, MAPE: 53.0240%


 48%|████▊     | 12/25 [02:15<02:23, 11.05s/it]

MAE: 6.4402, MSE:77.9675,
RMSE: 8.8299, IA: 0.8776, MAPE: 41.7142%


 52%|█████▏    | 13/25 [02:26<02:13, 11.12s/it]

MAE: 6.4044, MSE:81.5060,
RMSE: 9.0281, IA: 0.8769, MAPE: 39.5731%


 56%|█████▌    | 14/25 [02:38<02:03, 11.26s/it]

MAE: 6.5148, MSE:82.7570,
RMSE: 9.0971, IA: 0.8740, MAPE: 39.1355%


 60%|██████    | 15/25 [02:49<01:51, 11.18s/it]

MAE: 6.9838, MSE:82.9067,
RMSE: 9.1053, IA: 0.8638, MAPE: 49.3741%


 64%|██████▍   | 16/25 [03:00<01:41, 11.24s/it]

MAE: 6.4959, MSE:81.9430,
RMSE: 9.0522, IA: 0.8791, MAPE: 36.3236%


 68%|██████▊   | 17/25 [03:12<01:30, 11.26s/it]

MAE: 7.1766, MSE:94.9227,
RMSE: 9.7428, IA: 0.8623, MAPE: 38.9896%


 72%|███████▏  | 18/25 [03:23<01:18, 11.23s/it]

MAE: 7.0126, MSE:85.6800,
RMSE: 9.2563, IA: 0.8779, MAPE: 51.0050%


 76%|███████▌  | 19/25 [03:34<01:06, 11.11s/it]

MAE: 7.0004, MSE:83.5911,
RMSE: 9.1428, IA: 0.8796, MAPE: 50.9350%


 80%|████████  | 20/25 [03:45<00:56, 11.29s/it]

MAE: 6.1666, MSE:74.3280,
RMSE: 8.6214, IA: 0.8856, MAPE: 37.2303%


 84%|████████▍ | 21/25 [03:58<00:46, 11.72s/it]

MAE: 6.2441, MSE:77.3142,
RMSE: 8.7929, IA: 0.8848, MAPE: 34.2480%


 88%|████████▊ | 22/25 [04:09<00:34, 11.52s/it]

MAE: 6.8726, MSE:86.5666,
RMSE: 9.3041, IA: 0.8810, MAPE: 38.5694%


 92%|█████████▏| 23/25 [04:21<00:23, 11.53s/it]

MAE: 6.2223, MSE:72.8626,
RMSE: 8.5360, IA: 0.8973, MAPE: 42.0917%


 96%|█████████▌| 24/25 [04:32<00:11, 11.49s/it]

MAE: 6.1031, MSE:72.4878,
RMSE: 8.5140, IA: 0.8937, MAPE: 37.7609%


100%|██████████| 25/25 [04:43<00:00, 11.35s/it]

MAE: 9.0045, MSE:121.9081,
RMSE: 11.0412, IA: 0.8434, MAPE: 70.2105%
Accuracy on validation dataset: 34.300000000000004%





In [14]:
print(f'Accuracy on test set: {get_accuracy(model3, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model3, test_loader, device=device, print_e=True)

Accuracy on test set: 30.48%
MAE: 9.3394, MSE:128.4582,
RMSE: 11.3339, IA: 0.8386, MAPE: 74.3396%


(9.339397490922435,
 128.4582106092612,
 11.33394064786212,
 0.8385622808506112,
 74.33962311672066)

### SNN with two Leaky layers

In [15]:
model4 = DoubleLeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model4)

DoubleLeakySNN(
  (lin1): Linear(in_features=8, out_features=3500, bias=True)
  (lif1): Leaky()
  (lif2): Leaky()
  (lin2): Linear(in_features=3500, out_features=1, bias=True)
)


In [16]:
training_loop(model4, train_loader, valid_loader, device, num_epochs=25, lr=0.002, validation=True)

  4%|▍         | 1/25 [00:10<04:06, 10.28s/it]

MAE: 20.7912, MSE:648.0209,
RMSE: 25.4563, IA: 0.4117, MAPE: 96.9298%


  8%|▊         | 2/25 [00:21<04:03, 10.60s/it]

MAE: 9.2240, MSE:158.7891,
RMSE: 12.6012, IA: 0.6237, MAPE: 61.6509%


 12%|█▏        | 3/25 [00:31<03:52, 10.57s/it]

MAE: 8.4034, MSE:137.2468,
RMSE: 11.7152, IA: 0.7221, MAPE: 54.9186%


 16%|█▌        | 4/25 [00:41<03:37, 10.34s/it]

MAE: 8.2320, MSE:134.0473,
RMSE: 11.5779, IA: 0.7350, MAPE: 50.8459%


 20%|██        | 5/25 [00:52<03:27, 10.40s/it]

MAE: 8.1703, MSE:130.1089,
RMSE: 11.4065, IA: 0.7500, MAPE: 51.8916%


 24%|██▍       | 6/25 [01:02<03:16, 10.33s/it]

MAE: 8.1002, MSE:131.0900,
RMSE: 11.4495, IA: 0.7350, MAPE: 49.8911%


 28%|██▊       | 7/25 [01:12<03:03, 10.20s/it]

MAE: 8.5063, MSE:131.3157,
RMSE: 11.4593, IA: 0.7453, MAPE: 58.2378%


 32%|███▏      | 8/25 [01:22<02:52, 10.14s/it]

MAE: 8.0309, MSE:128.6205,
RMSE: 11.3411, IA: 0.7503, MAPE: 49.2924%


 36%|███▌      | 9/25 [01:32<02:41, 10.11s/it]

MAE: 7.8740, MSE:128.7604,
RMSE: 11.3473, IA: 0.7494, MAPE: 46.6007%


 40%|████      | 10/25 [01:42<02:30, 10.04s/it]

MAE: 8.1021, MSE:129.0207,
RMSE: 11.3587, IA: 0.7521, MAPE: 51.3649%


 44%|████▍     | 11/25 [01:52<02:20, 10.04s/it]

MAE: 7.9705, MSE:128.4749,
RMSE: 11.3347, IA: 0.7518, MAPE: 49.5858%


 48%|████▊     | 12/25 [02:02<02:09,  9.99s/it]

MAE: 8.0970, MSE:126.5448,
RMSE: 11.2492, IA: 0.7533, MAPE: 52.7670%


 52%|█████▏    | 13/25 [02:11<01:59,  9.94s/it]

MAE: 8.1205, MSE:131.2229,
RMSE: 11.4553, IA: 0.7533, MAPE: 51.9379%


 56%|█████▌    | 14/25 [02:22<01:50, 10.04s/it]

MAE: 7.8519, MSE:125.5747,
RMSE: 11.2060, IA: 0.7593, MAPE: 48.5188%


 60%|██████    | 15/25 [02:32<01:40, 10.01s/it]

MAE: 7.7946, MSE:128.8636,
RMSE: 11.3518, IA: 0.7528, MAPE: 45.5502%


 64%|██████▍   | 16/25 [02:42<01:31, 10.20s/it]

MAE: 7.7812, MSE:126.7176,
RMSE: 11.2569, IA: 0.7570, MAPE: 46.8592%


 68%|██████▊   | 17/25 [02:52<01:21, 10.21s/it]

MAE: 7.8485, MSE:125.7920,
RMSE: 11.2157, IA: 0.7559, MAPE: 47.8383%


 72%|███████▏  | 18/25 [03:04<01:14, 10.60s/it]

MAE: 8.4004, MSE:129.4541,
RMSE: 11.3778, IA: 0.7528, MAPE: 58.1428%


 76%|███████▌  | 19/25 [03:14<01:02, 10.48s/it]

MAE: 8.1605, MSE:124.3647,
RMSE: 11.1519, IA: 0.7696, MAPE: 56.0116%


 80%|████████  | 20/25 [03:24<00:51, 10.40s/it]

MAE: 7.7097, MSE:122.0141,
RMSE: 11.0460, IA: 0.7621, MAPE: 47.8783%


 84%|████████▍ | 21/25 [03:35<00:41, 10.32s/it]

MAE: 7.8617, MSE:122.3035,
RMSE: 11.0591, IA: 0.7666, MAPE: 50.3110%


 88%|████████▊ | 22/25 [03:45<00:30, 10.23s/it]

MAE: 7.7820, MSE:118.6304,
RMSE: 10.8918, IA: 0.7740, MAPE: 50.6737%


 92%|█████████▏| 23/25 [03:55<00:20, 10.32s/it]

MAE: 7.7632, MSE:118.0005,
RMSE: 10.8628, IA: 0.7812, MAPE: 49.5265%


 96%|█████████▌| 24/25 [04:06<00:10, 10.53s/it]

MAE: 7.6525, MSE:118.2745,
RMSE: 10.8754, IA: 0.7762, MAPE: 47.5426%


100%|██████████| 25/25 [04:17<00:00, 10.29s/it]

MAE: 7.5988, MSE:115.5370,
RMSE: 10.7488, IA: 0.7869, MAPE: 47.7704%





Accuracy on validation dataset: 41.14%


In [17]:
print(f'Accuracy on test set: {get_accuracy(model4, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model4, test_loader, device=device, print_e=True)

Accuracy on test set: 41.6%
MAE: 7.2979, MSE:105.2202,
RMSE: 10.2577, IA: 0.8092, MAPE: 50.2140%


(7.2979026539715255,
 105.22018341455497,
 10.257688989950658,
 0.809150441530647,
 50.21399957549717)

### Repeated training on random datasets

In [5]:
err_model1 = [0., 0., 0., 0., 0., 0.]
err_model2 = [0., 0., 0., 0., 0., 0.]
err_model3 = [0., 0., 0., 0., 0., 0.]
err_model4 = [0., 0., 0., 0., 0., 0.]

In [6]:
trials = 2
for i in range(trials):
    # Prepare data loaders.
    train_dataset, test_dataset = torch.utils.data.random_split(all_data, (round(0.8 * len(all_data)), round(0.2 * len(all_data))))
    train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
    test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)

    # Prepare models.
    model1 = MultiLayerANN(num_inputs=8, num_hidden=4000, num_outputs=1)
    model2 = LeakySNN(num_inputs=8, num_hidden=4000, num_outputs=1, num_steps=25)
    model3 = SynapticSNN(num_inputs=8, num_hidden=4000, num_outputs=1, num_steps=50)
    model4 = DoubleLeakySNN(num_inputs=8, num_hidden=4000, num_outputs=1, num_steps=20)

    # Train models.
    print(f'------------------------------------\nTraining {i+1}')
    training_loop(model1, train_loader, test_loader, device, num_epochs=100, lr=0.003, validation=False) #90
    training_loop(model2, train_loader, test_loader, device, num_epochs=60, lr=0.001, validation=False) #50
    training_loop(model3, train_loader, test_loader, device, num_epochs=30, lr=0.0001, validation=False) #25
    training_loop(model4, train_loader, test_loader, device, num_epochs=40, lr=0.001, validation=False) #35

    # Gather results.
    collect_stats(model1, err_model1, test_loader, device)
    collect_stats(model2, err_model2, test_loader, device)
    collect_stats(model3, err_model3, test_loader, device)
    collect_stats(model4, err_model4, test_loader, device)

------------------------------------
Training 1


100%|██████████| 100/100 [03:30<00:00,  2.10s/it]


Accuracy on validation dataset: 42.0%


100%|██████████| 60/60 [05:12<00:00,  5.21s/it]


Accuracy on validation dataset: 50.63999999999999%


100%|██████████| 30/30 [05:21<00:00, 10.71s/it]


Accuracy on validation dataset: 47.68%


100%|██████████| 40/40 [05:17<00:00,  7.94s/it]


Accuracy on validation dataset: 48.84%
------------------------------------
Training 2


 22%|██▏       | 22/100 [00:47<02:49,  2.17s/it]


KeyboardInterrupt: 

In [10]:
# Print average results
print('------Simple ANN ------')
print_average_stats(err_model1, trials)
print('------LIF SNN ------')
print_average_stats(err_model2, trials)
print('------Synaptic SNN ------')
print_average_stats(err_model3, trials)
print('------Double LIF SNN ------')
print_average_stats(err_model4, trials)

------Simple ANN ------
MAE: 5.733594991551628, MSE:63.580033061411996,
RMSE: 7.967944587224418, IA: 0.9164841564885322,
MAPE: 38.03590761235476 %, acc: 51.449999999999996 %

------LIF SNN ------
MAE: 6.2430069344668615, MSE:69.54496934137244,
RMSE: 8.339354083040947, IA: 0.9103193725694927,
MAPE: 44.6369944580173 %, acc: 48.089999999999996 %

------Synaptic SNN ------
MAE: 9.264562545325347, MSE:135.060263711052,
RMSE: 11.233460025005105, IA: 0.8579632085418573,
MAPE: 68.35216490927189 %, acc: 33.75 %

------Double LIF SNN ------
MAE: 6.685834302134199, MSE:80.7118603840073,
RMSE: 8.983535278879645, IA: 0.8816772396448564,
MAPE: 48.23070322089788 %, acc: 45.655 %

