In [1]:
import torch
import torch.utils.data as data
import pandas as pd
from src.error_measures import get_accuracy, get_error_measures, collect_stats, print_average_stats
from src.models import TwoLayerPerceptron, SynapticSNN, LeakySNN, DoubleLeakySNN
from src.train_model import training_loop

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

## Data

In [5]:
data06 = pd.read_csv("../data/processed/target06.csv")
data06.drop(columns=["weekday"], inplace=True)

In [6]:
all_data = data.TensorDataset(torch.from_numpy((data06.values[:,:-1] - data06.values[:,:-1].min(0)) / data06.values[:,:-1].ptp(0)).float(),
                              torch.from_numpy(data06.values[:,-1]).float())  # with normalization
train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(all_data, (round(0.75 * len(all_data)), round(0.15 * len(all_data)), round(0.1 * len(all_data))))

In [5]:
train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)
valid_loader = data.DataLoader(valid_dataset, batch_size=256, shuffle=False, drop_last=False)

### Simple regressor

In [6]:
model1 = TwoLayerPerceptron(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model1)

TwoLayerPerceptron(
  (lin1): Linear(in_features=8, out_features=4000, bias=True)
  (relu): LeakyReLU(negative_slope=0.01)
  (lin2): Linear(in_features=4000, out_features=1, bias=True)
)


In [7]:
training_loop(model1, train_loader, valid_loader, device, num_epochs=50, lr=0.003, validation=True)

  2%|▏         | 1/50 [00:02<02:04,  2.54s/it]

MAE: 6.4051214022443475, MSE:79.3088850732351,
RMSE: 8.905553608464501, IA: 0.8785967859374101, MAPE: 42.524250340555184%


  4%|▍         | 2/50 [00:02<01:02,  1.30s/it]

MAE: 6.299213872755194, MSE:78.90802608016764,
RMSE: 8.883018973308998, IA: 0.8758917524999306, MAPE: 39.158839261370794%


  6%|▌         | 3/50 [00:03<00:41,  1.12it/s]

MAE: 6.271410387788894, MSE:80.03767952047497,
RMSE: 8.946378011266624, IA: 0.8772346629668094, MAPE: 36.46787920801433%


  8%|▊         | 4/50 [00:03<00:31,  1.45it/s]

MAE: 6.207237047129284, MSE:77.32932744069036,
RMSE: 8.793709538112477, IA: 0.8880625017216082, MAPE: 39.23330118028064%


 10%|█         | 5/50 [00:04<00:27,  1.64it/s]

MAE: 6.416439289302495, MSE:77.38876063217583,
RMSE: 8.797088190542132, IA: 0.8874306332624241, MAPE: 44.518195396275374%


 12%|█▏        | 6/50 [00:04<00:24,  1.79it/s]

MAE: 6.186369993231889, MSE:75.29315086723453,
RMSE: 8.677162604632606, IA: 0.8904558747353989, MAPE: 40.52138969144555%


 14%|█▍        | 7/50 [00:05<00:22,  1.94it/s]

MAE: 6.927017348901384, MSE:82.460211362156,
RMSE: 9.080760505715146, IA: 0.8727929518006523, MAPE: 50.23427427352202%


 16%|█▌        | 8/50 [00:05<00:19,  2.14it/s]

MAE: 6.384739850160014, MSE:75.92615491317592,
RMSE: 8.713561551580153, IA: 0.8854421801885228, MAPE: 43.92370138075016%


 18%|█▊        | 9/50 [00:05<00:17,  2.30it/s]

MAE: 6.233028939693649, MSE:75.30660626974688,
RMSE: 8.677937904234328, IA: 0.8903034331183179, MAPE: 41.94448019268794%


 20%|██        | 10/50 [00:06<00:16,  2.38it/s]

MAE: 6.553861319536418, MSE:79.05237962765219,
RMSE: 8.891140513322922, IA: 0.8902806551388862, MAPE: 46.85617150957733%


 22%|██▏       | 11/50 [00:06<00:16,  2.40it/s]

