In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from dataset.dataset import NoCDataset
from model.vanilla import VanillaModel

from tqdm import tqdm

In [2]:
dataset = NoCDataset()
print(f"#Samples = {len(dataset)}")
print(dataset[0])

# TODO: use advanced dgl.dataloader instead of manually dividing dataset
num_training = 0.9 * len(dataset)


#Samples = 559
(Graph(num_nodes={'packet': 10, 'router': 12},
      num_edges={('packet', 'pass', 'router'): 46, ('router', 'backpressure', 'router'): 16, ('router', 'connect', 'router'): 16, ('router', 'transfer', 'packet'): 46},
      metagraph=[('packet', 'router', 'pass'), ('router', 'router', 'backpressure'), ('router', 'router', 'connect'), ('router', 'packet', 'transfer')]), tensor([1057.8928,  887.8571]))


In [3]:
device = "cpu"
# device = "cuda:0"

model = VanillaModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
losses = []

epoches = 100
for e in tqdm(range(epoches)):
    for i, data in enumerate(dataset):
        if i > num_training:
            break

        g, congestion = data
        g = g.to(device)

        pred = model(g)
        loss = F.mse_loss(pred, torch.log(congestion+1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"iteration: {i}; loss = {loss}")
            losses.append(loss.item())
            # print(f"pred = {pred}")
            # print(f"congestion = {congestion}")
            
plt.ylim(0, 20)
plt.plot(losses)


  0%|          | 0/100 [00:00<?, ?it/s]

iteration: 0; loss = 74.66802978515625
iteration: 100; loss = 32.879417419433594
iteration: 200; loss = 14.707178115844727
iteration: 300; loss = 6.074963569641113
iteration: 400; loss = 6.498138427734375


  1%|          | 1/100 [00:02<03:33,  2.16s/it]

iteration: 500; loss = 7.965472221374512
iteration: 0; loss = 8.72354793548584
iteration: 100; loss = 13.729049682617188
iteration: 200; loss = 7.677690505981445
iteration: 300; loss = 5.306740760803223
iteration: 400; loss = 5.947522163391113


  2%|▏         | 2/100 [00:04<03:32,  2.16s/it]

iteration: 500; loss = 8.174074172973633
iteration: 0; loss = 9.350984573364258
iteration: 100; loss = 12.21688175201416
iteration: 200; loss = 7.843695640563965
iteration: 300; loss = 5.785938262939453
iteration: 400; loss = 5.68498420715332


  3%|▎         | 3/100 [00:06<03:31,  2.18s/it]

iteration: 500; loss = 8.220603942871094
iteration: 0; loss = 9.39470100402832
iteration: 100; loss = 11.13741683959961
iteration: 200; loss = 7.850262641906738
iteration: 300; loss = 6.107549667358398
iteration: 400; loss = 5.713761806488037


  4%|▍         | 4/100 [00:08<03:27,  2.16s/it]

iteration: 500; loss = 8.182330131530762
iteration: 0; loss = 9.40207290649414
iteration: 100; loss = 10.653310775756836
iteration: 200; loss = 7.69388484954834
iteration: 300; loss = 6.081240653991699
iteration: 400; loss = 5.890822410583496


  5%|▌         | 5/100 [00:10<03:24,  2.15s/it]

iteration: 500; loss = 8.016251564025879
iteration: 0; loss = 9.155292510986328
iteration: 100; loss = 10.127199172973633
iteration: 200; loss = 7.596438407897949
iteration: 300; loss = 6.101991653442383
iteration: 400; loss = 6.068156719207764


  6%|▌         | 6/100 [00:12<03:22,  2.16s/it]

iteration: 500; loss = 7.977252960205078
iteration: 0; loss = 9.049442291259766
iteration: 100; loss = 9.792426109313965
iteration: 200; loss = 7.629509925842285
iteration: 300; loss = 5.650356769561768
iteration: 400; loss = 6.347714424133301


  7%|▋         | 7/100 [00:15<03:18,  2.13s/it]

iteration: 500; loss = 8.010160446166992
iteration: 0; loss = 9.100605964660645
iteration: 100; loss = 9.779108047485352
iteration: 200; loss = 7.741888046264648
iteration: 300; loss = 5.680500030517578
iteration: 400; loss = 6.305400371551514


  8%|▊         | 8/100 [00:17<03:20,  2.18s/it]

iteration: 500; loss = 7.421998023986816
iteration: 0; loss = 8.350349426269531
iteration: 100; loss = 8.64661979675293
iteration: 200; loss = 7.884604454040527
iteration: 300; loss = 5.264969825744629
iteration: 400; loss = 6.276812553405762


  9%|▉         | 9/100 [00:19<03:18,  2.18s/it]

iteration: 500; loss = 6.621750831604004
iteration: 0; loss = 7.436712741851807
iteration: 100; loss = 7.157182693481445
iteration: 200; loss = 8.039262771606445
iteration: 300; loss = 4.064287185668945
iteration: 400; loss = 6.218685150146484


 10%|█         | 10/100 [00:21<03:15,  2.17s/it]

iteration: 500; loss = 5.112462043762207
iteration: 0; loss = 5.739530563354492
iteration: 100; loss = 5.103909969329834
iteration: 200; loss = 7.938422203063965
iteration: 300; loss = 2.8777542114257812
iteration: 400; loss = 6.356319427490234


 11%|█         | 11/100 [00:23<03:13,  2.18s/it]

iteration: 500; loss = 3.470053195953369
iteration: 0; loss = 3.81467604637146
iteration: 100; loss = 3.476485252380371
iteration: 200; loss = 7.836202621459961
iteration: 300; loss = 1.9684264659881592
iteration: 400; loss = 6.287155628204346


 12%|█▏        | 12/100 [00:26<03:13,  2.20s/it]

iteration: 500; loss = 2.7794857025146484
iteration: 0; loss = 2.976999282836914
iteration: 100; loss = 2.4467835426330566
iteration: 200; loss = 7.701805114746094
iteration: 300; loss = 1.511405348777771
iteration: 400; loss = 6.068869113922119


 13%|█▎        | 13/100 [00:28<03:11,  2.20s/it]

iteration: 500; loss = 2.173912286758423
iteration: 0; loss = 2.280041217803955
iteration: 100; loss = 2.068937301635742
iteration: 200; loss = 7.574592590332031
iteration: 300; loss = 1.2386897802352905
iteration: 400; loss = 5.9718732833862305


 14%|█▍        | 14/100 [00:30<03:09,  2.20s/it]

iteration: 500; loss = 2.078094959259033
iteration: 0; loss = 2.1419665813446045
iteration: 100; loss = 1.874377965927124
iteration: 200; loss = 7.661510944366455
iteration: 300; loss = 1.3182251453399658
iteration: 400; loss = 6.087156295776367


 15%|█▌        | 15/100 [00:32<03:06,  2.19s/it]

iteration: 500; loss = 2.3360061645507812
iteration: 0; loss = 2.390929698944092
iteration: 100; loss = 1.74078369140625
iteration: 200; loss = 7.90366268157959
iteration: 300; loss = 1.2211347818374634
iteration: 400; loss = 6.250230312347412


 16%|█▌        | 16/100 [00:34<03:02,  2.17s/it]

iteration: 500; loss = 1.968357801437378
iteration: 0; loss = 2.00325870513916
iteration: 100; loss = 1.6264803409576416
iteration: 200; loss = 7.707530975341797
iteration: 300; loss = 1.2155730724334717
iteration: 400; loss = 6.507170677185059


 17%|█▋        | 17/100 [00:37<03:01,  2.19s/it]

iteration: 500; loss = 2.2115156650543213
iteration: 0; loss = 2.245647668838501
iteration: 100; loss = 1.6110703945159912
iteration: 200; loss = 7.751492500305176
iteration: 300; loss = 1.137721300125122
iteration: 400; loss = 6.23148250579834


 18%|█▊        | 18/100 [00:39<03:00,  2.20s/it]

iteration: 500; loss = 2.1680097579956055
iteration: 0; loss = 2.192265748977661
iteration: 100; loss = 1.5097343921661377
iteration: 200; loss = 7.790816783905029
iteration: 300; loss = 1.156538486480713
iteration: 400; loss = 6.023856163024902


 19%|█▉        | 19/100 [00:41<02:57,  2.19s/it]

iteration: 500; loss = 1.937652349472046
iteration: 0; loss = 1.9562559127807617
iteration: 100; loss = 1.5520813465118408
iteration: 200; loss = 7.787765026092529
iteration: 300; loss = 1.036571979522705
iteration: 400; loss = 6.389894485473633


 20%|██        | 20/100 [00:43<02:55,  2.19s/it]

iteration: 500; loss = 2.066721200942993
iteration: 0; loss = 2.084704875946045
iteration: 100; loss = 1.4634582996368408
iteration: 200; loss = 8.062045097351074
iteration: 300; loss = 1.2401000261306763
iteration: 400; loss = 6.627157211303711


 21%|██        | 21/100 [00:45<02:52,  2.18s/it]

iteration: 500; loss = 1.9854164123535156
iteration: 0; loss = 2.008690118789673
iteration: 100; loss = 1.3510832786560059
iteration: 200; loss = 8.117444038391113
iteration: 300; loss = 1.150022029876709
iteration: 400; loss = 6.21605110168457


 22%|██▏       | 22/100 [00:47<02:50,  2.19s/it]

iteration: 500; loss = 1.9478216171264648
iteration: 0; loss = 1.936093807220459
iteration: 100; loss = 1.3057258129119873
iteration: 200; loss = 8.459281921386719
iteration: 300; loss = 0.9973969459533691
iteration: 400; loss = 6.002358436584473


 23%|██▎       | 23/100 [00:50<02:48,  2.18s/it]

iteration: 500; loss = 2.158632755279541
iteration: 0; loss = 2.215073823928833
iteration: 100; loss = 1.20076584815979
iteration: 200; loss = 8.7433500289917
iteration: 300; loss = 0.8980568647384644
iteration: 400; loss = 5.4437408447265625


 24%|██▍       | 24/100 [00:52<02:46,  2.19s/it]

iteration: 500; loss = 2.135741949081421
iteration: 0; loss = 2.094743013381958
iteration: 100; loss = 1.1810176372528076
iteration: 200; loss = 8.38262939453125
iteration: 300; loss = 1.2701270580291748
iteration: 400; loss = 4.808590412139893


 25%|██▌       | 25/100 [00:54<02:45,  2.20s/it]

iteration: 500; loss = 2.222769260406494
iteration: 0; loss = 2.202582359313965
iteration: 100; loss = 1.1632530689239502
iteration: 200; loss = 8.519831657409668
iteration: 300; loss = 1.416170597076416
iteration: 400; loss = 5.1512274742126465


 26%|██▌       | 26/100 [00:56<02:43,  2.21s/it]

iteration: 500; loss = 2.0890817642211914
iteration: 0; loss = 2.034370183944702
iteration: 100; loss = 1.1566789150238037
iteration: 200; loss = 9.094696044921875
iteration: 300; loss = 0.9068517684936523
iteration: 400; loss = 5.672717094421387


 27%|██▋       | 27/100 [00:58<02:40,  2.20s/it]

iteration: 500; loss = 2.1098058223724365
iteration: 0; loss = 2.0524349212646484
iteration: 100; loss = 1.1371124982833862
iteration: 200; loss = 9.038568496704102
iteration: 300; loss = 0.7593222856521606
iteration: 400; loss = 4.681394577026367


 28%|██▊       | 28/100 [01:01<02:38,  2.20s/it]

iteration: 500; loss = 1.5369555950164795
iteration: 0; loss = 1.463232159614563
iteration: 100; loss = 1.269476294517517
iteration: 200; loss = 9.500162124633789
iteration: 300; loss = 0.884758472442627
iteration: 400; loss = 4.675144195556641


 29%|██▉       | 29/100 [01:03<02:36,  2.20s/it]

iteration: 500; loss = 1.619052767753601
iteration: 0; loss = 1.5065772533416748
iteration: 100; loss = 1.1818861961364746
iteration: 200; loss = 8.999977111816406
iteration: 300; loss = 0.7212434411048889
iteration: 400; loss = 5.803919792175293


 30%|███       | 30/100 [01:05<02:33,  2.20s/it]

iteration: 500; loss = 1.55571448802948
iteration: 0; loss = 1.4525847434997559
iteration: 100; loss = 1.2500709295272827
iteration: 200; loss = 9.362701416015625
iteration: 300; loss = 0.787432074546814
iteration: 400; loss = 4.577017307281494


 31%|███       | 31/100 [01:07<02:29,  2.17s/it]

iteration: 500; loss = 1.4769532680511475
iteration: 0; loss = 1.3258411884307861
iteration: 100; loss = 1.2902731895446777


 31%|███       | 31/100 [01:08<02:32,  2.21s/it]

iteration: 200; loss = 9.56175422668457





KeyboardInterrupt: 

In [4]:
# test accuracy
# we use relative error to measure

mae_losses = []

for i, data in enumerate(dataset):
    if i <= num_training:
        continue

    g, congestion = data
    g = g.to(device)

    pred = model(g)
    loss = F.mse_loss(pred, torch.log(congestion+1))
    pred_congestion = torch.exp(pred) - 1

    print(f"iteration: {i}; loss = {loss.item()}")
    print(f"pred = {pred_congestion.detach().numpy()}")
    print(f"congestion = {congestion.numpy()}")

    mae_losses.append(F.l1_loss(pred_congestion, congestion))

print("*" * 30)
mae_losses = torch.stack(mae_losses).mean()
print(f"average mae loss: {mae_losses}")



iteration: 504; loss = 2.2492332458496094
pred = [224.05904 233.71704]
congestion = [1118.0714  939.2857]
iteration: 505; loss = 4.427812099456787
pred = [10.185406   4.6935515]
congestion = [0. 0.]
iteration: 506; loss = 3.8778021335601807
pred = [112.22044 120.26068]
congestion = [838.12085 838.4623 ]
iteration: 507; loss = 10.677654266357422
pred = [ 0.35073853 -0.4048471 ]
congestion = [132.3   0. ]
iteration: 508; loss = 24.919662475585938
pred = [149.622   164.65103]
congestion = [0.15441176 0.        ]
iteration: 509; loss = 16.004436492919922
pred = [222.99495 123.44561]
congestion = [4305.2856    0.    ]
iteration: 510; loss = 11.50244426727295
pred = [40.69868  19.383211]
congestion = [0. 0.]
iteration: 511; loss = 6.040027618408203
pred = [9.872671 4.655108]
congestion = [88.588036 88.588036]
iteration: 512; loss = 2.320106029510498
pred = [51.274147 53.530777]
congestion = [247.46443 240.16798]
iteration: 513; loss = 4.5643157958984375
pred = [8.413722 6.577396]
congestion 