In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from dataset.dataset import NoCDataset
from model.vanilla import VanillaModel

from tqdm import tqdm

In [2]:
dataset = NoCDataset()
print(f"#Samples = {len(dataset)}")
print(dataset[0])

# TODO: use advanced dgl.dataloader instead of manually dividing dataset
num_training = 0.9 * len(dataset)


#Samples = 559
(Graph(num_nodes={'packet': 10, 'router': 12},
      num_edges={('packet', 'pass', 'router'): 46, ('router', 'backpressure', 'router'): 16, ('router', 'connect', 'router'): 16, ('router', 'transfer', 'packet'): 46},
      metagraph=[('packet', 'router', 'pass'), ('router', 'router', 'backpressure'), ('router', 'router', 'connect'), ('router', 'packet', 'transfer')]), tensor([1057.8928,  887.8571]))


In [5]:
device = "cpu"
# device = "cuda:0"

model = VanillaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
losses = []

epoches = 100
for e in tqdm(range(epoches)):
    for i, data in enumerate(dataset):
        if i > num_training:
            break

        g, congestion = data
        g = g.to(device)

        pred = model(g)
        loss = F.mse_loss(pred, torch.log(congestion+1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"iteration: {i}; loss = {loss}")
            losses.append(loss.item())
            # print(f"pred = {pred}")
            # print(f"congestion = {congestion}")
            
plt.ylim(0, 20)
plt.plot(losses)


  0%|          | 0/100 [00:00<?, ?it/s]

iteration: 0; loss = 90.5848388671875
iteration: 100; loss = 24.426420211791992
iteration: 200; loss = 15.197343826293945
iteration: 300; loss = 7.454904079437256
iteration: 400; loss = 6.675256729125977


  1%|          | 1/100 [00:03<06:22,  3.86s/it]

iteration: 500; loss = 7.523663520812988
iteration: 0; loss = 8.845500946044922
iteration: 100; loss = 10.212442398071289
iteration: 200; loss = 7.753923416137695
iteration: 300; loss = 5.9914140701293945
iteration: 400; loss = 6.324829578399658


  2%|▏         | 2/100 [00:07<06:13,  3.81s/it]

iteration: 500; loss = 7.800098419189453
iteration: 0; loss = 8.957642555236816
iteration: 100; loss = 9.726606369018555
iteration: 200; loss = 8.005084037780762
iteration: 300; loss = 5.8300676345825195
iteration: 400; loss = 6.038851737976074


  3%|▎         | 3/100 [00:11<06:10,  3.82s/it]

iteration: 500; loss = 7.665209770202637
iteration: 0; loss = 8.66726303100586
iteration: 100; loss = 9.440008163452148
iteration: 200; loss = 8.116987228393555
iteration: 300; loss = 5.464106559753418
iteration: 400; loss = 6.070445537567139


  4%|▍         | 4/100 [00:15<06:09,  3.85s/it]

iteration: 500; loss = 8.028295516967773
iteration: 0; loss = 8.924420356750488
iteration: 100; loss = 9.13081169128418
iteration: 200; loss = 8.409296989440918
iteration: 300; loss = 5.2625412940979
iteration: 400; loss = 5.9106035232543945


  5%|▌         | 5/100 [00:19<06:04,  3.83s/it]

iteration: 500; loss = 7.869128704071045
iteration: 0; loss = 8.663934707641602
iteration: 100; loss = 8.759931564331055
iteration: 200; loss = 8.702604293823242
iteration: 300; loss = 4.867667198181152
iteration: 400; loss = 5.9379167556762695


  6%|▌         | 6/100 [00:23<06:01,  3.84s/it]

iteration: 500; loss = 7.718590259552002
iteration: 0; loss = 8.416864395141602
iteration: 100; loss = 11.096772193908691
iteration: 200; loss = 8.926416397094727
iteration: 300; loss = 4.440793037414551
iteration: 400; loss = 6.315412998199463


  7%|▋         | 7/100 [00:26<05:59,  3.86s/it]

iteration: 500; loss = 7.389350891113281
iteration: 0; loss = 8.050125122070312
iteration: 100; loss = 7.678874969482422
iteration: 200; loss = 8.922239303588867
iteration: 300; loss = 3.969238758087158
iteration: 400; loss = 6.543637275695801


  8%|▊         | 8/100 [00:30<05:58,  3.89s/it]

iteration: 500; loss = 6.475584983825684
iteration: 0; loss = 7.0108489990234375
iteration: 100; loss = 6.38168478012085
iteration: 200; loss = 8.961115837097168
iteration: 300; loss = 3.191039800643921
iteration: 400; loss = 6.8591413497924805


  9%|▉         | 9/100 [00:34<05:54,  3.90s/it]

iteration: 500; loss = 5.360538482666016
iteration: 0; loss = 5.771203517913818
iteration: 100; loss = 4.945820331573486
iteration: 200; loss = 9.13500690460205
iteration: 300; loss = 2.3567795753479004
iteration: 400; loss = 7.028533935546875


 10%|█         | 10/100 [00:38<05:49,  3.88s/it]

iteration: 500; loss = 3.963198661804199
iteration: 0; loss = 4.375500202178955
iteration: 100; loss = 3.294252872467041
iteration: 200; loss = 9.02314281463623
iteration: 300; loss = 1.7167831659317017
iteration: 400; loss = 7.054058074951172


 11%|█         | 11/100 [00:42<05:45,  3.88s/it]

iteration: 500; loss = 2.9288854598999023
iteration: 0; loss = 3.0526437759399414
iteration: 100; loss = 2.660667896270752
iteration: 200; loss = 8.753861427307129
iteration: 300; loss = 1.2726259231567383
iteration: 400; loss = 6.638981819152832


 12%|█▏        | 12/100 [00:46<05:40,  3.87s/it]

iteration: 500; loss = 2.4212327003479004
iteration: 0; loss = 2.4468154907226562
iteration: 100; loss = 2.287660837173462
iteration: 200; loss = 8.56445598602295
iteration: 300; loss = 1.049000859260559
iteration: 400; loss = 6.491892337799072


 13%|█▎        | 13/100 [00:50<05:37,  3.88s/it]

iteration: 500; loss = 2.254035472869873
iteration: 0; loss = 2.2340099811553955
iteration: 100; loss = 2.0648574829101562
iteration: 200; loss = 8.511449813842773
iteration: 300; loss = 0.9562273621559143
iteration: 400; loss = 6.287265777587891


 14%|█▍        | 14/100 [00:54<05:34,  3.89s/it]

iteration: 500; loss = 2.184464454650879
iteration: 0; loss = 2.135382890701294
iteration: 100; loss = 1.93131685256958
iteration: 200; loss = 8.381985664367676
iteration: 300; loss = 0.8499190807342529
iteration: 400; loss = 6.386030197143555


 15%|█▌        | 15/100 [00:58<05:30,  3.89s/it]

iteration: 500; loss = 2.251161575317383
iteration: 0; loss = 2.1765263080596924
iteration: 100; loss = 1.658082127571106
iteration: 200; loss = 8.375616073608398
iteration: 300; loss = 0.8569464683532715
iteration: 400; loss = 6.131328582763672


 16%|█▌        | 16/100 [01:01<05:27,  3.89s/it]

iteration: 500; loss = 2.2458486557006836
iteration: 0; loss = 2.2028536796569824
iteration: 100; loss = 1.6291404962539673
iteration: 200; loss = 8.295654296875
iteration: 300; loss = 0.836912989616394
iteration: 400; loss = 6.476940631866455


 17%|█▋        | 17/100 [01:05<05:22,  3.88s/it]

iteration: 500; loss = 2.2653183937072754
iteration: 0; loss = 2.222468137741089
iteration: 100; loss = 1.7787351608276367
iteration: 200; loss = 8.267036437988281
iteration: 300; loss = 0.8060801029205322
iteration: 400; loss = 6.345046043395996


 18%|█▊        | 18/100 [01:09<05:17,  3.88s/it]

iteration: 500; loss = 2.032024383544922
iteration: 0; loss = 1.9355263710021973
iteration: 100; loss = 1.7555478811264038
iteration: 200; loss = 8.132591247558594
iteration: 300; loss = 0.724175214767456
iteration: 400; loss = 6.342360496520996


 19%|█▉        | 19/100 [01:13<05:14,  3.88s/it]

iteration: 500; loss = 2.16489839553833
iteration: 0; loss = 2.061400890350342
iteration: 100; loss = 1.6041103601455688
iteration: 200; loss = 8.27747631072998
iteration: 300; loss = 0.6633046865463257
iteration: 400; loss = 6.290240287780762


 20%|██        | 20/100 [01:17<05:09,  3.87s/it]

iteration: 500; loss = 2.2036867141723633
iteration: 0; loss = 2.107423782348633
iteration: 100; loss = 1.5277632474899292
iteration: 200; loss = 8.135238647460938
iteration: 300; loss = 0.6450411081314087
iteration: 400; loss = 6.184232234954834


 21%|██        | 21/100 [01:21<05:06,  3.88s/it]

iteration: 500; loss = 2.17785906791687
iteration: 0; loss = 2.0751919746398926
iteration: 100; loss = 1.6716526746749878
iteration: 200; loss = 8.21340274810791
iteration: 300; loss = 0.6206797361373901
iteration: 400; loss = 6.245696067810059


 22%|██▏       | 22/100 [01:25<05:01,  3.87s/it]

iteration: 500; loss = 2.383596420288086
iteration: 0; loss = 2.2958855628967285
iteration: 100; loss = 1.597424864768982
iteration: 200; loss = 8.24795150756836
iteration: 300; loss = 0.5408060550689697
iteration: 400; loss = 6.000288009643555


 23%|██▎       | 23/100 [01:29<04:58,  3.88s/it]

iteration: 500; loss = 2.368208169937134
iteration: 0; loss = 2.2687196731567383
iteration: 100; loss = 1.6061583757400513
iteration: 200; loss = 8.35826301574707
iteration: 300; loss = 0.5173569321632385
iteration: 400; loss = 5.979513168334961


 24%|██▍       | 24/100 [01:32<04:55,  3.88s/it]

iteration: 500; loss = 2.3392558097839355
iteration: 0; loss = 2.238534927368164
iteration: 100; loss = 1.7006912231445312
iteration: 200; loss = 8.521159172058105
iteration: 300; loss = 0.4458984136581421
iteration: 400; loss = 5.66026496887207


 25%|██▌       | 25/100 [01:36<04:52,  3.91s/it]

iteration: 500; loss = 2.3267292976379395
iteration: 0; loss = 2.221757411956787
iteration: 100; loss = 1.6734309196472168
iteration: 200; loss = 8.519609451293945
iteration: 300; loss = 0.48365849256515503
iteration: 400; loss = 5.720831394195557


 26%|██▌       | 26/100 [01:40<04:47,  3.89s/it]

iteration: 500; loss = 2.218323230743408
iteration: 0; loss = 2.0980019569396973
iteration: 100; loss = 1.5693167448043823
iteration: 200; loss = 8.897041320800781
iteration: 300; loss = 0.45329707860946655
iteration: 400; loss = 5.924520015716553


 27%|██▋       | 27/100 [01:44<04:43,  3.88s/it]

iteration: 500; loss = 2.24137544631958
iteration: 0; loss = 2.1119091510772705
iteration: 100; loss = 1.7445149421691895
iteration: 200; loss = 8.863329887390137
iteration: 300; loss = 0.4704750180244446
iteration: 400; loss = 5.239340782165527


 28%|██▊       | 28/100 [01:48<04:39,  3.89s/it]

iteration: 500; loss = 2.255312919616699
iteration: 0; loss = 2.15344500541687
iteration: 100; loss = 1.7391767501831055
iteration: 200; loss = 9.311077117919922
iteration: 300; loss = 0.5450315475463867
iteration: 400; loss = 5.220911502838135


 29%|██▉       | 29/100 [01:52<04:37,  3.91s/it]

iteration: 500; loss = 1.9269769191741943
iteration: 0; loss = 1.8354120254516602
iteration: 100; loss = 1.708784818649292
iteration: 200; loss = 9.868010520935059
iteration: 300; loss = 0.3840988278388977
iteration: 400; loss = 4.727683067321777


 30%|███       | 30/100 [01:56<04:33,  3.90s/it]

iteration: 500; loss = 1.9040393829345703
iteration: 0; loss = 1.7807788848876953
iteration: 100; loss = 1.773486852645874
iteration: 200; loss = 9.966211318969727
iteration: 300; loss = 0.44074547290802
iteration: 400; loss = 4.797552108764648


 31%|███       | 31/100 [02:00<04:27,  3.88s/it]

iteration: 500; loss = 1.88968825340271
iteration: 0; loss = 1.8169686794281006
iteration: 100; loss = 1.6929643154144287
iteration: 200; loss = 10.258719444274902
iteration: 300; loss = 0.46417444944381714
iteration: 400; loss = 4.7477335929870605


 32%|███▏      | 32/100 [02:03<04:22,  3.86s/it]

iteration: 500; loss = 1.8101134300231934
iteration: 0; loss = 1.7377846240997314
iteration: 100; loss = 1.9472274780273438
iteration: 200; loss = 10.415637016296387
iteration: 300; loss = 0.3988122045993805
iteration: 400; loss = 4.79780912399292


 33%|███▎      | 33/100 [02:07<04:18,  3.85s/it]

iteration: 500; loss = 1.7916889190673828
iteration: 0; loss = 1.698393702507019
iteration: 100; loss = 1.6852480173110962
iteration: 200; loss = 10.789766311645508
iteration: 300; loss = 0.29944437742233276
iteration: 400; loss = 4.814255714416504


 34%|███▍      | 34/100 [02:11<04:16,  3.89s/it]

iteration: 500; loss = 1.6647353172302246
iteration: 0; loss = 1.5838909149169922
iteration: 100; loss = 1.8741238117218018
iteration: 200; loss = 10.825170516967773
iteration: 300; loss = 0.3037537634372711
iteration: 400; loss = 4.715598106384277


 35%|███▌      | 35/100 [02:15<04:13,  3.91s/it]

iteration: 500; loss = 1.569090723991394
iteration: 0; loss = 1.4995472431182861
iteration: 100; loss = 1.9873175621032715
iteration: 200; loss = 11.314051628112793
iteration: 300; loss = 0.24737876653671265
iteration: 400; loss = 4.988848686218262


 36%|███▌      | 36/100 [02:19<04:09,  3.89s/it]

iteration: 500; loss = 1.396768569946289
iteration: 0; loss = 1.341522455215454
iteration: 100; loss = 1.968074083328247
iteration: 200; loss = 11.04177474975586
iteration: 300; loss = 0.07307561486959457
iteration: 400; loss = 4.899199485778809


 37%|███▋      | 37/100 [02:23<04:05,  3.90s/it]

iteration: 500; loss = 1.3602477312088013
iteration: 0; loss = 1.29351806640625
iteration: 100; loss = 1.937696933746338
iteration: 200; loss = 10.865859985351562
iteration: 300; loss = 0.18870225548744202
iteration: 400; loss = 5.651426315307617


 38%|███▊      | 38/100 [02:27<04:04,  3.94s/it]

iteration: 500; loss = 1.143810510635376
iteration: 0; loss = 1.0766334533691406
iteration: 100; loss = 1.8835645914077759
iteration: 200; loss = 10.968644142150879
iteration: 300; loss = 0.10381119698286057
iteration: 400; loss = 5.253266334533691


 39%|███▉      | 39/100 [02:31<03:59,  3.92s/it]

iteration: 500; loss = 1.1727163791656494
iteration: 0; loss = 1.1034274101257324
iteration: 100; loss = 2.009164571762085
iteration: 200; loss = 11.167764663696289
iteration: 300; loss = 0.03613678365945816
iteration: 400; loss = 5.468672275543213


 40%|████      | 40/100 [02:35<03:53,  3.89s/it]

iteration: 500; loss = 1.1036512851715088
iteration: 0; loss = 1.0282841920852661
iteration: 100; loss = 2.2225594520568848
iteration: 200; loss = 11.220817565917969
iteration: 300; loss = 0.03474850952625275
iteration: 400; loss = 5.668136119842529


 41%|████      | 41/100 [02:39<03:49,  3.89s/it]

iteration: 500; loss = 1.0813233852386475
iteration: 0; loss = 0.9982860088348389
iteration: 100; loss = 2.335916757583618
iteration: 200; loss = 11.778331756591797
iteration: 300; loss = 0.02597595751285553
iteration: 400; loss = 6.105434894561768


 42%|████▏     | 42/100 [02:42<03:44,  3.87s/it]

iteration: 500; loss = 1.0205752849578857
iteration: 0; loss = 0.9381512999534607
iteration: 100; loss = 2.6216795444488525
iteration: 200; loss = 11.824182510375977
iteration: 300; loss = 0.030441807582974434
iteration: 400; loss = 5.859167098999023


 43%|████▎     | 43/100 [02:46<03:41,  3.88s/it]

iteration: 500; loss = 0.9462982416152954
iteration: 0; loss = 0.8588389158248901
iteration: 100; loss = 2.4937920570373535
iteration: 200; loss = 11.469802856445312
iteration: 300; loss = 0.009451648220419884
iteration: 400; loss = 5.911890983581543


 44%|████▍     | 44/100 [02:50<03:37,  3.88s/it]

iteration: 500; loss = 0.7304590940475464
iteration: 0; loss = 0.651516318321228
iteration: 100; loss = 2.6815969944000244
iteration: 200; loss = 11.606452941894531
iteration: 300; loss = 0.0037793810479342937
iteration: 400; loss = 6.340647220611572


 45%|████▌     | 45/100 [02:54<03:32,  3.86s/it]

iteration: 500; loss = 0.6566365957260132
iteration: 0; loss = 0.5807203054428101
iteration: 100; loss = 2.6407880783081055
iteration: 200; loss = 11.375301361083984
iteration: 300; loss = 0.005803077016025782
iteration: 400; loss = 6.371323585510254


 46%|████▌     | 46/100 [02:58<03:27,  3.84s/it]

iteration: 500; loss = 0.7516423463821411
iteration: 0; loss = 0.6612533330917358
iteration: 100; loss = 2.7354319095611572
iteration: 200; loss = 11.540628433227539
iteration: 300; loss = 0.031105246394872665
iteration: 400; loss = 6.681894302368164


 47%|████▋     | 47/100 [03:02<03:25,  3.87s/it]

iteration: 500; loss = 0.6968153715133667
iteration: 0; loss = 0.5701179504394531
iteration: 100; loss = 2.966048240661621
iteration: 200; loss = 11.661005020141602
iteration: 300; loss = 0.05613144114613533
iteration: 400; loss = 6.714278221130371


 48%|████▊     | 48/100 [03:06<03:23,  3.90s/it]

iteration: 500; loss = 0.667086660861969
iteration: 0; loss = 0.5623563528060913
iteration: 100; loss = 2.731074333190918
iteration: 200; loss = 11.668315887451172
iteration: 300; loss = 0.09074877202510834
iteration: 400; loss = 6.855045318603516


 49%|████▉     | 49/100 [03:10<03:18,  3.89s/it]

iteration: 500; loss = 0.5750209093093872
iteration: 0; loss = 0.49738970398902893
iteration: 100; loss = 2.983428478240967
iteration: 200; loss = 11.775721549987793
iteration: 300; loss = 0.21331867575645447
iteration: 400; loss = 6.849189758300781


 50%|█████     | 50/100 [03:14<03:14,  3.88s/it]

iteration: 500; loss = 0.5946893095970154
iteration: 0; loss = 0.45115548372268677
iteration: 100; loss = 3.0958781242370605
iteration: 200; loss = 11.917928695678711
iteration: 300; loss = 0.20128780603408813
iteration: 400; loss = 6.80536413192749


 51%|█████     | 51/100 [03:17<03:09,  3.87s/it]

iteration: 500; loss = 0.4988340437412262
iteration: 0; loss = 0.36728715896606445
iteration: 100; loss = 3.40805721282959
iteration: 200; loss = 11.950552940368652
iteration: 300; loss = 0.19663292169570923
iteration: 400; loss = 6.7039690017700195


 52%|█████▏    | 52/100 [03:21<03:07,  3.90s/it]

iteration: 500; loss = 0.49745088815689087
iteration: 0; loss = 0.3600567877292633
iteration: 100; loss = 3.424872398376465
iteration: 200; loss = 11.994743347167969
iteration: 300; loss = 0.21268656849861145
iteration: 400; loss = 6.9080095291137695


 53%|█████▎    | 53/100 [03:25<03:04,  3.93s/it]

iteration: 500; loss = 0.436740517616272
iteration: 0; loss = 0.30783915519714355
iteration: 100; loss = 3.642343044281006
iteration: 200; loss = 12.18592643737793
iteration: 300; loss = 0.25019997358322144
iteration: 400; loss = 6.215667724609375


 54%|█████▍    | 54/100 [03:29<02:59,  3.91s/it]

iteration: 500; loss = 0.37448349595069885
iteration: 0; loss = 0.2513349652290344
iteration: 100; loss = 3.7123615741729736
iteration: 200; loss = 12.16972827911377
iteration: 300; loss = 0.21233484148979187
iteration: 400; loss = 6.8337626457214355


 55%|█████▌    | 55/100 [03:33<02:56,  3.91s/it]

iteration: 500; loss = 0.38207465410232544
iteration: 0; loss = 0.25271719694137573
iteration: 100; loss = 3.840796709060669
iteration: 200; loss = 12.236946105957031
iteration: 300; loss = 0.19437971711158752
iteration: 400; loss = 6.899389266967773


 56%|█████▌    | 56/100 [03:37<02:51,  3.90s/it]

iteration: 500; loss = 0.3379564881324768
iteration: 0; loss = 0.21056392788887024
iteration: 100; loss = 3.673654556274414
iteration: 200; loss = 12.229216575622559
iteration: 300; loss = 0.2005370557308197
iteration: 400; loss = 7.118080139160156


 57%|█████▋    | 57/100 [03:41<02:47,  3.89s/it]

iteration: 500; loss = 0.3128776550292969
iteration: 0; loss = 0.19160565733909607
iteration: 100; loss = 3.9372847080230713
iteration: 200; loss = 12.085371971130371
iteration: 300; loss = 0.16628560423851013
iteration: 400; loss = 6.762489318847656


 58%|█████▊    | 58/100 [03:45<02:42,  3.87s/it]

iteration: 500; loss = 0.29650598764419556
iteration: 0; loss = 0.1692863404750824
iteration: 100; loss = 4.068973064422607
iteration: 200; loss = 12.144269943237305
iteration: 300; loss = 0.1523376852273941
iteration: 400; loss = 7.021800994873047


 59%|█████▉    | 59/100 [03:49<02:38,  3.87s/it]

iteration: 500; loss = 0.2560061812400818
iteration: 0; loss = 0.13529321551322937
iteration: 100; loss = 3.8161110877990723
iteration: 200; loss = 12.541552543640137
iteration: 300; loss = 0.14697343111038208
iteration: 400; loss = 7.064723968505859


 60%|██████    | 60/100 [03:52<02:34,  3.85s/it]

iteration: 500; loss = 0.22259238362312317
iteration: 0; loss = 0.10774417966604233
iteration: 100; loss = 3.731062173843384
iteration: 200; loss = 12.459709167480469
iteration: 300; loss = 0.1875934600830078
iteration: 400; loss = 7.508918762207031


 61%|██████    | 61/100 [03:56<02:30,  3.86s/it]

iteration: 500; loss = 0.2408868670463562
iteration: 0; loss = 0.11774246394634247
iteration: 100; loss = 3.6843483448028564
iteration: 200; loss = 12.28622055053711
iteration: 300; loss = 0.15346962213516235
iteration: 400; loss = 7.3636980056762695


 62%|██████▏   | 62/100 [04:00<02:26,  3.85s/it]

iteration: 500; loss = 0.1758301705121994
iteration: 0; loss = 0.07161357998847961
iteration: 100; loss = 3.8486804962158203
iteration: 200; loss = 12.442198753356934
iteration: 300; loss = 0.16910725831985474
iteration: 400; loss = 7.344019889831543


 63%|██████▎   | 63/100 [04:04<02:22,  3.85s/it]

iteration: 500; loss = 0.1519872546195984
iteration: 0; loss = 0.054769594222307205
iteration: 100; loss = 3.8048036098480225
iteration: 200; loss = 12.59544563293457
iteration: 300; loss = 0.16139231622219086
iteration: 400; loss = 7.514104843139648


 64%|██████▍   | 64/100 [04:08<02:18,  3.85s/it]

iteration: 500; loss = 0.16840288043022156
iteration: 0; loss = 0.06102832406759262
iteration: 100; loss = 3.9542574882507324
iteration: 200; loss = 12.659746170043945
iteration: 300; loss = 0.1292160451412201
iteration: 400; loss = 7.588276386260986


 65%|██████▌   | 65/100 [04:12<02:14,  3.85s/it]

iteration: 500; loss = 0.14305312931537628
iteration: 0; loss = 0.04333237558603287
iteration: 100; loss = 3.8866779804229736
iteration: 200; loss = 12.667289733886719
iteration: 300; loss = 0.1009996235370636
iteration: 400; loss = 7.6655964851379395


 66%|██████▌   | 66/100 [04:15<02:10,  3.84s/it]

iteration: 500; loss = 0.1338338702917099
iteration: 0; loss = 0.037391699850559235
iteration: 100; loss = 3.8800933361053467
iteration: 200; loss = 12.73422622680664
iteration: 300; loss = 0.13153252005577087
iteration: 400; loss = 7.603968620300293


 67%|██████▋   | 67/100 [04:19<02:07,  3.86s/it]

iteration: 500; loss = 0.12025253474712372
iteration: 0; loss = 0.029721707105636597
iteration: 100; loss = 3.703580617904663
iteration: 200; loss = 12.73393726348877
iteration: 300; loss = 0.10371480882167816
iteration: 400; loss = 7.638673305511475


 68%|██████▊   | 68/100 [04:23<02:03,  3.86s/it]

iteration: 500; loss = 0.1353093683719635
iteration: 0; loss = 0.03398757800459862
iteration: 100; loss = 3.673828601837158
iteration: 200; loss = 12.830680847167969
iteration: 300; loss = 0.09548327326774597
iteration: 400; loss = 7.590639114379883


 69%|██████▉   | 69/100 [04:27<01:59,  3.85s/it]

iteration: 500; loss = 0.11864344030618668
iteration: 0; loss = 0.02717379666864872
iteration: 100; loss = 3.690321922302246
iteration: 200; loss = 12.832741737365723
iteration: 300; loss = 0.1037289947271347
iteration: 400; loss = 7.720757484436035


 70%|███████   | 70/100 [04:31<01:55,  3.84s/it]

iteration: 500; loss = 0.11358429491519928
iteration: 0; loss = 0.022184422239661217
iteration: 100; loss = 3.7653160095214844
iteration: 200; loss = 12.999534606933594
iteration: 300; loss = 0.1007683277130127
iteration: 400; loss = 7.588344573974609


 71%|███████   | 71/100 [04:35<01:51,  3.83s/it]

iteration: 500; loss = 0.11523805558681488
iteration: 0; loss = 0.023661376908421516
iteration: 100; loss = 3.7590041160583496
iteration: 200; loss = 12.60775375366211
iteration: 300; loss = 0.07995879650115967
iteration: 400; loss = 7.67667293548584


 72%|███████▏  | 72/100 [04:38<01:47,  3.83s/it]

iteration: 500; loss = 0.08349348604679108
iteration: 0; loss = 0.012586266733705997
iteration: 100; loss = 3.5710716247558594
iteration: 200; loss = 12.563961029052734
iteration: 300; loss = 0.0491805337369442
iteration: 400; loss = 7.91780948638916


 73%|███████▎  | 73/100 [04:42<01:43,  3.84s/it]

iteration: 500; loss = 0.09958681464195251
iteration: 0; loss = 0.021198445931077003
iteration: 100; loss = 3.187225341796875
iteration: 200; loss = 13.185932159423828
iteration: 300; loss = 0.05168923735618591
iteration: 400; loss = 7.643224716186523


 74%|███████▍  | 74/100 [04:46<01:41,  3.89s/it]

iteration: 500; loss = 0.07525631040334702
iteration: 0; loss = 0.013396669179201126
iteration: 100; loss = 3.04948091506958
iteration: 200; loss = 13.719514846801758
iteration: 300; loss = 0.07630482316017151
iteration: 400; loss = 7.821937561035156


 75%|███████▌  | 75/100 [04:50<01:37,  3.90s/it]

iteration: 500; loss = 0.06262525916099548
iteration: 0; loss = 0.008461183868348598
iteration: 100; loss = 3.2706804275512695
iteration: 200; loss = 13.697368621826172
iteration: 300; loss = 0.05290762335062027
iteration: 400; loss = 7.869093418121338


 76%|███████▌  | 76/100 [04:54<01:33,  3.88s/it]

iteration: 500; loss = 0.07093361765146255
iteration: 0; loss = 0.01137813925743103
iteration: 100; loss = 3.143610715866089
iteration: 200; loss = 13.594944953918457
iteration: 300; loss = 0.039649542421102524
iteration: 400; loss = 8.277748107910156


 77%|███████▋  | 77/100 [04:58<01:29,  3.89s/it]

iteration: 500; loss = 0.05696626007556915
iteration: 0; loss = 0.007130442652851343
iteration: 100; loss = 3.0460731983184814
iteration: 200; loss = 13.748403549194336
iteration: 300; loss = 0.07912320643663406
iteration: 400; loss = 8.074652671813965


 78%|███████▊  | 78/100 [05:02<01:25,  3.90s/it]

iteration: 500; loss = 0.06403346359729767
iteration: 0; loss = 0.009208917617797852
iteration: 100; loss = 2.8729326725006104
iteration: 200; loss = 13.782964706420898
iteration: 300; loss = 0.03313906863331795
iteration: 400; loss = 7.939350605010986


 79%|███████▉  | 79/100 [05:06<01:22,  3.92s/it]

iteration: 500; loss = 0.03403111547231674
iteration: 0; loss = 0.006465349346399307
iteration: 100; loss = 2.889078140258789
iteration: 200; loss = 13.264022827148438
iteration: 300; loss = 0.02529199793934822
iteration: 400; loss = 7.786128520965576


 80%|████████  | 80/100 [05:10<01:17,  3.90s/it]

iteration: 500; loss = 0.04752393439412117
iteration: 0; loss = 0.00449218088760972
iteration: 100; loss = 2.744858741760254
iteration: 200; loss = 13.449665069580078
iteration: 300; loss = 0.02309071086347103


 80%|████████  | 80/100 [05:13<01:18,  3.92s/it]


KeyboardInterrupt: 

In [6]:
# test accuracy
# we use relative error to measure

mae_losses = []

for i, data in enumerate(dataset):
    if i <= num_training:
        continue

    g, congestion = data
    g = g.to(device)

    pred = model(g)
    loss = F.mse_loss(pred, torch.log(congestion+1))
    pred_congestion = torch.exp(pred) - 1

    print(f"iteration: {i}; loss = {loss.item()}")
    print(f"pred = {pred_congestion.detach().numpy()}")
    print(f"congestion = {congestion.numpy()}")

    mae_losses.append(F.l1_loss(pred_congestion, congestion))

print("*" * 30)
mae_losses = torch.stack(mae_losses).mean()
print(f"average mae loss: {mae_losses}")



iteration: 504; loss = 0.6641845703125
pred = [479.1408  428.94254]
congestion = [1118.0714  939.2857]
iteration: 505; loss = 11.055566787719727
pred = [40.299778 16.727673]
congestion = [0. 0.]
iteration: 506; loss = 10.733789443969727
pred = [31.033901 30.370367]
congestion = [838.12085 838.4623 ]
iteration: 507; loss = 14.412928581237793
pred = [-0.33196092 -0.5860933 ]
congestion = [132.3   0. ]
iteration: 508; loss = 1.6523786783218384
pred = [2.1711516 3.5320368]
congestion = [0.15441176 0.        ]
iteration: 509; loss = 15.886566162109375
pred = [27.665281 12.184709]
congestion = [4305.2856    0.    ]
iteration: 510; loss = 0.40594416856765747
pred = [-0.35482025 -0.5449295 ]
congestion = [0. 0.]
iteration: 511; loss = 1.9137054681777954
pred = [34.94242  14.880779]
congestion = [88.588036 88.588036]
iteration: 512; loss = 4.459554672241211
pred = [35.454514 23.467356]
congestion = [247.46443 240.16798]
iteration: 513; loss = 3.804619789123535
pred = [7.46966   4.7252836]
conge