MAE: 6.457421969540547, MSE:76.18995326178177,
RMSE: 8.72868565488423, IA: 0.8864801464053479, MAPE: 45.52513111430816%


 24%|██▍       | 12/50 [00:07<00:15,  2.40it/s]

MAE: 6.440971144361992, MSE:76.44063559338304,
RMSE: 8.743033546394697, IA: 0.8897201108490322, MAPE: 45.40229981502482%


 26%|██▌       | 13/50 [00:07<00:15,  2.40it/s]

MAE: 6.128716802045789, MSE:75.02778459311108,
RMSE: 8.661858033534784, IA: 0.8959283746025719, MAPE: 38.43615751836086%


 28%|██▊       | 14/50 [00:07<00:15,  2.32it/s]

MAE: 6.043179189263052, MSE:72.54032090451722,
RMSE: 8.517060578892064, IA: 0.8905255553831019, MAPE: 38.06660234617566%


 30%|███       | 15/50 [00:08<00:15,  2.28it/s]

MAE: 7.038173427747164, MSE:83.91367649029426,
RMSE: 9.160440845848756, IA: 0.8782723130066163, MAPE: 52.673451835245274%


 32%|███▏      | 16/50 [00:08<00:14,  2.32it/s]

MAE: 6.938670584783389, MSE:82.6312187824583,
RMSE: 9.090171548571474, IA: 0.8833129574822355, MAPE: 51.146407202257684%


 34%|███▍      | 17/50 [00:09<00:13,  2.40it/s]

MAE: 6.093248572377111, MSE:73.20169938224821,
RMSE: 8.555799166778531, IA: 0.8950210183297975, MAPE: 36.00949065373054%


 36%|███▌      | 18/50 [00:09<00:12,  2.47it/s]

MAE: 6.6051213939754945, MSE:77.92147595334659,
RMSE: 8.827314198177529, IA: 0.8883359240767152, MAPE: 47.14973068061488%


 38%|███▊      | 19/50 [00:09<00:12,  2.40it/s]

MAE: 6.07305204716721, MSE:72.7984036469645,
RMSE: 8.532198054837012, IA: 0.8977157853396898, MAPE: 39.94228902027145%


 40%|████      | 20/50 [00:10<00:13,  2.29it/s]

MAE: 6.225803497898785, MSE:72.8445479069054,
RMSE: 8.534901751450066, IA: 0.8923846101851428, MAPE: 42.57338082678387%


 42%|████▏     | 21/50 [00:10<00:13,  2.22it/s]

MAE: 6.101953764733551, MSE:74.3158715029691,
RMSE: 8.620665374724222, IA: 0.8854414239086464, MAPE: 36.364435610497246%


 44%|████▍     | 22/50 [00:11<00:13,  2.15it/s]

MAE: 6.044092513233251, MSE:71.34312345813352,
RMSE: 8.44648586443697, IA: 0.8937687694397656, MAPE: 38.01182705784589%


 46%|████▌     | 23/50 [00:11<00:12,  2.23it/s]

MAE: 5.963398841350754, MSE:70.09424745763867,
RMSE: 8.372230733659856, IA: 0.9031498000628213, MAPE: 37.138666256753176%


 48%|████▊     | 24/50 [00:12<00:11,  2.33it/s]

MAE: 5.976450871456565, MSE:70.72509565467213,
RMSE: 8.409821380663928, IA: 0.8957687679344316, MAPE: 35.82601314713089%


 50%|█████     | 25/50 [00:12<00:10,  2.34it/s]

MAE: 5.917768751265686, MSE:68.34317621330864,
RMSE: 8.26699317849656, IA: 0.9020981853446919, MAPE: 37.490490852941306%


 52%|█████▏    | 26/50 [00:13<00:10,  2.40it/s]

MAE: 5.940843936887091, MSE:69.65158654010176,
RMSE: 8.345752604774585, IA: 0.9062034184124232, MAPE: 38.44228844784925%


 54%|█████▍    | 27/50 [00:13<00:09,  2.45it/s]

