In [1]:
from entsoe import load_data
import numpy as np
from torch_geometric_temporal.nn.attention.stgcn import STConv
import torch
import torch.nn as nn
from tqdm import tqdm

In [2]:
data = load_data()

torch.Size([43729, 10, 24]) torch.Size([43729, 2, 32]) torch.Size([43729, 32, 1]) torch.Size([43729, 32, 1])


In [None]:
class FullyConnLayer(nn.Module):
    def __init__(self, c):
        super(FullyConnLayer, self).__init__()
        self.conv = nn.Conv2d(c, 1, 1)

    def forward(self, x):
        return self.conv(x)


# adapted from Hao Wei
class OutputLayer(nn.Module):
    def __init__(self, c, T, n):
        super(OutputLayer, self).__init__()
        self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation=1, padding=(0, 0))
        self.ln = nn.LayerNorm([n, c])
        self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation=1, padding=(0, 0))
        self.fc = FullyConnLayer(c)

    def forward(self, x):
        x_t1 = self.tconv1(x)
        x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x_t2 = self.tconv2(x_ln)
        return self.fc(x_t2)


class Model(nn.Module):
    def __init__(self, window_size, num_nodes):
        super(Model, self).__init__()
        self.stconv1 = STConv(
            num_nodes=num_nodes,
            in_channels=24,
            hidden_channels=16,
            out_channels=64,
            kernel_size=3,
            K=1,
        )
        self.stconv2 = STConv(
            num_nodes=num_nodes,
            in_channels=64,
            hidden_channels=16,
            out_channels=64,
            kernel_size=3,
            K=1,
        )
        # window_size - 2 * num_layers * (kernel_size - 1) = 24 - 2 * 2 * (3 - 1) = 16
        T = window_size - 2 * 2 * (3 - 1)
        self.output_layer = OutputLayer(64, T, 10)

    def forward(self, x, edge_index, edge_weight):
        x = self.stconv1(x, edge_index, edge_weight)
        x = self.stconv2(x, edge_index, edge_weight)
        x = x.permute(0, 3, 1, 2)
        x = self.output_layer(x)
        x = x.permute(0, 2, 3, 1)
        return x

