In [2]:
import torch
import torch.nn as nn
import torch.nn.functional
import torch_geometric.nn 
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import pandas as pd
import networkx as nx
import numpy as np
import tqdm
import random
from torch_geometric.nn import GATv2Conv, global_mean_pool
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader


In [3]:
pkl_graphs = pd.read_pickle('/Users/MathildeStouby/Desktop/P5 GitHub/5-semester/Momentum graphs.pkl')

In [4]:
unique_nodes = []
for value in pkl_graphs.values():
    graph = value['graph']

    temp = [node for node in graph.nodes() if node not in unique_nodes]
    unique_nodes.extend(temp)

In [5]:
idx_to_pos = dict(enumerate(unique_nodes))
pos_to_idx = {pos : idx for idx, pos in idx_to_pos.items()}

In [6]:
pyg_data = []

#add node attributes
for value in pkl_graphs.values():
    momentum = value['momentum']

    graph = value['graph']

    closeness = nx.closeness_centrality(graph)
    betweenness = nx.closeness_centrality(graph)
    pagerank = nx.pagerank(graph, weight='weight')
    centrality_list = [closeness, betweenness, pagerank] 

    adj_dict = nx.to_dict_of_dicts(graph)

    for node in list(graph.nodes()):
        adj_vect = np.zeros((len(unique_nodes)))
        players = adj_dict[node]
        for key, value in players.items():
            adj_vect[pos_to_idx[key]] = value['weight']
        adj_vect = torch.from_numpy(adj_vect).float()
        centrality_vect = []
        for measure in centrality_list:
            if node in list(measure.keys()):
                centrality_vect.append(measure[node])
            else:
                centrality_vect.append(0)
        centrality_vect = torch.Tensor(centrality_vect).float()        
        #graph.nodes[node]['x'] = torch.cat((adj_vect, centrality_vect), -1)
        graph.nodes[node]['x'] = centrality_vect

   
    for node in unique_nodes:
        if node not in graph.nodes:
            graph.add_node(node) 
            #graph.nodes[node]['x'] = torch.from_numpy(np.zeros((len(unique_nodes)+3))).float()
            graph.nodes[node]['x'] = torch.from_numpy(np.zeros(3)).float()  
            

    data = from_networkx(graph)
    data['y'] = momentum
    try:
        data['weight']
        pyg_data.append(data)
    except:
        print(data)
   

Data(x=[23, 3], edge_index=[2, 0], y=-0.0141746799999999)
Data(x=[23, 3], edge_index=[2, 0], y=0.00802374)


In [7]:
train_idx = random.sample(range(len(pyg_data)), int(len(pyg_data) * 0.8))
test_idx = [i for i in range(len(pyg_data)) if i not in train_idx]

In [8]:
train_data = [pyg_data[idx] for idx in train_idx]
dataloader = DataLoader(train_data, batch_size=64, shuffle=True)

In [9]:
class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, v2 = True):
        super(GAT, self).__init__()
        self.layer1= GATv2Conv(input_dim, hidden_dim, heads=num_heads)
        self.layer2= GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=num_heads)
        self.layer3= GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=num_heads)
        self.layer4= GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=num_heads)
        self.layer5 = GATv2Conv(hidden_dim * num_heads, output_dim, heads=1, concat=False)
        self.activation_function = nn.ReLU()
     

    def forward(self, input, edge_index, batch):
        output = self.activation_function(self.layer1(input, edge_index))
        output = self.activation_function(self.layer2(output, edge_index))
        output = self.activation_function(self.layer3(output, edge_index))
        output = self.activation_function(self.layer4(output, edge_index))
        output = self.layer5(output, edge_index)
        output = global_mean_pool(output, batch)
        return output

In [10]:
#input_dim = len(unique_nodes)+3
input_dim = 3
lr = 0.01

gat = GAT(input_dim = input_dim, hidden_dim = 11, output_dim = 1, num_heads = 6)
optimizer = torch.optim.SGD(gat.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()
epochs_num = 100

In [11]:
for epoch in tqdm.tqdm(range(epochs_num)):
    
    epoch_loss = 0
    for batch in tqdm.tqdm(dataloader):
        # Forward pass
        optimizer.zero_grad()
        
        output = gat(batch.x, batch.edge_index, batch.batch)
        
        # Calculate loss
        loss = loss_fn(output, batch.y)
        loss.backward()

        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader):.4f}") 
