In [1]:
import torch
import os
import matplotlib.pyplot as plt
os.chdir("..")

from epilearn.models.SpatialTemporal.STGCN import STGCN

from epilearn.models.Spatial.GCN import GCN
from epilearn.models.Spatial.SAGE import SAGE
from epilearn.models.Spatial.GAT import GAT
from epilearn.models.Spatial.GIN import GIN

from epilearn.data import UniversalDataset
from epilearn.utils import utils, transforms
from epilearn.utils import simulation
from epilearn.tasks.detection import Detection

### Configs

In [2]:
# initial settings
device = torch.device('cpu')
torch.manual_seed(7)

lookback = 1 # inputs size
horizon = 2 # predicts size; also seen as number of classes

epochs = 50 # training epochs
batch_size = 25 # training batch size

### Initialize Dataset

In [3]:
# load toy dataset
dataset = UniversalDataset()
dataset.load_toy_dataset()

### Initialize model and task

In [4]:
task = Detection(prototype=GCN, dataset=None, lookback=lookback, horizon=horizon, device='cpu')

### Add transformations

In [5]:
# transformation = transforms.Compose({
#                                  'features':[transforms.normalize_feat()], 
#                                  'graph': [transforms.normalize_adj()], 
#                                  'dynamic_graph': [transforms.normalize_adj()], 
#                                  'states': []
#                                  })
transformation = transforms.Compose({
                                 'features':[], 
                                 'graph': [], 
                                 'dynamic_graph': [], 
                                 'states': []
                                 })
dataset.transforms = transformation

### Train model

In [6]:
config = None
result = task.train_model(dataset=dataset, config=config, loss='ce', epochs=5) # instead of config, we can also dircetly input some parameters

spatial model loaded!


 20%|██        | 1/5 [00:24<01:37, 24.48s/it]

######### epoch:0
Training loss: 1.8975600004196167
Validation loss: 2.861781120300293


 40%|████      | 2/5 [00:49<01:13, 24.61s/it]

######### epoch:1
Training loss: 1.2397278547286987
Validation loss: 1.442754864692688


 60%|██████    | 3/5 [01:17<00:52, 26.16s/it]

######### epoch:2
Training loss: 0.857474684715271
Validation loss: 1.2778254747390747


 80%|████████  | 4/5 [01:50<00:29, 29.06s/it]

######### epoch:3
Training loss: 0.5894216299057007
Validation loss: 0.9511571526527405


100%|██████████| 5/5 [02:12<00:00, 26.42s/it]

######### epoch:4
Training loss: 0.5351558327674866
Validation loss: 0.6486267447471619

Final Training loss: 0.5351558327674866
Final Validation loss: 0.6486267447471619
Best Epoch: 4
Best Training loss: 0.5351558327674866
Best Validation loss: 0.6486267447471619






Predicting Progress...


100%|██████████| 108/108 [00:37<00:00,  2.89it/s]


Test ACC: 2.194444417953491


### Train on Simulated dataset

In [7]:
# Simulation Process
from epilearn.models.SpatialTemporal.NetworkSIR import NetSIR

# generate 10 samples
num_nodes = 25
# generate random static graph: 25 nodes
initial_graph = simulation.get_random_graph(num_nodes=num_nodes, connect_prob=0.15)
initial_states = torch.zeros(num_nodes,3) # [S,I,R]
initial_states[:, 0] = 1

graph = initial_graph
x = []
y = []
for i in range(100): 
    # set infected individual
    idx = torch.randint(0,num_nodes, (1,))
    initial_states[idx.item(), 0] = 0
    initial_states[idx.item(), 1] = 1

    model = NetSIR(num_nodes=initial_graph.shape[0], horizon=100, infection_rate=0.01, recovery_rate=0.0384) # infection_rate, recover_rate, fixed_population
    preds = model(initial_states, initial_graph, steps = None)
    x.append(torch.nn.functional.one_hot(preds[-1].argmax(1)))
    y.append(initial_states.argmax(1))
x = torch.stack(x)
y = torch.stack(y)

In [8]:
dataset = UniversalDataset(x=x,y=y,graph=initial_graph)
dataset.transforms = transformation
task = Detection(prototype=GCN, dataset=dataset, lookback=lookback, horizon=horizon, device='cpu')

In [9]:
result = task.train_model(dataset=dataset, loss='ce', epochs=5)

spatial model loaded!


 20%|██        | 1/5 [00:10<00:41, 10.43s/it]

######### epoch:0
Training loss: 0.6597769856452942
Validation loss: 0.5772261023521423


 40%|████      | 2/5 [00:15<00:22,  7.42s/it]

######### epoch:1
Training loss: 0.6493467688560486
Validation loss: 0.5620183944702148


 60%|██████    | 3/5 [00:21<00:13,  6.54s/it]

######### epoch:2
Training loss: 0.644253671169281
Validation loss: 0.5452414751052856


 80%|████████  | 4/5 [00:26<00:06,  6.05s/it]

######### epoch:3
Training loss: 0.6366496682167053
Validation loss: 0.5283515453338623


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]

######### epoch:4
Training loss: 0.6305637955665588
Validation loss: 0.5115376114845276

Final Training loss: 0.6305637955665588
Final Validation loss: 0.5115376114845276
Best Epoch: 4
Best Training loss: 0.6305637955665588
Best Validation loss: 0.5115376114845276






Predicting Progress...


100%|██████████| 20/20 [00:12<00:00,  1.61it/s]

Test ACC: 24.0



