In [6]:
import pandas as pd
import numpy as np
import sys
sys.path.append("../src/utils")
from utils import SimpleGraphVoltDatasetLoader, read_and_prepare_data
from torch_geometric_temporal.signal import temporal_signal_split
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import A3TGCN
from tqdm import tqdm


In [2]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

  nonzero_finite_vals = torch.masked_select(


tensor([1.], device='mps:0')


In [7]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, node_features, periods):
        super(TemporalGNN, self).__init__()
        # Attention Temporal Graph Convolutional Cell
        out_channels = 32
        self.tgnn = A3TGCN(in_channels=node_features, 
                           out_channels=out_channels, 
                           periods=periods)
        # Equals single-shot prediction
        self.linear = torch.nn.Linear(out_channels, periods)

    def forward(self, x, edge_index, edge_weights):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        edge_weights = Graph edge weights
        """
        h = self.tgnn(x, edge_index, edge_weights)
        h = F.relu(h)
        h = self.linear(h)
        return h


In [22]:
def train_test(model, device, train_dataset, test_dataset, optimizer, loss_fn, epochs, now):
    """
    Definition of the training loop.
    """
    epoch_losses_train = []
    epoch_losses_test = []
    
    for epoch in range(epochs):
        model.train()
        epoch_loss_train = 0

        subset = 100
        step=0

        for snapshot in tqdm(train_dataset, desc="Training epoch {}".format(epoch)):
            snapshot.to(device)
            optimizer.zero_grad()
            out = model(snapshot.x, snapshot.edge_index,snapshot.edge_weight)
            loss = loss_fn()(out, snapshot.y)
            loss.backward()
            optimizer.step()
            epoch_loss_train += loss.detach().cpu().numpy()

            step+=1

            if step > subset:
                break


        epoch_losses_train.append(epoch_loss_train)
        model.eval()
        epoch_loss_test = 0
        with torch.no_grad():

            subset = 100
            step=0

            for snapshot in tqdm(test_dataset, desc="Testing epoch {}".format(epoch)):
                snapshot.to(device)
                out = model(snapshot.x, snapshot.edge_index,snapshot.edge_weight)
                loss = loss_fn()(out, snapshot.y).cpu().numpy()
                epoch_loss_test += loss

                step+=1
                if step > subset:
                    break

            epoch_losses_test.append(epoch_loss_test)
            if min(epoch_losses_test) == epoch_loss_test:
                torch.save(model.state_dict(), f"../models/A3TGCN_{now}_{trafo_id}_epochs-{epochs}_in-{num_timesteps_in}_out-{num_timesteps_out}_train-ratio-{train_ratio}_lr-{learning_rate}.pt")
            print("Epoch: {}, Train Loss: {:.7f}, Test Loss: {:.7f}".format(epoch, epoch_loss_train, epoch_loss_test))
        
        
    return epoch_losses_train, epoch_losses_test
            

In [25]:
def eval(model, eval_dataset, device, loss_fn, std):
    with torch.no_grad():
        model.eval()
        loss_all = 0
        loss_elementwise = 0
        
        steps = 0
        for snapshot in tqdm(eval_dataset, desc="Evaluating"):
            steps += 1
            snapshot.to(device)
            out = model(snapshot.x, snapshot.edge_index,snapshot.edge_weight)
            loss_all += loss_fn()(out, snapshot.y).cpu().numpy()
            loss_elementwise += loss_fn(reduction="none")(out, snapshot.y).cpu().numpy()

            if steps > 1000:
                break

        loss_all *= std/steps
        loss_elementwise *= std/steps
    return loss_all, loss_elementwise

In [10]:
trafo_id = "T1330"
epochs = 25
num_timesteps_in = 12
num_timesteps_out = 4
train_ratio = 0.7
test_ratio_vs_eval_ratio = 0.5
learning_rate = 0.01
device_str = 'mps'

#----------------------
if device_str == 'mps':
    torch.cuda.empty_cache()

#get dateime string of now
now = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")

In [11]:
print("Loading data...")
loader = SimpleGraphVoltDatasetLoader(trafo_id)
loader_data = loader.get_dataset(num_timesteps_in=num_timesteps_in, num_timesteps_out=num_timesteps_out)

train_dataset, test_eval_dataset = temporal_signal_split(loader_data, train_ratio=train_ratio)
test_dataset, eval_dataset = temporal_signal_split(test_eval_dataset, train_ratio=test_ratio_vs_eval_ratio)

Loading data...
Voltage index: 5
Voltage index: 5


In [23]:
print("Running training...")
device = torch.device(device_str)
model = TemporalGNN(node_features=train_dataset[0].x.shape[1], periods=train_dataset[0].y.shape[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.L1Loss
losses = train_test(model, device, train_dataset, test_dataset, optimizer, loss_fn, epochs=epochs, now=now)

Running training...


Training epoch 0: 0it [00:00, ?it/s]

Training epoch 0: 100it [00:12,  7.85it/s]
Testing epoch 0: 100it [00:08, 12.23it/s]


Epoch: 0, Train Loss: 28.3214764, Test Loss: 53.3878801


Training epoch 1: 100it [00:13,  7.62it/s]
Testing epoch 1: 100it [00:07, 13.02it/s]


Epoch: 1, Train Loss: 28.8309276, Test Loss: 53.1604500


Training epoch 2: 100it [00:12,  7.79it/s]
Testing epoch 2: 100it [00:09, 11.07it/s]


Epoch: 2, Train Loss: 28.7793804, Test Loss: 53.9404162


Training epoch 3: 100it [00:13,  7.41it/s]
Testing epoch 3: 100it [00:08, 12.24it/s]


Epoch: 3, Train Loss: 27.7316262, Test Loss: 54.0300387


Training epoch 4: 100it [00:12,  7.91it/s]
Testing epoch 4: 100it [00:07, 12.56it/s]


Epoch: 4, Train Loss: 29.1348240, Test Loss: 54.0341364


Training epoch 5: 100it [00:12,  7.82it/s]
Testing epoch 5: 100it [00:08, 12.14it/s]


Epoch: 5, Train Loss: 28.5673506, Test Loss: 53.4555289


Training epoch 6: 100it [00:14,  7.10it/s]
Testing epoch 6: 100it [00:09, 10.91it/s]


Epoch: 6, Train Loss: 28.2594076, Test Loss: 52.4400500


Training epoch 7: 100it [00:14,  7.13it/s]
Testing epoch 7: 100it [00:08, 12.01it/s]


Epoch: 7, Train Loss: 29.7012399, Test Loss: 53.1489950


Training epoch 8: 100it [00:13,  7.52it/s]
Testing epoch 8: 100it [00:07, 12.76it/s]


Epoch: 8, Train Loss: 28.7672603, Test Loss: 52.6429931


Training epoch 9: 100it [00:12,  8.03it/s]
Testing epoch 9: 100it [00:08, 11.80it/s]


Epoch: 9, Train Loss: 29.7285584, Test Loss: 55.0801048


Training epoch 10: 100it [00:12,  8.00it/s]
Testing epoch 10: 100it [00:08, 12.50it/s]


Epoch: 10, Train Loss: 28.2766436, Test Loss: 55.7554107


Training epoch 11: 100it [00:13,  7.49it/s]
Testing epoch 11: 100it [00:08, 11.75it/s]


Epoch: 11, Train Loss: 28.3380174, Test Loss: 55.5847510


Training epoch 12: 100it [00:12,  7.99it/s]
Testing epoch 12: 100it [00:07, 12.73it/s]


Epoch: 12, Train Loss: 27.0127028, Test Loss: 53.8705847


Training epoch 13: 100it [00:14,  7.13it/s]
Testing epoch 13: 100it [00:08, 12.40it/s]


Epoch: 13, Train Loss: 24.5309881, Test Loss: 53.1018706


Training epoch 14: 100it [00:13,  7.46it/s]
Testing epoch 14: 100it [00:08, 12.12it/s]


Epoch: 14, Train Loss: 26.3508633, Test Loss: 55.3308822


Training epoch 15: 100it [00:13,  7.35it/s]
Testing epoch 15: 100it [00:08, 11.30it/s]


Epoch: 15, Train Loss: 26.2706847, Test Loss: 56.8740396


Training epoch 16: 100it [00:12,  8.05it/s]
Testing epoch 16: 100it [00:07, 12.64it/s]


Epoch: 16, Train Loss: 26.8184510, Test Loss: 56.3903770


Training epoch 17: 100it [00:12,  8.04it/s]
Testing epoch 17: 100it [00:07, 13.11it/s]


Epoch: 17, Train Loss: 30.3111977, Test Loss: 54.0671284


Training epoch 18: 100it [00:12,  7.78it/s]
Testing epoch 18: 100it [00:08, 12.17it/s]


Epoch: 18, Train Loss: 29.5686524, Test Loss: 54.1220540


Training epoch 19: 100it [00:12,  7.93it/s]
Testing epoch 19: 100it [00:08, 11.89it/s]


Epoch: 19, Train Loss: 29.0328945, Test Loss: 55.5623541


Training epoch 20: 100it [00:12,  7.88it/s]
Testing epoch 20: 100it [00:07, 12.52it/s]


Epoch: 20, Train Loss: 28.5183823, Test Loss: 53.8280083


Training epoch 21: 100it [00:12,  7.79it/s]
Testing epoch 21: 100it [00:08, 12.37it/s]


Epoch: 21, Train Loss: 26.6497205, Test Loss: 53.1427727


Training epoch 22: 100it [00:12,  7.84it/s]
Testing epoch 22: 100it [00:07, 12.76it/s]


Epoch: 22, Train Loss: 25.3046938, Test Loss: 63.7640471


Training epoch 23: 100it [00:12,  7.92it/s]
Testing epoch 23: 100it [00:08, 11.71it/s]


Epoch: 23, Train Loss: 29.5406842, Test Loss: 54.1442893


Training epoch 24: 100it [00:12,  7.78it/s]
Testing epoch 24: 100it [00:08, 11.78it/s]

Epoch: 24, Train Loss: 25.5905694, Test Loss: 65.5652723





In [26]:
print(losses)

std = loader.mean_and_std["measurements"][1]["voltage"]

#read saved model
model.load_state_dict(torch.load(f"../models/A3TGCN_{now}_{trafo_id}_epochs-{epochs}_in-{num_timesteps_in}_out-{num_timesteps_out}_train-ratio-{train_ratio}_lr-{learning_rate}.pt"))

loss_all, loss_elementwise = eval(model, eval_dataset, device, loss_fn, std)

print("Loss all: {:.7f}".format(loss_all))
print("Loss elementwise: {}".format(loss_elementwise))

([28.32147641479969, 28.830927565693855, 28.77938039600849, 27.73162615299225, 29.134823963046074, 28.567350551486015, 28.25940763950348, 29.70123991370201, 28.767260268330574, 29.72855842113495, 28.27664363384247, 28.33801744878292, 27.012702763080597, 24.53098814189434, 26.350863330066204, 26.27068467438221, 26.818450957536697, 30.3111976608634, 29.568652421236038, 29.032894492149353, 28.51838231086731, 26.649720519781113, 25.304693818092346, 29.54068424552679, 25.590569399297237], [53.387880086898804, 53.16044998168945, 53.94041621685028, 54.03003865480423, 54.03413638472557, 53.455528885126114, 52.44004997611046, 53.14899501204491, 52.64299312233925, 55.08010482788086, 55.75541067123413, 55.584750950336456, 53.87058472633362, 53.10187056660652, 55.33088222146034, 56.87403964996338, 56.39037698507309, 54.06712844967842, 54.12205395102501, 55.56235411763191, 53.82800829410553, 53.14277270436287, 63.764047145843506, 54.14428931474686, 65.5652723312378])


Evaluating: 1000it [01:29, 11.17it/s]

Loss all: 1.7337504
Loss elementwise: [[1.8264438 1.910499  1.9089893 2.34431  ]
 [1.2350519 1.1781158 1.1964612 1.2133894]
 [1.2350519 1.1781158 1.1964612 1.2133894]
 [1.2350519 1.1781158 1.1964612 1.2133894]
 [1.2350519 1.1781158 1.1964612 1.2133894]
 [1.2350519 1.1781158 1.1964612 1.2133894]
 [1.2350519 1.1781158 1.1964612 1.2133894]
 [1.2459556 1.1895136 1.20825   1.2275467]
 [1.9770947 2.144172  2.302701  2.4157524]
 [1.193372  1.2228383 1.2419189 1.2864403]
 [1.1142725 1.182351  1.2504691 1.3010877]
 [1.2235099 1.3269868 1.44387   1.5313154]
 [1.9988277 2.2544105 2.422819  2.6233253]
 [1.2865878 1.3851172 1.5153363 1.6147095]
 [2.251656  2.511316  2.7086565 2.8444064]
 [1.2837005 1.3759886 1.5044025 1.6025378]
 [1.3216758 1.4354705 1.5749004 1.6808898]
 [1.5365175 1.4983807 1.5241097 1.6575949]
 [1.7785938 2.0137718 2.0231822 2.3108385]
 [0.9720015 0.9641293 1.0103028 1.0649099]
 [2.4104152 2.6584172 2.850805  2.9960706]
 [2.4188344 2.6657777 2.8554735 3.005189 ]
 [2.376253  2.66