gat.eval() 

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 25/25 [00:01<00:00, 20.34it/s]
  1%|          | 1/100 [00:01<02:02,  1.23s/it]

Epoch 1, Loss: 0.0002


100%|██████████| 25/25 [00:00<00:00, 32.28it/s]
  2%|▏         | 2/100 [00:02<01:34,  1.04it/s]

Epoch 2, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.41it/s]
  3%|▎         | 3/100 [00:02<01:28,  1.09it/s]

Epoch 3, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.30it/s]
  4%|▍         | 4/100 [00:03<01:22,  1.16it/s]

Epoch 4, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.92it/s]
  5%|▌         | 5/100 [00:04<01:20,  1.17it/s]

Epoch 5, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.93it/s]
  6%|▌         | 6/100 [00:05<01:18,  1.20it/s]

Epoch 6, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.96it/s]
  7%|▋         | 7/100 [00:06<01:16,  1.21it/s]

Epoch 7, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.88it/s]
  8%|▊         | 8/100 [00:06<01:14,  1.23it/s]

Epoch 8, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.55it/s]
  9%|▉         | 9/100 [00:07<01:14,  1.21it/s]

Epoch 9, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.58it/s]
 10%|█         | 10/100 [00:08<01:12,  1.24it/s]

Epoch 10, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.40it/s]
 11%|█         | 11/100 [00:09<01:12,  1.23it/s]

Epoch 11, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.68it/s]
 12%|█▏        | 12/100 [00:10<01:10,  1.25it/s]

Epoch 12, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.81it/s]
 13%|█▎        | 13/100 [00:10<01:08,  1.27it/s]

Epoch 13, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.85it/s]
 14%|█▍        | 14/100 [00:11<01:08,  1.26it/s]

Epoch 14, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.85it/s]
 15%|█▌        | 15/100 [00:12<01:08,  1.24it/s]

Epoch 15, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.98it/s]
 16%|█▌        | 16/100 [00:13<01:07,  1.25it/s]

Epoch 16, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.02it/s]
 17%|█▋        | 17/100 [00:14<01:08,  1.22it/s]

Epoch 17, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.62it/s]
 18%|█▊        | 18/100 [00:14<01:06,  1.24it/s]

Epoch 18, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.81it/s]
 19%|█▉        | 19/100 [00:15<01:05,  1.24it/s]

Epoch 19, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.75it/s]
 20%|██        | 20/100 [00:16<01:04,  1.23it/s]

Epoch 20, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.20it/s]
 21%|██        | 21/100 [00:17<01:03,  1.24it/s]

Epoch 21, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.90it/s]
 22%|██▏       | 22/100 [00:18<01:01,  1.26it/s]

Epoch 22, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 28.90it/s]
 23%|██▎       | 23/100 [00:18<01:02,  1.22it/s]

Epoch 23, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 28.12it/s]
 24%|██▍       | 24/100 [00:19<01:03,  1.19it/s]

Epoch 24, Loss: 0.0001


100%|██████████| 25/25 [00:01<00:00, 22.40it/s]
 25%|██▌       | 25/100 [00:20<01:09,  1.08it/s]

Epoch 25, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.15it/s]
 26%|██▌       | 26/100 [00:21<01:06,  1.11it/s]

Epoch 26, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 28.74it/s]
 27%|██▋       | 27/100 [00:22<01:05,  1.12it/s]

Epoch 27, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 25.58it/s]
 28%|██▊       | 28/100 [00:23<01:06,  1.09it/s]

Epoch 28, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 26.93it/s]
 29%|██▉       | 29/100 [00:24<01:05,  1.08it/s]

Epoch 29, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.72it/s]
 30%|███       | 30/100 [00:25<01:02,  1.12it/s]

Epoch 30, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.09it/s]
 31%|███       | 31/100 [00:26<01:00,  1.14it/s]

Epoch 31, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.50it/s]
 32%|███▏      | 32/100 [00:27<00:57,  1.19it/s]

Epoch 32, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.46it/s]
 33%|███▎      | 33/100 [00:27<00:56,  1.18it/s]

Epoch 33, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.36it/s]
 34%|███▍      | 34/100 [00:28<00:54,  1.21it/s]

Epoch 34, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.74it/s]
 35%|███▌      | 35/100 [00:29<00:53,  1.20it/s]

Epoch 35, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.61it/s]
 36%|███▌      | 36/100 [00:30<00:52,  1.22it/s]

