In [1]:
import torch
import os
import matplotlib.pyplot as plt
os.chdir("..")

from epilearn.models.SpatialTemporal.STGCN import STGCN
from epilearn.models.SpatialTemporal.MepoGNN import MepoGNN
from epilearn.models.SpatialTemporal.EpiGNN import EpiGNN
from epilearn.models.SpatialTemporal.DASTGN import DASTGN
from epilearn.models.SpatialTemporal.ColaGNN import ColaGNN
from epilearn.models.SpatialTemporal.EpiColaGNN import EpiColaGNN
from epilearn.models.SpatialTemporal.CNNRNN_Res import CNNRNN_Res
from epilearn.models.SpatialTemporal.ATMGNN import MPNN_LSTM, ATMGNN

from epilearn.models.Temporal.Dlinear import DlinearModel
from epilearn.models.Temporal.LSTM import LSTMModel
from epilearn.models.Temporal.GRU import GRUModel

from epilearn.data import UniversalDataset
from epilearn.utils import utils, transforms
from epilearn.tasks.forecast import Forecast

### Configs

In [2]:
# initial settings
device = torch.device('mps')
torch.manual_seed(7)

lookback = 12 # inputs size
horizon = 3 # predicts size

# permutation is True when using STGCN
permute = False

epochs = 10 # training epochs
batch_size = 50 # training batch size

### Initialize dataset

In [3]:
# load toy dataset
dataset = UniversalDataset()
dataset.load_toy_dataset()

### Initialize model and task
* prototype supports all models imported at the first cell

In [4]:
task = Forecast(prototype=EpiGNN, dataset=None, lookback=lookback, horizon=horizon)

### Add transformations to dataset

In [5]:
transformation = transforms.Compose({"features":[transforms.normalize_feat()], 
                                 'graph': [transforms.normalize_adj()], 
                                 'dynamic_graph': [transforms.normalize_adj()], 
                                 'states': []
                                 })
dataset.transforms = transformation

### Train Model
* for epicolagnn, loss='epi_cola' else loss='mse
* for STGCN, permute_dataset=True

In [6]:
config = None
result = task.train_model(dataset=dataset, config=config, loss='mse', epochs=5, permute_dataset=permute, device=device) # instead of config, we can also dircetly input some parameters

spatial-temporal model loaded!


100%|██████████| 5/5 [00:03<00:00,  1.54it/s]



Final Training loss: 9522.398098853326
Final Validation loss: 23821.14453125
Test MSE: 299950.25
Test MAE: 190.66786193847656
Test RMSE: 547.6771240234375





### Evaluate model

In [7]:
evaluation = task.evaluate_model()

Test MSE: 299950.25
Test MAE: 190.66786193847656
Test RMSE: 547.6771240234375


### Try more datasets

In [8]:
# load other datasets
datasets = [dataset]
raw_data = torch.load("datasets/covid_static.pt")
for name in ['Brazil', 'Austria', 'China']:
    data = raw_data[name]
    dataset = UniversalDataset()
    dataset.x = data['features']
    dataset.y = data['features'][:,:,0]
    dataset.graph = data['graph']
    dataset.states = data['features']
    dataset.dynamic_graph = None

    dataset.transforms = transformation
    datasets.append(dataset)


  raw_data = torch.load("datasets/covid_static.pt")


In [9]:
for i, dataset in enumerate(datasets):
    print(f"dataset {i}")
    model = task.train_model(dataset=dataset, config=config, loss='mse', epochs=50, batch_size=50, permute_dataset=permute) # instead of config, we can also dircetly input some parameters

dataset 0
spatial-temporal model loaded!


100%|██████████| 50/50 [00:06<00:00,  7.67it/s]




Final Training loss: 8732.541434151786
Final Validation loss: 25009.900390625
Test MSE: 293662.21875
Test MAE: 187.66256713867188
Test RMSE: 541.9061279296875
dataset 1
spatial-temporal model loaded!


100%|██████████| 50/50 [00:02<00:00, 19.72it/s]




Final Training loss: 908052.6875
Final Validation loss: 601284.0625
Test MSE: 232661.734375
Test MAE: 259.1087341308594
Test RMSE: 482.3502197265625
dataset 2
spatial-temporal model loaded!