MAE: 6.022127167200077, MSE:70.32554838568777,
RMSE: 8.386032934927442, IA: 0.9032695311454151, MAPE: 40.09700729055178%


 56%|█████▌    | 28/50 [00:13<00:08,  2.48it/s]

MAE: 5.91191909740426, MSE:68.95417763018739,
RMSE: 8.303865222303852, IA: 0.9038616200557728, MAPE: 38.39156535560088%


 58%|█████▊    | 29/50 [00:14<00:08,  2.51it/s]

MAE: 6.0055491334441085, MSE:70.08765876172001,
RMSE: 8.371837239323279, IA: 0.9046336159254613, MAPE: 38.345783745508584%


 60%|██████    | 30/50 [00:14<00:07,  2.52it/s]

MAE: 5.88531794134592, MSE:67.25808692418497,
RMSE: 8.201102787075953, IA: 0.9044747538399378, MAPE: 37.54187244251287%


 62%|██████▏   | 31/50 [00:15<00:07,  2.53it/s]

MAE: 6.07835836335586, MSE:74.21209947324814,
RMSE: 8.61464447747254, IA: 0.8827565299747807, MAPE: 35.121052075934244%


 64%|██████▍   | 32/50 [00:15<00:07,  2.48it/s]

MAE: 5.956450861175625, MSE:68.71693450978437,
RMSE: 8.289567812002286, IA: 0.8959004825884973, MAPE: 38.339184339631274%


 66%|██████▌   | 33/50 [00:15<00:07,  2.38it/s]

MAE: 5.951028931071993, MSE:69.02643165722061,
RMSE: 8.308214709383757, IA: 0.8981497501913815, MAPE: 38.48412265409714%


 68%|██████▊   | 34/50 [00:16<00:06,  2.35it/s]

MAE: 5.853109810807112, MSE:68.53356590734404,
RMSE: 8.278500220894124, IA: 0.9002524084457939, MAPE: 35.23531513953105%


 70%|███████   | 35/50 [00:16<00:06,  2.16it/s]

MAE: 5.85172254543084, MSE:69.51884374617303,
RMSE: 8.337796096461764, IA: 0.907952420427697, MAPE: 34.1462946103782%


 72%|███████▏  | 36/50 [00:17<00:06,  2.05it/s]

MAE: 6.106023166393269, MSE:70.19687674146135,
RMSE: 8.378357639863635, IA: 0.9027933711351099, MAPE: 41.10135389400837%


 74%|███████▍  | 37/50 [00:17<00:06,  2.04it/s]

MAE: 6.01876299836043, MSE:69.10796952861593,
RMSE: 8.313120324439911, IA: 0.9030620303417666, MAPE: 39.154472614663%


 76%|███████▌  | 38/50 [00:18<00:05,  2.12it/s]

MAE: 5.955768815630433, MSE:70.08840862026327,
RMSE: 8.371882023790306, IA: 0.8907957677687555, MAPE: 36.578014794232416%


 78%|███████▊  | 39/50 [00:18<00:05,  2.17it/s]

MAE: 5.799433559075945, MSE:65.57300077616117,
RMSE: 8.09771577521471, IA: 0.9071932612573939, MAPE: 35.879388123436875%


 80%|████████  | 40/50 [00:19<00:04,  2.27it/s]

MAE: 5.9918381453938565, MSE:67.9416039833256,
RMSE: 8.242669712133662, IA: 0.8951590283977784, MAPE: 38.93895263635732%


 82%|████████▏ | 41/50 [00:19<00:03,  2.29it/s]

MAE: 6.084023120775388, MSE:70.99212976839394,
RMSE: 8.425682747907967, IA: 0.8861882714109898, MAPE: 39.83621081338568%


 84%|████████▍ | 42/50 [00:20<00:03,  2.26it/s]

MAE: 5.806092500600512, MSE:65.80006234747617,
RMSE: 8.111723759317508, IA: 0.9095232330463755, MAPE: 36.76927082923337%


 86%|████████▌ | 43/50 [00:20<00:03,  2.30it/s]

