In [1]:
import torch
import os
import matplotlib.pyplot as plt
os.chdir("..")
os.sys.path.append(os.path.join(os.path.abspath(''), '..'))

from epilearn.models.SpatialTemporal.STGCN import STGCN

from epilearn.models.Spatial.GCN import GCN
from epilearn.models.Spatial.SAGE import SAGE
from epilearn.models.Spatial.GAT import GAT
from epilearn.models.Spatial.GIN import GIN

from epilearn.data import UniversalDataset
from epilearn.utils import utils, transforms
from epilearn.utils import simulation
from epilearn.tasks.detection import Detection

### Configs

In [2]:
# initial settings
device = torch.device('cpu')
torch.manual_seed(7)

lookback = 1 # inputs size
horizon = 2 # predicts size; also seen as number of classes

epochs = 50 # training epochs
batch_size = 25 # training batch size

### Initialize Dataset

In [3]:
# load toy dataset
dataset = UniversalDataset()
dataset.load_toy_dataset()
dataset.x.shape

torch.Size([539, 47, 4])

### Initialize model and task

In [4]:
task = Detection(prototype=GCN, dataset=None, lookback=lookback, horizon=horizon, device='cpu')

### Add transformations

In [5]:

transformation = transforms.Compose({
                                 'features':[transforms.add_time_embedding(embedding_dim=4, fourier=False),
                                             transforms.convert_to_frequency(ftype="fft"),
                                             transforms.normalize_feat()], 
                                 'graph': [transforms.normalize_adj()], 
                                 'dynamic_graph': [transforms.normalize_adj()], 
                                 'states': []
                                 })

'''transformation = transforms.Compose({
                                 'features':[], 
                                 'graph': [], 
                                 'dynamic_graph': [], 
                                 'states': []
                                 })'''
dataset.transforms = transformation

### Train model

In [6]:
config = None
result = task.train_model(dataset=dataset, config=config, loss='ce', epochs=5) # instead of config, we can also dircetly input some parameters

spatial model loaded!


 20%|██        | 1/5 [00:26<01:45, 26.28s/it]

######### epoch:0
Training loss: 0.21719585359096527
Validation loss: 0.1288830041885376


 40%|████      | 2/5 [01:02<01:36, 32.29s/it]

######### epoch:1
Training loss: 0.11564651876688004
Validation loss: 0.10498964786529541


 60%|██████    | 3/5 [02:04<01:31, 45.77s/it]

######### epoch:2
Training loss: 0.10384580492973328
Validation loss: 0.09909678250551224


 80%|████████  | 4/5 [02:46<00:44, 44.28s/it]

######### epoch:3
Training loss: 0.10048463940620422
Validation loss: 0.09528712928295135


100%|██████████| 5/5 [03:18<00:00, 39.70s/it]

######### epoch:4
Training loss: 0.09822500497102737
Validation loss: 0.09225907176733017

Final Training loss: 0.09822500497102737
Final Validation loss: 0.09225907176733017
Best Epoch: 4
Best Training loss: 0.09822500497102737
Best Validation loss: 0.09225907176733017






Predicting Progress...


100%|██████████| 108/108 [00:38<00:00,  2.84it/s]


Test ACC: 2.194444417953491