100%|██████████| 50/50 [00:02<00:00, 20.06it/s]




Final Training loss: 17236.51611328125
Final Validation loss: 703.6200561523438
Test MSE: 218.26226806640625
Test MAE: 9.065763473510742
Test RMSE: 14.773701667785645
dataset 3
spatial-temporal model loaded!


100%|██████████| 50/50 [00:01<00:00, 33.36it/s]




Final Training loss: 7147.6953125
Final Validation loss: 48518.01953125
Test MSE: 915881.4375
Test MAE: 107.65332794189453
Test RMSE: 957.0169677734375


### Try temporal models

In [10]:
task = Forecast(prototype=LSTMModel, dataset=None, lookback=lookback, horizon=horizon, device='cpu')
num_nodes = 47
mae_list=[]
rmse_list=[]
for region in range(num_nodes):
    print("region", region)
    result = task.train_model(dataset=datasets[-1], config=config, loss='mse', epochs=50, batch_size=50, region_idx=1, permute_dataset=False)
    mae_list.append(result['mae'])
    rmse_list.append(result['rmse'])

mae = torch.FloatTensor(mae_list)
rmse = torch.FloatTensor(rmse_list)
print(f"mae:{mae.mean()} {mae.std()}")
print(f"rmse:{rmse.mean()} {rmse.std()}")

region 0
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 151.10it/s]




Final Training loss: 5.020066738128662
Final Validation loss: 9.171232223510742
Test MSE: 15.183773040771484
Test MAE: 2.93493390083313
Test RMSE: 3.8966360092163086
region 1
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 165.08it/s]




Final Training loss: 4.789176940917969
Final Validation loss: 8.993369102478027
Test MSE: 14.404995918273926
Test MAE: 2.9321441650390625
Test RMSE: 3.795391321182251
region 2
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 147.28it/s]




Final Training loss: 4.766205787658691
Final Validation loss: 8.980277061462402
Test MSE: 14.452549934387207
Test MAE: 2.9339582920074463
Test RMSE: 3.8016510009765625
region 3
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 162.94it/s]




Final Training loss: 4.881421089172363
Final Validation loss: 9.153989791870117
Test MSE: 14.480252265930176
Test MAE: 2.9459407329559326
Test RMSE: 3.8052926063537598
region 4
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 117.46it/s]




Final Training loss: 4.865485668182373
Final Validation loss: 9.005175590515137
Test MSE: 15.18845272064209
Test MAE: 2.939289093017578
Test RMSE: 3.8972365856170654
region 5
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 118.50it/s]




Final Training loss: 4.777289867401123
Final Validation loss: 9.09062671661377
Test MSE: 15.423748970031738
Test MAE: 2.9389092922210693
Test RMSE: 3.9273080825805664
region 6
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 131.07it/s]




Final Training loss: 4.863633155822754
Final Validation loss: 9.053088188171387
Test MSE: 15.548702239990234
Test MAE: 2.9466707706451416
Test RMSE: 3.9431843757629395
region 7
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 150.90it/s]




Final Training loss: 4.791506767272949
Final Validation loss: 9.070149421691895
Test MSE: 15.066975593566895
Test MAE: 2.9334380626678467
Test RMSE: 3.881620168685913
region 8
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 152.98it/s]




Final Training loss: 4.834712505340576
Final Validation loss: 9.19892406463623
Test MSE: 15.548991203308105
Test MAE: 2.957782030105591
Test RMSE: 3.943220853805542
region 9
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 151.34it/s]




Final Training loss: 4.815269947052002
Final Validation loss: 9.074174880981445
Test MSE: 14.39230728149414
Test MAE: 2.952535390853882
Test RMSE: 3.793719530105591
region 10
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 151.37it/s]




Final Training loss: 4.7548956871032715
Final Validation loss: 9.043291091918945
Test MSE: 15.07220458984375
Test MAE: 2.929243326187134
Test RMSE: 3.882293701171875
region 11
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 148.02it/s]




Final Training loss: 4.882945537567139
Final Validation loss: 9.143513679504395
Test MSE: 15.661742210388184
Test MAE: 2.9458630084991455
Test RMSE: 3.957491874694824
region 12
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 136.99it/s]