In [12]:
class RMSELoss(torch.nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()

    def forward(self, x, y):
        criterion = nn.MSELoss()
        loss = torch.sqrt(criterion(x, y))
        return loss

In [14]:
# len_data = len(data)
len_data = 730
window_size = 24
future = 1
model = Model(window_size, 10)
model.train()
edge_index = data[0].edge_index  # static graph
# print(edge_index.shape)
criteron_mse = nn.MSELoss()
criteron_mae = nn.L1Loss()
criteron_rmse = RMSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(100):
    epoch_loss_mse = 0
    epoch_loss_mae = 0
    epoch_loss_rmse = 0
    for i in tqdm(range(0, len_data - window_size - future)):
        optimizer.zero_grad()
        window_data = data[i : i + window_size]
        X = [d.x for d in window_data]
        X = torch.stack(X, dim=0)
        # Add a 1 dimension in front as batch dimension
        X = X.unsqueeze(0)
        # print(X.shape)
        edge_weights = [d.edge_attr for d in window_data]
        avg_edge_weight = np.mean(edge_weights, axis=0)
        edge_weights = torch.tensor(avg_edge_weight, dtype=torch.float32)

        y = data[i + window_size + future - 1].x[:, 0]  # first feature of a node is DAP
        # print(y.shape)
        y_hat = model(X, edge_index, edge_weights)
        # print(y_hat)
        # print(y_hat.shape)
        y_hat = y_hat.squeeze(0, 1, 3)
        # print(y_hat.shape)
        loss_mse = criteron_mse(y_hat, y)
        loss_mae = criteron_mae(y_hat, y)
        loss_rmse = criteron_rmse(y_hat, y)

        epoch_loss_mse += loss_mse.item()
        epoch_loss_mae += loss_mae.item()
        epoch_loss_rmse += loss_rmse.item()
        loss_mse.backward()
        optimizer.step()
    print(f"Epoch {epoch}")
    print(epoch_loss_mse / (len_data - window_size - 1))
    print(epoch_loss_mae / (len_data - window_size - 1))
    print(epoch_loss_rmse / (len_data - window_size - 1))

100%|██████████| 705/705 [00:11<00:00, 62.59it/s]


Epoch 0
0.4651564080413457
0.41686361099177216
0.5418906099724431


100%|██████████| 705/705 [00:11<00:00, 63.92it/s]


Epoch 1
0.328510635304562
0.3586041071524857
0.4662670336610882


100%|██████████| 705/705 [00:10<00:00, 64.22it/s]


Epoch 2
0.2883502390255478
0.3385823177424728
0.44215883043006804


100%|██████████| 705/705 [00:10<00:00, 64.14it/s]


Epoch 3
0.26086981131080933
0.32050047128441483
0.42158352863175647


100%|██████████| 705/705 [00:10<00:00, 64.28it/s]


Epoch 4
0.2420447624849935
0.31138869859013996
0.41000772784364986


100%|██████████| 705/705 [00:10<00:00, 64.79it/s]


Epoch 5
0.2340144933432552
0.30273689357101496
0.4006259535644071


100%|██████████| 705/705 [00:10<00:00, 64.95it/s]


Epoch 6
0.22534082337606248
0.29930796746877913
0.39498217559241233


100%|██████████| 705/705 [00:10<00:00, 65.09it/s]


Epoch 7
0.21858872190804768
0.2949796254647539
0.38969945502830733


100%|██████████| 705/705 [00:10<00:00, 64.85it/s]


Epoch 8
0.21430887482228114
0.29307285196920657
0.3865715480865316


100%|██████████| 705/705 [00:10<00:00, 65.06it/s]


Epoch 9
0.20832057470217982
0.2864161861967956
0.3808286116689655


100%|██████████| 705/705 [00:10<00:00, 64.73it/s]


Epoch 10
0.20562951533348425
0.284603876615247
0.3770460229076392


100%|██████████| 705/705 [00:10<00:00, 65.52it/s]


Epoch 11
0.19990270809032656
0.2793612911557475
0.3719758002259207


100%|██████████| 705/705 [00:10<00:00, 65.33it/s]


Epoch 12
0.19500218527523655
0.27652578553092394
0.36798695808818155


100%|██████████| 705/705 [00:10<00:00, 65.66it/s]


Epoch 13
0.18921231267815258
0.27215910652216446
0.36235530242218195


100%|██████████| 705/705 [00:10<00:00, 65.50it/s]


Epoch 14
0.18948326433103876
0.2731287543919492
0.3624985419055249


100%|██████████| 705/705 [00:10<00:00, 65.27it/s]


Epoch 15
0.18419810559944058
0.26856981972430616
0.3586002022873425


100%|██████████| 705/705 [00:10<00:00, 65.53it/s]


Epoch 16
0.1811256961124216
0.2657665762706851
0.354303318515737


100%|██████████| 705/705 [00:10<00:00, 65.23it/s]


Epoch 17
0.17851688210647684
0.2641184505841411
0.3513310448180699


100%|██████████| 705/705 [00:10<00:00, 66.22it/s]


Epoch 18
0.1732006325829008
0.2624959357981141
0.34845955971073594


100%|██████████| 705/705 [00:10<00:00, 65.87it/s]


Epoch 19
0.16719603932310714
0.259276419349596
0.3428864759756318


100%|██████████| 705/705 [00:10<00:00, 65.56it/s]


Epoch 20
0.1705093072377251
0.26287098037647016
0.347503806566093


100%|██████████| 705/705 [00:10<00:00, 65.70it/s]


Epoch 21
0.16734973945595483
0.258061453985407
0.3424915741627098


100%|██████████| 705/705 [00:10<00:00, 65.87it/s]


Epoch 22
0.1593165763749589
0.2531305442993523
0.33588887467346296


100%|██████████| 705/705 [00:11<00:00, 59.37it/s]


Epoch 23
0.16462364161968338
0.2578886071101148
0.34166205674409866


100%|██████████| 705/705 [00:11<00:00, 60.29it/s]


Epoch 24
0.1598681131235806
0.25537815217113663
0.3383085430778087


100%|██████████| 705/705 [00:11<00:00, 62.36it/s]


Epoch 25
0.1508971453571679
0.24850018817587946
0.3306150624620999


100%|██████████| 705/705 [00:11<00:00, 62.19it/s]


Epoch 26
0.15349411232568694
0.2502057119992608
0.3316282274769553


100%|██████████| 705/705 [00:11<00:00, 63.04it/s]


Epoch 27
0.1508787900156586
0.24950726457098696
0.3306027957839323


100%|██████████| 705/705 [00:11<00:00, 62.54it/s]


Epoch 28
0.1500255690920596
0.24973592724360472
0.3287900951420162


100%|██████████| 705/705 [00:11<00:00, 62.63it/s]


Epoch 29
0.14004239757736842
0.2425646296206941
0.3216882123909098


100%|██████████| 705/705 [00:11<00:00, 62.84it/s]


Epoch 30
0.1436790132571461
0.24553247275927387
0.3249793494529758


100%|██████████| 705/705 [00:11<00:00, 63.00it/s]


Epoch 31
0.14049760376942072
0.24535512799489584
0.32370811081524437


100%|██████████| 705/705 [00:11<00:00, 62.60it/s]


Epoch 32
0.1447413946726485
0.24605564893560206
0.32525072554324536


100%|██████████| 705/705 [00:11<00:00, 62.11it/s]


Epoch 33
0.13645389131897837
0.24137630405049798
0.3189098134953925


100%|██████████| 705/705 [00:11<00:00, 62.45it/s]


Epoch 34
0.1460320331820069
0.24614053836739655
0.3254952939108331


100%|██████████| 705/705 [00:11<00:00, 62.97it/s]


Epoch 35
0.14372976588626915
0.2456398256492953
0.3256773710250854


100%|██████████| 705/705 [00:11<00:00, 62.86it/s]


Epoch 36
0.13402635401101612
0.2415128305237344
0.3190703419717491


100%|██████████| 705/705 [00:11<00:00, 62.97it/s]


Epoch 37
0.13195348990988648
0.240431327087448
0.3173540712354031


100%|██████████| 705/705 [00:11<00:00, 62.74it/s]


Epoch 38
0.14509837506735263
0.24546818908647441
0.325765867341072


100%|██████████| 705/705 [00:11<00:00, 62.29it/s]


Epoch 39
0.12950020913061097
0.23650530297177058
0.313151641824144


100%|██████████| 705/705 [00:11<00:00, 61.11it/s]


Epoch 40
0.1270471750150917
0.2377388467497014
0.3127177580768335


100%|██████████| 705/705 [00:10<00:00, 64.80it/s]


Epoch 41
0.12443353916676914
0.23153606200366156
0.3063119554614767


100%|██████████| 705/705 [00:10<00:00, 65.18it/s]


Epoch 42
0.12132508610210098
0.2307315170764923
0.3041707808138631


100%|██████████| 705/705 [00:10<00:00, 65.23it/s]


Epoch 43
0.1317619024921235
0.23494155538525988
0.31138318161486733


100%|██████████| 705/705 [00:10<00:00, 65.11it/s]


Epoch 44
0.11664793831197189
0.230270821527175
0.30029596064953096


100%|██████████| 705/705 [00:10<00:00, 65.44it/s]


Epoch 45
0.11435141617445448
0.2263129898419617
0.29637766036581487


100%|██████████| 705/705 [00:10<00:00, 65.36it/s]


Epoch 46
0.12086232736680314
0.229942117821663
0.30483317396319504


100%|██████████| 705/705 [00:10<00:00, 65.22it/s]


Epoch 47
0.1356898557100815
0.24095471591180098
0.3203052649208417


100%|██████████| 705/705 [00:10<00:00, 65.24it/s]


Epoch 48
0.10748153872497326
0.2238207536122055
0.29049911212625235


100%|██████████| 705/705 [00:10<00:00, 65.27it/s]


Epoch 49
0.10864077103753568
0.22289477028216875
0.29046479804401704


100%|██████████| 705/705 [00:10<00:00, 65.17it/s]


Epoch 50
0.11063548008455558
0.22361432379441903
0.2925961762455338


100%|██████████| 705/705 [00:10<00:00, 65.38it/s]


Epoch 51
0.11157735901792262
0.22186793984039455
0.2910456263638557


100%|██████████| 705/705 [00:10<00:00, 65.65it/s]


Epoch 52
0.107327906298286
0.21958493035524448
0.2866466281314691


100%|██████████| 705/705 [00:10<00:00, 65.48it/s]


Epoch 53
0.10768923694532075
0.22124251297391054
0.2887636069710373


100%|██████████| 705/705 [00:10<00:00, 64.68it/s]


Epoch 54
0.09685703032445295
0.21448936447606864
0.2773959132690802


100%|██████████| 705/705 [00:10<00:00, 64.97it/s]


Epoch 55
0.11376260796112689
0.22361777229829038
0.2905046650915281


100%|██████████| 705/705 [00:10<00:00, 65.31it/s]


Epoch 56
0.10497177504861714
0.21732624272505444
0.28619693424684783


100%|██████████| 705/705 [00:10<00:00, 65.45it/s]


Epoch 57
0.10736948744730747
0.22058885397640526
0.28711355751287854


100%|██████████| 705/705 [00:10<00:00, 65.13it/s]


Epoch 58
0.09981695144232494
0.2161337159838237
0.27976801316683175


100%|██████████| 705/705 [00:10<00:00, 65.09it/s]


Epoch 59
0.09079559083156129
0.2106086292799483
0.27018164657320537


100%|██████████| 705/705 [00:10<00:00, 65.28it/s]


Epoch 60
0.09291309469135095
0.2098075176331591
0.2712793998473079


100%|██████████| 705/705 [00:10<00:00, 65.18it/s]


Epoch 61
0.08914140495537708
0.2074724351630566
0.26757983966922083


100%|██████████| 705/705 [00:10<00:00, 65.23it/s]


Epoch 62
0.08927467594729037
0.20846319948422148
0.2685418295205062


100%|██████████| 705/705 [00:10<00:00, 65.37it/s]


Epoch 63
0.08986364302311278
0.20901207857309503
0.26800220966973204


100%|██████████| 705/705 [00:10<00:00, 65.21it/s]


Epoch 64
0.08747161628401026
0.2046323298139775
0.26491363441690485


100%|██████████| 705/705 [00:10<00:00, 65.09it/s]


Epoch 65
0.09056052024284682
0.20651556060035178
0.26763830395243693


100%|██████████| 705/705 [00:10<00:00, 65.08it/s]


Epoch 66
0.08884262999598967
0.20390931720852007
0.2645927150938528


100%|██████████| 705/705 [00:10<00:00, 65.35it/s]


Epoch 67
0.08725983444271042
0.20217473199392888
0.26235197647243524


100%|██████████| 705/705 [00:10<00:00, 65.16it/s]


Epoch 68
0.08751816915031126
0.20537512047705075
0.2650428906188789


100%|██████████| 705/705 [00:10<00:00, 65.37it/s]


Epoch 69
0.08382895744953912
0.20270728502713198
0.26090342619106277


100%|██████████| 705/705 [00:10<00:00, 65.18it/s]


Epoch 70
0.08706468416655318
0.2029658447343407
0.26397443934746667


100%|██████████| 705/705 [00:10<00:00, 65.22it/s]


Epoch 71
0.07975246856015519
0.19756151348986525
0.2534412738701976


100%|██████████| 705/705 [00:10<00:00, 65.45it/s]


Epoch 72
0.07814785401594448
0.19597635496909738
0.2522588065330018


100%|██████████| 705/705 [00:10<00:00, 65.10it/s]


Epoch 73
0.08159400278319281
0.19937706012675102
0.2571377503639417


100%|██████████| 705/705 [00:10<00:00, 65.07it/s]


Epoch 74
0.08631316353335765
0.20202385464035874
0.2610480267422419


100%|██████████| 705/705 [00:10<00:00, 65.57it/s]


Epoch 75
0.07708611781345615
0.19241134904272167
0.24865881894708525


100%|██████████| 705/705 [00:10<00:00, 65.20it/s]


Epoch 76
0.07606660971444444
0.19371986084796014
0.24848759895098124


100%|██████████| 705/705 [00:11<00:00, 62.13it/s]


Epoch 77
0.07735842692183581
0.195361812846035
0.25107983086548796


100%|██████████| 705/705 [00:12<00:00, 55.49it/s]


Epoch 78
0.07558187433908171
0.19479472640346973
0.24869644505546448


100%|██████████| 705/705 [00:11<00:00, 62.30it/s]


Epoch 79
0.08109727071809536
0.19551649889413347
0.25437715920150705


100%|██████████| 705/705 [00:11<00:00, 61.21it/s]


Epoch 80
0.07634152304447184
0.19299129415171365
0.24868407008495735


100%|██████████| 705/705 [00:11<00:00, 63.52it/s]


Epoch 81
0.09146742149064621
0.20297990604072597
0.2663719421688546


100%|██████████| 705/705 [00:10<00:00, 64.74it/s]


Epoch 82
0.07700153197564766
0.19302380637708286
0.2494898891195338


100%|██████████| 705/705 [00:10<00:00, 64.30it/s]


Epoch 83
0.08254592862204774
0.19833709451746434
0.2562904193786019


100%|██████████| 705/705 [00:11<00:00, 63.40it/s]


Epoch 84
0.08819710411963627
0.2026225450032569
0.26480727822433975


100%|██████████| 705/705 [00:10<00:00, 64.40it/s]


Epoch 85
0.07777253263596948
0.19440546569460673
0.25154145451936316


100%|██████████| 705/705 [00:10<00:00, 65.68it/s]


Epoch 86
0.06792530811305904
0.1847958726081865
0.2364024340472323


100%|██████████| 705/705 [00:10<00:00, 65.50it/s]


Epoch 87
0.0668318880530378
0.18305808570473753
0.23454692005477054


100%|██████████| 705/705 [00:10<00:00, 64.90it/s]


Epoch 88
0.06636853420591735
0.18186479276376413
0.2324382813264292


100%|██████████| 705/705 [00:10<00:00, 65.56it/s]


Epoch 89
0.06989640785737875
0.1851211593664707
0.23721054569415168


100%|██████████| 705/705 [00:10<00:00, 65.18it/s]


Epoch 90
0.06596770137152139
0.182113323545625
0.23196568833598008


100%|██████████| 705/705 [00:10<00:00, 64.63it/s]


Epoch 91
0.07271619703495175
0.18867484837347734
0.24300672513373356


100%|██████████| 705/705 [00:10<00:00, 64.69it/s]


Epoch 92
0.07662377983681752
0.191887141130072
0.24705674906677388


100%|██████████| 705/705 [00:10<00:00, 64.48it/s]


Epoch 93
0.07020938455503671
0.18534691036275938
0.23673521714641693


100%|██████████| 705/705 [00:10<00:00, 64.14it/s]


Epoch 94
0.06832717873935476
0.18254645573543318
0.23457556831921247


100%|██████████| 705/705 [00:10<00:00, 65.34it/s]


Epoch 95
0.06179117552922867
0.17695767705110793
0.22599712157925816


100%|██████████| 705/705 [00:11<00:00, 63.48it/s]


Epoch 96
0.06131014500658774
0.17689477444117796
0.22531223990921434


100%|██████████| 705/705 [00:10<00:00, 64.80it/s]


Epoch 97
0.06606375700232725
0.1831419752527636
0.23435598268893593


100%|██████████| 705/705 [00:10<00:00, 64.86it/s]


Epoch 98
0.06315289997290952
0.17947952106383674
0.22846645895682327


100%|██████████| 705/705 [00:10<00:00, 65.45it/s]

Epoch 99
0.06284585300674464
0.179225962096495
0.22769988079865774



