In [1]:
import torch
import os
import matplotlib.pyplot as plt
os.chdir("..")
os.sys.path.append(os.path.join(os.path.abspath(''), '..'))

from epilearn.models.SpatialTemporal.STGCN import STGCN

from epilearn.models.Spatial.GCN import GCN
from epilearn.models.Spatial.SAGE import SAGE
from epilearn.models.Spatial.GAT import GAT
from epilearn.models.Spatial.GIN import GIN

from epilearn.data import UniversalDataset
from epilearn.utils import utils, transforms
from epilearn.utils import simulation
from epilearn.tasks.detection import Detection

### Configs

In [2]:
# initial settings
device = torch.device('cpu')
torch.manual_seed(7)

lookback = 1 # inputs size
horizon = 2 # predicts size; also seen as number of classes

epochs = 50 # training epochs
batch_size = 25 # training batch size

### Initialize Dataset

In [3]:
# load toy dataset
dataset = UniversalDataset()
dataset.load_toy_dataset()
dataset.x.shape

torch.Size([539, 47, 4])

### Initialize model and task

In [4]:
task = Detection(prototype=GCN, dataset=None, lookback=lookback, horizon=horizon, device='cpu')

### Add transformations

In [5]:

transformation = transforms.Compose({
                                 'features':[transforms.add_time_embedding(embedding_dim=4, fourier=False),
                                             transforms.normalize_feat()], 
                                 'graph': [transforms.normalize_adj()], 
                                 'dynamic_graph': [transforms.normalize_adj()], 
                                 'states': []
                                 })

'''transformation = transforms.Compose({
                                 'features':[], 
                                 'graph': [], 
                                 'dynamic_graph': [], 
                                 'states': []
                                 })'''
dataset.transforms = transformation

### Train model

In [6]:
config = None
result = task.train_model(dataset=dataset, config=config, loss='ce', epochs=5) # instead of config, we can also dircetly input some parameters

spatial model loaded!


 20%|██        | 1/5 [00:17<01:09, 17.48s/it]

######### epoch:0
Training loss: 0.13716816902160645
Validation loss: 0.14406348764896393


 40%|████      | 2/5 [00:38<00:58, 19.61s/it]

######### epoch:1
Training loss: 0.11759451776742935
Validation loss: 0.13170084357261658


 60%|██████    | 3/5 [00:59<00:40, 20.11s/it]

######### epoch:2
Training loss: 0.10998421907424927
Validation loss: 0.12396952509880066


 80%|████████  | 4/5 [01:20<00:20, 20.46s/it]

######### epoch:3
Training loss: 0.10743874311447144
Validation loss: 0.11801645904779434


100%|██████████| 5/5 [01:43<00:00, 20.68s/it]

######### epoch:4
Training loss: 0.1055581346154213
Validation loss: 0.11123505234718323

Final Training loss: 0.1055581346154213
Final Validation loss: 0.11123505234718323
Best Epoch: 4
Best Training loss: 0.1055581346154213
Best Validation loss: 0.11123505234718323






Predicting Progress...


100%|██████████| 108/108 [00:21<00:00,  5.10it/s]


Test ACC: 2.194444417953491