Final Training loss: 4.7580485343933105
Final Validation loss: 9.16529369354248
Test MSE: 14.453628540039062
Test MAE: 2.931640625
Test RMSE: 3.801792860031128
region 13
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 142.98it/s]




Final Training loss: 4.829558372497559
Final Validation loss: 9.071805953979492
Test MSE: 14.732754707336426
Test MAE: 2.9783573150634766
Test RMSE: 3.838327169418335
region 14
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 136.52it/s]




Final Training loss: 4.678573131561279
Final Validation loss: 9.123279571533203
Test MSE: 14.917170524597168
Test MAE: 2.925870180130005
Test RMSE: 3.8622753620147705
region 15
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 154.06it/s]




Final Training loss: 4.8779988288879395
Final Validation loss: 9.107592582702637
Test MSE: 15.087157249450684
Test MAE: 2.9326012134552
Test RMSE: 3.88421893119812
region 16
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 143.96it/s]




Final Training loss: 4.727482318878174
Final Validation loss: 9.098685264587402
Test MSE: 15.128108024597168
Test MAE: 2.9348390102386475
Test RMSE: 3.889486789703369
region 17
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 159.55it/s]




Final Training loss: 4.87184476852417
Final Validation loss: 9.138330459594727
Test MSE: 15.01755142211914
Test MAE: 2.9303176403045654
Test RMSE: 3.875248670578003
region 18
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 156.73it/s]




Final Training loss: 4.848287582397461
Final Validation loss: 9.152745246887207
Test MSE: 14.95703125
Test MAE: 2.9299182891845703
Test RMSE: 3.867432117462158
region 19
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.18it/s]




Final Training loss: 4.911590576171875
Final Validation loss: 9.17906665802002
Test MSE: 14.49023723602295
Test MAE: 2.926523208618164
Test RMSE: 3.8066043853759766
region 20
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 160.05it/s]




Final Training loss: 4.807560443878174
Final Validation loss: 9.016924858093262
Test MSE: 14.326336860656738
Test MAE: 2.9675095081329346
Test RMSE: 3.7850148677825928
region 21
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.93it/s]




Final Training loss: 4.816884994506836
Final Validation loss: 9.147095680236816
Test MSE: 15.190625190734863
Test MAE: 2.9358537197113037
Test RMSE: 3.897515296936035
region 22
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.45it/s]




Final Training loss: 4.883336067199707
Final Validation loss: 9.114570617675781
Test MSE: 15.029906272888184
Test MAE: 2.934427261352539
Test RMSE: 3.8768422603607178
region 23
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 159.12it/s]




Final Training loss: 5.0753021240234375
Final Validation loss: 9.036827087402344
Test MSE: 14.91317367553711
Test MAE: 2.9275810718536377
Test RMSE: 3.861757755279541
region 24
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.55it/s]




Final Training loss: 4.922022342681885
Final Validation loss: 9.1727876663208
Test MSE: 14.881866455078125
Test MAE: 2.9266910552978516
Test RMSE: 3.8577022552490234
region 25
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 144.00it/s]




Final Training loss: 4.809100151062012
Final Validation loss: 8.988743782043457
Test MSE: 14.50797176361084
Test MAE: 2.9446334838867188
Test RMSE: 3.8089332580566406
region 26
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.91it/s]




Final Training loss: 4.915351867675781
Final Validation loss: 9.062477111816406
Test MSE: 15.264854431152344
Test MAE: 2.932987928390503
Test RMSE: 3.9070262908935547
region 27
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.63it/s]




Final Training loss: 4.945450305938721
Final Validation loss: 8.999634742736816
Test MSE: 14.449067115783691
Test MAE: 2.955000877380371
Test RMSE: 3.8011927604675293
region 28
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.98it/s]




Final Training loss: 4.847881317138672
Final Validation loss: 9.045326232910156
Test MSE: 14.427529335021973
Test MAE: 2.943922281265259
Test RMSE: 3.798358678817749
region 29
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 160.92it/s]




Final Training loss: 4.8126420974731445
Final Validation loss: 9.115446090698242
Test MSE: 14.97359848022461
Test MAE: 2.9312326908111572
Test RMSE: 3.8695733547210693
region 30
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.40it/s]




