In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.append("../src/utils")
from utils import SimpleGraphVoltDatasetLoader_Lazy
from torch_geometric_temporal.signal import temporal_signal_split
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import A3TGCN
from tqdm import tqdm

In [2]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, node_features, periods):
        super(TemporalGNN, self).__init__()
        # Attention Temporal Graph Convolutional Cell
        out_channels = 32
        self.tgnn = A3TGCN(in_channels=node_features, 
                           out_channels=out_channels, 
                           periods=periods)
        # Equals single-shot prediction
        self.linear = torch.nn.Linear(out_channels, periods)

    def forward(self, x, edge_index):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        h = self.tgnn(x, edge_index)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [3]:
def train_test(model, device, loader, train_dataset, test_dataset, optimizer, loss_fn, epochs, now):
    """
    Definition of the training loop.
    """
    epoch_losses_train = []
    epoch_losses_test = []
    
    for epoch in range(epochs):
        model.train()
        epoch_loss_train = 0
        for snapshot_i in tqdm(train_dataset, desc="Training epoch {}".format(epoch)):
            snapshot = loader.get_snapshot(snapshot_i)
            # print(snapshot)
            snapshot.to(device)
            optimizer.zero_grad()
            out = model(snapshot.x, snapshot.edge_index)
            loss = loss_fn()(out, snapshot.y)
            loss.backward()
            optimizer.step()
            epoch_loss_train += loss.detach().cpu().numpy()
        epoch_losses_train.append(epoch_loss_train)
        model.eval()
        epoch_loss_test = 0
        with torch.no_grad():
            for snapshot_j in tqdm(test_dataset, desc="Testing epoch {}".format(epoch)):
                snapshot = loader.get_snapshot(snapshot_j)
                snapshot.to(device)
                out = model(snapshot.x, snapshot.edge_index)
                loss = loss_fn()(out, snapshot.y).cpu().numpy()
                epoch_loss_test += loss
            epoch_losses_test.append(epoch_loss_test)
            if min(epoch_losses_test) == epoch_loss_test:
                torch.save(model.state_dict(), f"../models/A3TGCN_{now}.pt")
            print("Epoch: {}, Train Loss: {:.7f}, Test Loss: {:.7f}".format(epoch, epoch_loss_train, epoch_loss_test))
        
        
    return epoch_losses_train, epoch_losses_test

In [21]:
def eval(model, loader, eval_dataset, device, loss_fn, std):
    with torch.no_grad():
        model.eval()
        loss_all = 0
        loss_elementwise = 0
        
        steps = 0
        for snapshot_i in tqdm(eval_dataset, desc="Evaluating"):
            snapshot = loader.get_snapshot(snapshot_i)
            steps += 1
            snapshot.to(device) #kaj je fora te vrstice?
            out = model(snapshot.x, snapshot.edge_index)
            loss_all += loss_fn()(out, snapshot.y).cpu().numpy()
            loss_elementwise += loss_fn(reduction="none")(out, snapshot.y).cpu().numpy()
        loss_all *= std/steps
        loss_elementwise *= std/steps
    return loss_all, loss_elementwise

In [5]:
torch.cuda.empty_cache() 

In [6]:
trafo_id = "T1330"
epochs = 1
num_timesteps_in = 12
num_timesteps_out = 4
train_ratio = 0.7
test_ratio_vs_eval_ratio = 0.5
learning_rate = 0.01

In [7]:
#get dateime string of now
now = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")

In [8]:
print("Loading data...", end="")
loader = SimpleGraphVoltDatasetLoader_Lazy(trafo_id, num_timesteps_in, num_timesteps_out)
print(" done")
loader_data_index = loader.snapshot_index
# loader_data = loader.get_dataset(num_timesteps_in=num_timesteps_in, num_timesteps_out=num_timesteps_out)

Loading data... done


In [9]:
train_dataset, test_eval_dataset = loader.temporal_signal_split_lazy(loader_data_index, train_ratio)
test_dataset, eval_dataset = loader.temporal_signal_split_lazy(test_eval_dataset, test_ratio_vs_eval_ratio)

In [10]:
print("Running training...")
device = torch.device('cpu')
model = TemporalGNN(node_features=loader.num_features, periods=num_timesteps_out).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.L1Loss
# print('HERE', loader.get_snapshot(0).edge_index)
losses = train_test(model, device, loader,train_dataset, test_dataset, optimizer, loss_fn, epochs=epochs, now=now)

Running training...


Training epoch 0:   0%|          | 0/49045 [00:00<?, ?it/s]

Training epoch 0: 100%|██████████| 49045/49045 [07:05<00:00, 115.40it/s]
Testing epoch 0: 100%|██████████| 10510/10510 [00:39<00:00, 263.20it/s]

Epoch: 0, Train Loss: 3948.4641827, Test Loss: 6298.1004346





In [11]:
print(losses)

([3948.4641827493906], [6298.100434599444])


In [12]:
std = loader.mean_and_std["measurements"][1]["voltage"]

In [13]:
#read saved model
model.load_state_dict(torch.load(f"../models/A3TGCN_{now}.pt"))

<All keys matched successfully>

In [22]:
loss_all, loss_elementwise = eval(model, loader, eval_dataset, device, loss_fn, std)

Evaluating: 100%|██████████| 10510/10510 [00:45<00:00, 230.52it/s]


In [23]:
print("Loss all: {:.7f}".format(loss_all))
print("Loss elementwise: {}".format(loss_elementwise))

Loss all: 0.5544011
Loss elementwise: [[7.72035569e-02 1.65244229e-02 2.74929013e-02 4.66753319e-02]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [9.96873621e-03 6.80640861e-02 6.90582842e-02 5.86075615e-03]
 [1.53142288e-01 1.20520316e-01 1.35455996e-01 1.11073531e-01]
 [3.91059881e-03 2.71374453e-03 8.96254554e-03 9.86321829e-03]
 [4.27991778e-01 4.26802188e-01 4.33077931e-01 4.14291739e-01]
 [4.94034868e-03 2.16047410e-02 1.69631932e-02 2.57839058e-02]
 [1.28978491e+00 1.23827112e+00 1.36381400e+00 1.21866238e+00]
 [4.19146001e-01 4.45768923e-01 4.41035032e-01 4.49969023e-01]
 [1.17976535e-02 2.08381265e-02 5.90093341e-03 3.03031877e-02]
 [4.19146001e-01 