Epoch 36, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.21it/s]
 37%|███▋      | 37/100 [00:31<00:51,  1.22it/s]

Epoch 37, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.98it/s]
 38%|███▊      | 38/100 [00:31<00:49,  1.24it/s]

Epoch 38, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 28.55it/s]
 39%|███▉      | 39/100 [00:32<00:50,  1.21it/s]

Epoch 39, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.93it/s]
 40%|████      | 40/100 [00:33<00:48,  1.23it/s]

Epoch 40, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.15it/s]
 41%|████      | 41/100 [00:34<00:48,  1.22it/s]

Epoch 41, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 35.57it/s]
 42%|████▏     | 42/100 [00:35<00:45,  1.27it/s]

Epoch 42, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.69it/s]
 43%|████▎     | 43/100 [00:35<00:45,  1.24it/s]

Epoch 43, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 34.63it/s]
 44%|████▍     | 44/100 [00:36<00:43,  1.28it/s]

Epoch 44, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.44it/s]
 45%|████▌     | 45/100 [00:37<00:42,  1.29it/s]

Epoch 45, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 34.34it/s]
 46%|████▌     | 46/100 [00:38<00:41,  1.31it/s]

Epoch 46, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.45it/s]
 47%|████▋     | 47/100 [00:39<00:41,  1.28it/s]

Epoch 47, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.80it/s]
 48%|████▊     | 48/100 [00:39<00:40,  1.29it/s]

Epoch 48, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.00it/s]
 49%|████▉     | 49/100 [00:40<00:40,  1.26it/s]

Epoch 49, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 26.67it/s]
 50%|█████     | 50/100 [00:41<00:41,  1.19it/s]

Epoch 50, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.33it/s]
 51%|█████     | 51/100 [00:42<00:41,  1.19it/s]

Epoch 51, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.68it/s]
 52%|█████▏    | 52/100 [00:43<00:39,  1.22it/s]

Epoch 52, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 28.35it/s]
 53%|█████▎    | 53/100 [00:44<00:39,  1.19it/s]

Epoch 53, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.90it/s]
 54%|█████▍    | 54/100 [00:44<00:38,  1.19it/s]

Epoch 54, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.43it/s]
 55%|█████▌    | 55/100 [00:45<00:37,  1.20it/s]

Epoch 55, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 33.90it/s]
 56%|█████▌    | 56/100 [00:46<00:35,  1.24it/s]

Epoch 56, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.89it/s]
 57%|█████▋    | 57/100 [00:47<00:34,  1.25it/s]

Epoch 57, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.86it/s]
 58%|█████▊    | 58/100 [00:48<00:33,  1.25it/s]

Epoch 58, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.70it/s]
 59%|█████▉    | 59/100 [00:48<00:33,  1.23it/s]

Epoch 59, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.93it/s]
 60%|██████    | 60/100 [00:49<00:32,  1.24it/s]

Epoch 60, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.88it/s]
 61%|██████    | 61/100 [00:50<00:31,  1.24it/s]

Epoch 61, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 33.21it/s]
 62%|██████▏   | 62/100 [00:51<00:30,  1.26it/s]

Epoch 62, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.53it/s]
 63%|██████▎   | 63/100 [00:52<00:29,  1.26it/s]

Epoch 63, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.25it/s]
 64%|██████▍   | 64/100 [00:52<00:28,  1.26it/s]

Epoch 64, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.02it/s]
 65%|██████▌   | 65/100 [00:53<00:27,  1.25it/s]

Epoch 65, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 35.31it/s]
 66%|██████▌   | 66/100 [00:54<00:26,  1.29it/s]

Epoch 66, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.18it/s]
 67%|██████▋   | 67/100 [00:55<00:26,  1.26it/s]

Epoch 67, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.22it/s]
 68%|██████▊   | 68/100 [00:55<00:25,  1.26it/s]

Epoch 68, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.41it/s]
 69%|██████▉   | 69/100 [00:56<00:24,  1.24it/s]

Epoch 69, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.75it/s]
 70%|███████   | 70/100 [00:57<00:23,  1.25it/s]

Epoch 70, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.24it/s]
 71%|███████   | 71/100 [00:58<00:23,  1.24it/s]

Epoch 71, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 34.11it/s]
 72%|███████▏  | 72/100 [00:59<00:22,  1.27it/s]

Epoch 72, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.59it/s]
 73%|███████▎  | 73/100 [00:59<00:21,  1.27it/s]