Final Training loss: 4.917630195617676
Final Validation loss: 9.156950950622559
Test MSE: 14.507359504699707
Test MAE: 2.945985794067383
Test RMSE: 3.8088526725769043
region 31
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 146.29it/s]




Final Training loss: 4.868114948272705
Final Validation loss: 9.086874961853027
Test MSE: 14.331466674804688
Test MAE: 2.9357261657714844
Test RMSE: 3.7856924533843994
region 32
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.18it/s]




Final Training loss: 4.803954124450684
Final Validation loss: 9.119641304016113
Test MSE: 14.615890502929688
Test MAE: 2.9730589389801025
Test RMSE: 3.823073387145996
region 33
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 156.80it/s]




Final Training loss: 4.799283981323242
Final Validation loss: 9.05598258972168
Test MSE: 14.515494346618652
Test MAE: 2.9279181957244873
Test RMSE: 3.8099205493927
region 34
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.56it/s]




Final Training loss: 4.883335590362549
Final Validation loss: 9.02387523651123
Test MSE: 14.531270027160645
Test MAE: 2.9421956539154053
Test RMSE: 3.811990261077881
region 35
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 156.32it/s]




Final Training loss: 4.881418228149414
Final Validation loss: 9.002523422241211
Test MSE: 14.553837776184082
Test MAE: 2.941279649734497
Test RMSE: 3.8149492740631104
region 36
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 132.29it/s]




Final Training loss: 4.798800945281982
Final Validation loss: 8.973701477050781
Test MSE: 15.333586692810059
Test MAE: 2.95531964302063
Test RMSE: 3.9158124923706055
region 37
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 129.28it/s]




Final Training loss: 4.763924598693848
Final Validation loss: 9.034420013427734
Test MSE: 15.184524536132812
Test MAE: 2.9366445541381836
Test RMSE: 3.8967325687408447
region 38
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.28it/s]




Final Training loss: 4.838878154754639
Final Validation loss: 8.993476867675781
Test MSE: 15.515599250793457
Test MAE: 2.9425716400146484
Test RMSE: 3.9389846324920654
region 39
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 148.45it/s]




Final Training loss: 4.9681243896484375
Final Validation loss: 9.09703540802002
Test MSE: 15.483543395996094
Test MAE: 2.9386990070343018
Test RMSE: 3.934913396835327
region 40
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.34it/s]




Final Training loss: 4.826498508453369
Final Validation loss: 9.043079376220703
Test MSE: 15.543010711669922
Test MAE: 2.9436800479888916
Test RMSE: 3.94246244430542
region 41
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 155.81it/s]




Final Training loss: 4.774615287780762
Final Validation loss: 9.103507995605469
Test MSE: 15.188115119934082
Test MAE: 2.930997610092163
Test RMSE: 3.897193193435669
region 42
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 156.56it/s]




Final Training loss: 4.905758380889893
Final Validation loss: 9.197304725646973
Test MSE: 15.080506324768066
Test MAE: 2.9308154582977295
Test RMSE: 3.8833627700805664
region 43
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 153.55it/s]




Final Training loss: 4.809247970581055
Final Validation loss: 9.024809837341309
Test MSE: 14.871265411376953
Test MAE: 2.922654151916504
Test RMSE: 3.856328010559082
region 44
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 156.67it/s]




Final Training loss: 4.923417568206787
Final Validation loss: 9.143959045410156
Test MSE: 14.35190200805664
Test MAE: 2.943540334701538
Test RMSE: 3.7883903980255127
region 45
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.56it/s]




Final Training loss: 4.7872209548950195
Final Validation loss: 9.052711486816406
Test MSE: 14.852538108825684
Test MAE: 2.9296348094940186
Test RMSE: 3.8538990020751953
region 46
temporal model loaded!


100%|██████████| 50/50 [00:00<00:00, 157.28it/s]



Final Training loss: 4.817765235900879
Final Validation loss: 9.118547439575195
Test MSE: 14.401469230651855
Test MAE: 2.9464633464813232
Test RMSE: 3.794926881790161
mae:2.9397404193878174 0.012147119268774986
rmse:3.858975410461426 0.052178412675857544