MAE: 5.904693645893494, MSE:67.76200486026109,
RMSE: 8.231768027602643, IA: 0.8979244502542412, MAPE: 37.26804160513908%


 88%|████████▊ | 44/50 [00:20<00:02,  2.38it/s]

MAE: 6.365549131349332, MSE:71.61426052263313,
RMSE: 8.462520931887443, IA: 0.8960675047123042, MAPE: 45.1336128659809%


 90%|█████████ | 45/50 [00:21<00:02,  2.46it/s]

MAE: 5.806161829019557, MSE:66.98748198405791,
RMSE: 8.184588076626575, IA: 0.9013521903398554, MAPE: 34.78563383385102%


 92%|█████████▏| 46/50 [00:21<00:01,  2.51it/s]

MAE: 5.883595385096665, MSE:66.91385602243292,
RMSE: 8.180088998442066, IA: 0.9050182169223281, MAPE: 37.63576827509417%


 94%|█████████▍| 47/50 [00:22<00:01,  2.55it/s]

MAE: 6.072982614026594, MSE:68.02183902119981,
RMSE: 8.247535330097095, IA: 0.9050917200300418, MAPE: 42.060774062262155%


 96%|█████████▌| 48/50 [00:22<00:00,  2.54it/s]

MAE: 6.004404685676443, MSE:69.19556829842571,
RMSE: 8.318387361648032, IA: 0.9096788244576383, MAPE: 38.21854005419703%


 98%|█████████▊| 49/50 [00:22<00:00,  2.38it/s]

MAE: 5.89169941168989, MSE:65.71339050484514,
RMSE: 8.106379617612609, IA: 0.9057495091706491, MAPE: 39.466435195674585%


100%|██████████| 50/50 [00:23<00:00,  2.13it/s]

MAE: 6.081988444907128, MSE:67.35857775491974,
RMSE: 8.207227166036025, IA: 0.9048248260874595, MAPE: 42.347686686804984%
Accuracy on validation dataset: 49.94%





In [8]:
print(f'Accuracy on test set: {get_accuracy(model1, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model1, test_loader, device=device, print_e=True)

Accuracy on test set: 51.23%
MAE: 6.275693345510721, MSE:76.93435649748643,
RMSE: 8.771223204176623, IA: 0.9031492589612802, MAPE: 43.58376729999973%


(6.275693345510721,
 76.93435649748643,
 8.771223204176623,
 0.9031492589612802,
 43.58376729999973)

### SNN with Leaky layer

In [9]:
model2 = LeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model2)

LeakySNN(
  (fc1): Linear(in_features=8, out_features=4000, bias=True)
  (lif): Leaky()
  (fc2): Linear(in_features=4000, out_features=1, bias=True)
)


In [10]:
training_loop(model2, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=True)

  4%|▍         | 1/25 [00:05<02:07,  5.33s/it]

MAE: 11.933502882305596, MSE:322.06663314020034,
RMSE: 17.946215008747675, IA: 0.40790693234759356, MAPE: 55.7255608113032%


  8%|▊         | 2/25 [00:10<01:57,  5.10s/it]

MAE: 9.149780340690834, MSE:170.43816668488344,
RMSE: 13.055196922485829, IA: 0.5967077899677402, MAPE: 55.389432357148706%


 12%|█▏        | 3/25 [00:15<01:55,  5.25s/it]

MAE: 7.715248570139009, MSE:115.35275177845116,
RMSE: 10.740239838032071, IA: 0.7777177457687532, MAPE: 49.156550814343106%


 16%|█▌        | 4/25 [00:21<01:54,  5.47s/it]

MAE: 7.0630519858674505, MSE:94.63635594455802,
RMSE: 9.728121912504902, IA: 0.8337309889253345, MAPE: 46.09033376554385%


 16%|█▌        | 4/25 [00:25<02:13,  6.35s/it]


KeyboardInterrupt: 

In [None]:
print(f'Accuracy on test set: {get_accuracy(model2, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model2, test_loader, device=device, print_e=True)

### SNN with Synaptic layer

In [14]:
model3 = SynapticSNN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model3)