Epoch 73, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 33.43it/s]
 74%|███████▍  | 74/100 [01:00<00:20,  1.29it/s]

Epoch 74, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.14it/s]
 75%|███████▌  | 75/100 [01:01<00:19,  1.26it/s]

Epoch 75, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.25it/s]
 76%|███████▌  | 76/100 [01:02<00:18,  1.27it/s]

Epoch 76, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.21it/s]
 77%|███████▋  | 77/100 [01:03<00:18,  1.25it/s]

Epoch 77, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 34.19it/s]
 78%|███████▊  | 78/100 [01:03<00:17,  1.28it/s]

Epoch 78, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.75it/s]
 79%|███████▉  | 79/100 [01:04<00:16,  1.26it/s]

Epoch 79, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.96it/s]
 80%|████████  | 80/100 [01:05<00:15,  1.27it/s]

Epoch 80, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.78it/s]
 81%|████████  | 81/100 [01:06<00:15,  1.24it/s]

Epoch 81, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.23it/s]
 82%|████████▏ | 82/100 [01:07<00:14,  1.25it/s]

Epoch 82, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.53it/s]
 83%|████████▎ | 83/100 [01:07<00:13,  1.24it/s]

Epoch 83, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.49it/s]
 84%|████████▍ | 84/100 [01:08<00:12,  1.23it/s]

Epoch 84, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.14it/s]
 85%|████████▌ | 85/100 [01:09<00:12,  1.21it/s]

Epoch 85, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.58it/s]
 86%|████████▌ | 86/100 [01:10<00:11,  1.24it/s]

Epoch 86, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.90it/s]
 87%|████████▋ | 87/100 [01:11<00:10,  1.22it/s]

Epoch 87, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.79it/s]
 88%|████████▊ | 88/100 [01:12<00:09,  1.24it/s]

Epoch 88, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.60it/s]
 89%|████████▉ | 89/100 [01:12<00:08,  1.23it/s]

Epoch 89, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.73it/s]
 90%|█████████ | 90/100 [01:13<00:07,  1.25it/s]

Epoch 90, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.71it/s]
 91%|█████████ | 91/100 [01:14<00:07,  1.23it/s]

Epoch 91, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 31.97it/s]
 92%|█████████▏| 92/100 [01:15<00:06,  1.24it/s]

Epoch 92, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.88it/s]
 93%|█████████▎| 93/100 [01:16<00:05,  1.24it/s]

Epoch 93, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 34.53it/s]
 94%|█████████▍| 94/100 [01:16<00:04,  1.28it/s]

Epoch 94, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 30.35it/s]
 95%|█████████▌| 95/100 [01:17<00:03,  1.26it/s]

Epoch 95, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.46it/s]
 96%|█████████▌| 96/100 [01:18<00:03,  1.27it/s]

Epoch 96, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 28.80it/s]
 97%|█████████▋| 97/100 [01:19<00:02,  1.23it/s]

Epoch 97, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.02it/s]
 98%|█████████▊| 98/100 [01:20<00:01,  1.24it/s]

Epoch 98, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 29.55it/s]
 99%|█████████▉| 99/100 [01:20<00:00,  1.22it/s]

Epoch 99, Loss: 0.0001


100%|██████████| 25/25 [00:00<00:00, 32.10it/s]
100%|██████████| 100/100 [01:21<00:00,  1.22it/s]

Epoch 100, Loss: 0.0001





GAT(
  (layer1): GATv2Conv(3, 11, heads=6)
  (layer2): GATv2Conv(66, 11, heads=6)
  (layer3): GATv2Conv(66, 11, heads=6)
  (layer4): GATv2Conv(66, 11, heads=6)
  (layer5): GATv2Conv(66, 1, heads=1)
  (activation_function): ReLU()
)

In [12]:
from sklearn.metrics import mean_absolute_percentage_error

y_pred  = []
y_true = []
batch = torch.from_numpy(np.zeros(23)).long()

with torch.no_grad():
    for idx in test_idx:
        batch = torch.zeros(pyg_data[idx].x.size(0), dtype=torch.long)
        output = gat(pyg_data[idx].x, pyg_data[idx].edge_index, batch)
        y_pred.append(output.item())  # Convert to numpy for MAPE
        y_true.append(pyg_data[idx].y)

mean_absolute_percentage_error(y_true, y_pred)

1.0931291884958338