In [1]:
import sys
import torch
import numpy as np
from tqdm import tqdm
sys.path.append('../model/geometric_temporal/')
from recurrent import LSTMGCNModel
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal

In [2]:
dataset = DynamicGraphTemporalSignal(
    edge_indices = [np.array([[0, 1], [1, 0]]), np.array([[1, 0], [0, 1]])],
    edge_weights = [np.array([[4], [2.3]]), np.array([[3.2], [1.2]])],
    features = [np.array([[1, 2], [3, 1]]), np.array([[2, 3], [1, 1]])],
    targets = [np.array([[0], [0]]), np.array([[0], [0]])]  # goldstein
)

In [3]:
model = LSTMGCNModel(2, 10, 1, 8)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(10)):
    cost = 0
    for time, snapshot in enumerate(dataset):
        print(f"Epoch: {epoch}, Snapshot: {snapshot}")
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        
        for hat in y_hat:
            print(f"hat: {hat}")
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 139.43it/s]

Epoch: 0, Snapshot: Data(x=[2, 2], edge_index=[2, 2], edge_attr=[2, 1], y=[2, 1])
hat: tensor([-0.2187, -0.2692,  0.2752, -0.0446,  0.0047,  0.0592,  0.0882,  0.2112],
       grad_fn=<UnbindBackward0>)
hat: tensor([-0.1796, -0.2699,  0.2925, -0.1545,  0.0164,  0.0913,  0.0947,  0.1742],
       grad_fn=<UnbindBackward0>)
Epoch: 0, Snapshot: Data(x=[2, 2], edge_index=[2, 2], edge_attr=[2, 1], y=[2, 1])
hat: tensor([-0.2101, -0.2722,  0.2809, -0.0412,  0.0146,  0.0613,  0.0961,  0.2051],
       grad_fn=<UnbindBackward0>)
hat: tensor([-0.2143, -0.2639,  0.2539, -0.1118, -0.0168,  0.0659,  0.1045,  0.2120],
       grad_fn=<UnbindBackward0>)
Epoch: 1, Snapshot: Data(x=[2, 2], edge_index=[2, 2], edge_attr=[2, 1], y=[2, 1])
hat: tensor([-0.1997, -0.2563,  0.2617, -0.0283, -0.0063,  0.0469,  0.0828,  0.1941],
       grad_fn=<UnbindBackward0>)
hat: tensor([-0.1775, -0.2673,  0.2648, -0.1231,  0.0033,  0.0645,  0.0841,  0.1608],
       grad_fn=<UnbindBackward0>)
Epoch: 1, Snapshot: Data(x=[2, 2],