SynapticSNN(
  (fc1): Linear(in_features=8, out_features=3500, bias=True)
  (lif): Synaptic()
  (fc2): Linear(in_features=3500, out_features=1, bias=True)
)


In [15]:
training_loop(model3, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=True)

  4%|▍         | 1/25 [00:20<08:16, 20.69s/it]

MAE: 8.232242779373433, MSE:126.25745839541996,
RMSE: 11.236434416460586, IA: 0.6070276439054336, MAPE: 54.66549650346819%


  8%|▊         | 2/25 [00:42<08:05, 21.11s/it]

MAE: 6.978520220276938, MSE:88.20087729500598,
RMSE: 9.391532212317966, IA: 0.805832818532277, MAPE: 47.826281028873815%


 12%|█▏        | 3/25 [01:02<07:34, 20.65s/it]

MAE: 6.3158844137467405, MSE:79.77807400592347,
RMSE: 8.931857254005097, IA: 0.8379919477141724, MAPE: 38.47185910984739%


 16%|█▌        | 4/25 [01:22<07:10, 20.50s/it]

MAE: 6.232161886292386, MSE:76.52209639538542,
RMSE: 8.747690917915735, IA: 0.8504049042767974, MAPE: 38.26191885784741%


 20%|██        | 5/25 [01:42<06:46, 20.33s/it]

MAE: 6.281028938293457, MSE:73.75185154431658,
RMSE: 8.587889819060127, IA: 0.8602826226614384, MAPE: 41.80323604445604%


 24%|██▍       | 6/25 [02:02<06:23, 20.17s/it]

MAE: 6.125375750436948, MSE:74.60001507867761,
RMSE: 8.63713002557433, IA: 0.8602089681803686, MAPE: 37.130535355743135%


 28%|██▊       | 7/25 [02:22<06:02, 20.13s/it]

MAE: 6.076936477594982, MSE:73.50648667037058,
RMSE: 8.573592401693153, IA: 0.8618027968762119, MAPE: 36.53092419724583%


 32%|███▏      | 8/25 [02:42<05:41, 20.10s/it]

MAE: 6.294612753873616, MSE:73.00640616357931,
RMSE: 8.544378629460384, IA: 0.8639477661359274, MAPE: 42.65831862131437%


 36%|███▌      | 9/25 [03:02<05:21, 20.08s/it]

MAE: 6.053930707060533, MSE:70.94127430110697,
RMSE: 8.422664323188178, IA: 0.8677999872388572, MAPE: 39.04185037964455%


 40%|████      | 10/25 [03:23<05:04, 20.29s/it]

MAE: 6.002647443727262, MSE:71.22436021176593,
RMSE: 8.439452601428954, IA: 0.866418961747955, MAPE: 36.85926874398883%


 44%|████▍     | 11/25 [03:44<04:48, 20.63s/it]

MAE: 5.993225437781715, MSE:70.36395805736726,
RMSE: 8.38832272014896, IA: 0.8680020485014478, MAPE: 36.93761026130568%


 48%|████▊     | 12/25 [04:05<04:30, 20.78s/it]

MAE: 5.948774593000467, MSE:71.02430112678105,
RMSE: 8.427591656385651, IA: 0.8670318241747071, MAPE: 35.576342620932536%


 52%|█████▏    | 13/25 [04:27<04:13, 21.09s/it]

MAE: 5.975514470359493, MSE:69.36702207929919,
RMSE: 8.32868669595028, IA: 0.870724981009557, MAPE: 37.33083914809793%


 56%|█████▌    | 14/25 [04:48<03:52, 21.15s/it]

MAE: 5.894092531149098, MSE:68.53197434676252,
RMSE: 8.27840409419367, IA: 0.8731405914815575, MAPE: 36.77631637002442%


 60%|██████    | 15/25 [05:09<03:28, 20.87s/it]

MAE: 5.933433555316374, MSE:68.42408935190038,
RMSE: 8.271885477441064, IA: 0.8739766340143968, MAPE: 37.73758726169585%


 64%|██████▍   | 16/25 [05:29<03:05, 20.60s/it]

MAE: 5.92102890951785, MSE:68.86421897722748,
RMSE: 8.298446781008328, IA: 0.8719813332644545, MAPE: 36.418616113804894%


 68%|██████▊   | 17/25 [05:49<02:44, 20.54s/it]

MAE: 6.035075180516767, MSE:68.76163673996312,
RMSE: 8.292263668019917, IA: 0.873995518924632, MAPE: 39.60706223538828%


 72%|███████▏  | 18/25 [06:11<02:25, 20.85s/it]

MAE: 5.863618543106696, MSE:68.92748172990086,
RMSE: 8.302257628494846, IA: 0.8746410940082734, MAPE: 35.10357105166152%


 76%|███████▌  | 19/25 [06:32<02:05, 20.96s/it]

MAE: 5.8843005808791675, MSE:67.03613780394473,
RMSE: 8.18755994200621, IA: 0.8769188140218168, MAPE: 36.48175569601958%


 80%|████████  | 20/25 [06:54<01:46, 21.23s/it]

MAE: 5.833306401313385, MSE:66.89378679768737,
RMSE: 8.178862194565168, IA: 0.8766195585491136, MAPE: 35.261914671867814%


 84%|████████▍ | 21/25 [07:15<01:25, 21.30s/it]

MAE: 5.87375724108922, MSE:67.4497528427378,
RMSE: 8.212779848671081, IA: 0.8756766168725385, MAPE: 36.56461557059894%


 88%|████████▊ | 22/25 [07:36<01:03, 21.20s/it]

MAE: 5.875225478376267, MSE:65.82757989442118,
RMSE: 8.113419741047617, IA: 0.8804476737484646, MAPE: 37.94435736982202%


 92%|█████████▏| 23/25 [07:57<00:42, 21.15s/it]

MAE: 5.758878638427381, MSE:64.9163121407723,
RMSE: 8.05706597594759, IA: 0.8827841413658402, MAPE: 35.91404247108064%


 96%|█████████▌| 24/25 [08:17<00:20, 20.84s/it]

MAE: 5.9485664869319494, MSE:65.26394239364875,
RMSE: 8.078610177106501, IA: 0.8792537322671139, MAPE: 39.64529049023867%


100%|██████████| 25/25 [08:37<00:00, 20.71s/it]

MAE: 5.7148208403173895, MSE:63.94562185040342,
RMSE: 7.996600643423643, IA: 0.8838679519829968, MAPE: 35.31117355672974%





Accuracy on validation dataset: 51.559999999999995%


In [16]:
print(f'Accuracy on test set: {get_accuracy(model3, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model3, test_loader, device=device, print_e=True)

Accuracy on test set: 54.620000000000005%
MAE: 5.715600851320523, MSE:68.73188526984967,
RMSE: 8.290469544594544, IA: 0.9142479890834956, MAPE: 37.17047633429125%


(5.715600851320523,
 68.73188526984967,
 8.290469544594544,
 0.9142479890834956,
 37.17047633429125)

### SNN with two Leaky layers

In [None]:
model4 = DoubleLeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1)
print(model4)

In [None]:
training_loop(model4, train_loader, valid_loader, device, num_epochs=25, lr=0.002, validation=True)

In [None]:
print(f'Accuracy on test set: {get_accuracy(model4, test_loader, device=device, pct_close=0.25)*100}%')
get_error_measures(model4, test_loader, device=device, print_e=True)

### Repeated training on random datasets

In [7]:
err_model1 = [0., 0., 0., 0., 0., 0.]
err_model2 = [0., 0., 0., 0., 0., 0.]
err_model3 = [0., 0., 0., 0., 0., 0.]
err_model4 = [0., 0., 0., 0., 0., 0.]

In [8]:
trials = 3
for i in range(trials):
    # Prepare data loaders.
    train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(all_data, (round(0.75 * len(all_data)), round(0.15 * len(all_data)), round(0.1 * len(all_data))))
    train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
    test_loader = data.DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)
    valid_loader = data.DataLoader(valid_dataset, batch_size=256, shuffle=False, drop_last=False)

    # Prepare models.
    model1 = TwoLayerPerceptron(num_inputs=8, num_hidden=3500, num_outputs=1)
    model2 = LeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1)
    model3 = SynapticSNN(num_inputs=8, num_hidden=3500, num_outputs=1)
    model4 = DoubleLeakySNN(num_inputs=8, num_hidden=3500, num_outputs=1)

    # Train models.
    print(f'------------------------------------\nTraining {i+1}')
    training_loop(model1, train_loader, valid_loader, device, num_epochs=25, lr=0.005, validation=False)
    training_loop(model2, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=False)
    training_loop(model3, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=False)
    training_loop(model4, train_loader, valid_loader, device, num_epochs=25, lr=0.0005, validation=False)

    # Gather results.
    collect_stats(model1, err_model1, test_loader, device)
    collect_stats(model2, err_model2, test_loader, device)
    collect_stats(model3, err_model3, test_loader, device)
    collect_stats(model4, err_model4, test_loader, device)

------------------------------------
Training 1


100%|██████████| 25/25 [00:11<00:00,  2.10it/s]


Accuracy on validation dataset: 52.6%


100%|██████████| 25/25 [02:11<00:00,  5.26s/it]


Accuracy on validation dataset: 51.449999999999996%


100%|██████████| 25/25 [08:46<00:00, 21.06s/it]


Accuracy on validation dataset: 52.6%


100%|██████████| 25/25 [04:12<00:00, 10.08s/it]


Accuracy on validation dataset: 53.64%
------------------------------------
Training 2


100%|██████████| 25/25 [00:09<00:00,  2.55it/s]


Accuracy on validation dataset: 50.980000000000004%


100%|██████████| 25/25 [02:11<00:00,  5.27s/it]


Accuracy on validation dataset: 52.49%


100%|██████████| 25/25 [08:32<00:00, 20.51s/it]


Accuracy on validation dataset: 54.339999999999996%


100%|██████████| 25/25 [04:05<00:00,  9.82s/it]


Accuracy on validation dataset: 51.910000000000004%
------------------------------------
Training 3


100%|██████████| 25/25 [00:09<00:00,  2.64it/s]


Accuracy on validation dataset: 53.76%


100%|██████████| 25/25 [02:08<00:00,  5.12s/it]


Accuracy on validation dataset: 53.53%


100%|██████████| 25/25 [08:33<00:00, 20.53s/it]


Accuracy on validation dataset: 53.87%


100%|██████████| 25/25 [04:07<00:00,  9.92s/it]


Accuracy on validation dataset: 53.18000000000001%


In [9]:
# Print average results
print('------Simple ANN ------')
print_average_stats(err_model1, trials)
print('------LIF SNN ------')
print_average_stats(err_model2, trials)
print('------Synaptic SNN ------')
print_average_stats(err_model3, trials)
print('------Double LIF SNN ------')
print_average_stats(err_model4, trials)

------Simple ANN ------
MAE: 6.223348741906082, MSE:77.87496346251066,
RMSE: 8.822432196005215, IA: 0.8731686590293771,
MAPE: 38.26240906891394 %, acc: 16.153333333333332 %

------LIF SNN ------
MAE: 5.9599948740642255, MSE:70.97580849449467,
RMSE: 8.424293170580057, IA: 0.8906145868669072,
MAPE: 36.826144396834216 %, acc: 17.36 %

------Synaptic SNN ------
MAE: 5.89480741050345, MSE:71.08278226297476,
RMSE: 8.430690187721938, IA: 0.891052049095577,
MAPE: 35.88559691869337 %, acc: 17.926666666666666 %

------Double LIF SNN ------
MAE: 6.093222893074224, MSE:72.02550909860128,
RMSE: 8.482812043114984, IA: 0.8892542744735851,
MAPE: 39.67798868867746 %, acc: 16.64 %

