In [23]:
import torch
import torch.nn as nn
import torch.nn.functional
import torch_geometric.nn 
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import pandas as pd
import networkx as nx
import numpy as np
import tqdm
import random
from torch_geometric.nn import GATConv


In [11]:
pkl_graphs = pd.read_pickle('/Users/MathildeStouby/Desktop/P5 GitHub/5-semester/Momentum graphs.pkl')

In [12]:
unique_nodes = []
for graph in pkl_graphs.values():
    temp = [node for node in graph.nodes() if node not in unique_nodes]
    unique_nodes.extend(temp)

In [13]:
idx_to_pos = dict(enumerate(unique_nodes))
pos_to_idx = {pos : idx for idx, pos in idx_to_pos.items()}

In [14]:
pyg_data = []
thors_kamp = []


for idx, graph in pkl_graphs.items():

    closeness = nx.closeness_centrality(graph)
    betweenness = nx.closeness_centrality(graph)
    pagerank = nx.pagerank(graph, weight='weight')
    centrality_list = [closeness, betweenness, pagerank] 

    #adj_dict = nx.to_dict_of_dicts(graph)

    for node in list(graph.nodes()):
        #adj_vect = np.zeros((len(unique_nodes)))
        #players = adj_dict[node]
        #for key, value in players.items():
            #adj_vect[pos_to_idx[key]] = value['weight']
        #adj_vect = torch.from_numpy(adj_vect).float()
        centrality_vect = []
        for measure in centrality_list:
            if node in list(measure.keys()):
                centrality_vect.append(measure[node])
            else:
                centrality_vect.append(0)
        centrality_vect = torch.Tensor(centrality_vect).float()        
        #graph.nodes[node]['x'] = torch.cat((adj_vect, centrality_vect), -1)
        graph.nodes[node]['x'] = centrality_vect

   
    for node in unique_nodes:
        if node not in graph.nodes:
            graph.add_node(node) 
            #graph.nodes[node]['x'] = torch.from_numpy(np.zeros((len(unique_nodes)+3))).float()
            graph.nodes[node]['x'] = torch.from_numpy(np.zeros(3)).float()  
            

    data = from_networkx(graph)

    try:
        data['momentum']
        if idx.startswith('3895275'):
            thors_kamp.append(data)
        else:
            pyg_data.append(data)
    except:
        print(data)

Data(x=[23, 3], edge_index=[2, 22], weight=[22])


In [15]:
train_idx = random.sample(range(len(pyg_data)), int(len(pyg_data) * 0.8))
test_idx = [i for i in range(len(pyg_data)) if i not in train_idx]

In [21]:
class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1= GATConv(input_dim, hidden_dim, heads=num_heads, edge_dim=1)
        self.layer2= GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, edge_dim=1)
        self.layer3 = GATConv(hidden_dim * num_heads, output_dim, heads=1, edge_dim=1, concat=False)
        self.activation_function = nn.ReLU()
     

    def forward(self, input, edge_index, edge_attr):
        output = self.activation_function(self.layer1(input, edge_index, edge_attr))
        output = self.activation_function(self.layer2(output, edge_index, edge_attr))
        output = self.layer3(output, edge_index, edge_attr)
        output = torch.mean(output)
        return output

In [24]:
input_dim = 3
lr = 0.01

gat = GAT(input_dim = input_dim, hidden_dim = 11, output_dim = 1, num_heads = 6)
optimizer = torch.optim.SGD(gat.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()
epochs_num = 100

In [25]:
for epoch in tqdm.tqdm(range(epochs_num)):
    
    epoch_loss = 0
    for idx in tqdm.tqdm(range(len(train_idx))):
        
        input = pyg_data[idx].x
        edge_idx = pyg_data[idx].edge_index
        edge_attr = pyg_data[idx].weight
        label = pyg_data[idx]['momentum']
    
        # Forward pass
        optimizer.zero_grad()
        
        output = gat(input, edge_idx, edge_attr)
        
        # Calculate loss
        loss = loss_fn(output, label)
        loss.backward()
        # optimizer.step()
        for p in gat.parameters():
            p.data.add_(p.grad.data, alpha=-lr)
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_idx):.4f}") 
gat.eval() 

100%|██████████| 421/421 [00:02<00:00, 175.47it/s]
  1%|          | 1/100 [00:02<03:57,  2.40s/it]

Epoch 1, Loss: 0.0067


100%|██████████| 421/421 [00:01<00:00, 225.63it/s]
  2%|▏         | 2/100 [00:04<03:24,  2.09s/it]

Epoch 2, Loss: 0.0066


100%|██████████| 421/421 [00:01<00:00, 230.40it/s]
  3%|▎         | 3/100 [00:06<03:11,  1.97s/it]

Epoch 3, Loss: 0.0066


100%|██████████| 421/421 [00:01<00:00, 226.60it/s]
  4%|▍         | 4/100 [00:07<03:05,  1.93s/it]

Epoch 4, Loss: 0.0066


100%|██████████| 421/421 [00:01<00:00, 235.57it/s]
  5%|▌         | 5/100 [00:09<02:58,  1.88s/it]

Epoch 5, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 236.94it/s]
  6%|▌         | 6/100 [00:11<02:53,  1.84s/it]

Epoch 6, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 227.65it/s]
  7%|▋         | 7/100 [00:13<02:51,  1.85s/it]

Epoch 7, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 235.34it/s]
  8%|▊         | 8/100 [00:15<02:48,  1.83s/it]

Epoch 8, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 228.21it/s]
  9%|▉         | 9/100 [00:17<02:46,  1.83s/it]

Epoch 9, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 234.78it/s]
 10%|█         | 10/100 [00:18<02:44,  1.82s/it]

Epoch 10, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 235.77it/s]
 11%|█         | 11/100 [00:20<02:41,  1.81s/it]

Epoch 11, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 211.95it/s]
 12%|█▏        | 12/100 [00:22<02:44,  1.87s/it]

Epoch 12, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 211.20it/s]
 13%|█▎        | 13/100 [00:24<02:45,  1.91s/it]

Epoch 13, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 223.00it/s]
 14%|█▍        | 14/100 [00:26<02:43,  1.90s/it]

Epoch 14, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 233.45it/s]
 15%|█▌        | 15/100 [00:28<02:39,  1.87s/it]

Epoch 15, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 232.47it/s]
 16%|█▌        | 16/100 [00:30<02:35,  1.85s/it]

Epoch 16, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 240.01it/s]
 17%|█▋        | 17/100 [00:31<02:31,  1.83s/it]

Epoch 17, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 235.15it/s]
 18%|█▊        | 18/100 [00:33<02:28,  1.82s/it]

Epoch 18, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 231.88it/s]
 19%|█▉        | 19/100 [00:35<02:27,  1.82s/it]

Epoch 19, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 233.93it/s]
 20%|██        | 20/100 [00:37<02:24,  1.81s/it]

Epoch 20, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 236.60it/s]
 21%|██        | 21/100 [00:39<02:22,  1.80s/it]

Epoch 21, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 240.51it/s]
 22%|██▏       | 22/100 [00:40<02:19,  1.79s/it]

Epoch 22, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 234.82it/s]
 23%|██▎       | 23/100 [00:42<02:17,  1.79s/it]

Epoch 23, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 229.71it/s]
 24%|██▍       | 24/100 [00:44<02:17,  1.80s/it]

Epoch 24, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 233.20it/s]
 25%|██▌       | 25/100 [00:46<02:15,  1.81s/it]

Epoch 25, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 232.81it/s]
 26%|██▌       | 26/100 [00:48<02:13,  1.81s/it]

Epoch 26, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 237.64it/s]
 27%|██▋       | 27/100 [00:49<02:11,  1.80s/it]

Epoch 27, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 232.35it/s]
 28%|██▊       | 28/100 [00:51<02:09,  1.80s/it]

Epoch 28, Loss: 0.0065


100%|██████████| 421/421 [00:01<00:00, 233.40it/s]
 29%|██▉       | 29/100 [00:53<02:08,  1.80s/it]

Epoch 29, Loss: 0.0065


100%|██████████| 421/421 [00:02<00:00, 201.21it/s]
 30%|███       | 30/100 [00:55<02:12,  1.89s/it]

Epoch 30, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 232.82it/s]
 31%|███       | 31/100 [00:57<02:08,  1.87s/it]

Epoch 31, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 223.35it/s]
 32%|███▏      | 32/100 [00:59<02:07,  1.87s/it]

Epoch 32, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 232.95it/s]
 33%|███▎      | 33/100 [01:01<02:04,  1.85s/it]

Epoch 33, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 236.53it/s]
 34%|███▍      | 34/100 [01:02<02:00,  1.83s/it]

Epoch 34, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 235.91it/s]
 35%|███▌      | 35/100 [01:04<01:58,  1.82s/it]

Epoch 35, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 219.08it/s]
 36%|███▌      | 36/100 [01:06<01:58,  1.85s/it]

Epoch 36, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 227.74it/s]
 37%|███▋      | 37/100 [01:08<01:56,  1.85s/it]

Epoch 37, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 198.29it/s]
 38%|███▊      | 38/100 [01:10<01:59,  1.93s/it]

Epoch 38, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 199.47it/s]
 39%|███▉      | 39/100 [01:12<02:01,  1.99s/it]

Epoch 39, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 201.51it/s]
 40%|████      | 40/100 [01:14<02:01,  2.02s/it]

Epoch 40, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 223.60it/s]
 41%|████      | 41/100 [01:16<01:56,  1.98s/it]

Epoch 41, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 208.83it/s]
 42%|████▏     | 42/100 [01:18<01:55,  1.99s/it]

Epoch 42, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 233.85it/s]
 43%|████▎     | 43/100 [01:20<01:50,  1.93s/it]

Epoch 43, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 210.33it/s]
 44%|████▍     | 44/100 [01:22<01:49,  1.96s/it]

Epoch 44, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 193.69it/s]
 45%|████▌     | 45/100 [01:24<01:51,  2.02s/it]

Epoch 45, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 202.38it/s]
 46%|████▌     | 46/100 [01:26<01:50,  2.04s/it]

Epoch 46, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 200.81it/s]
 47%|████▋     | 47/100 [01:28<01:49,  2.06s/it]

Epoch 47, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 201.89it/s]
 48%|████▊     | 48/100 [01:30<01:47,  2.07s/it]

Epoch 48, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 232.52it/s]
 49%|████▉     | 49/100 [01:32<01:41,  1.99s/it]

Epoch 49, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 236.82it/s]
 50%|█████     | 50/100 [01:34<01:36,  1.93s/it]

Epoch 50, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 203.17it/s]
 51%|█████     | 51/100 [01:36<01:36,  1.97s/it]

Epoch 51, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 216.85it/s]
 52%|█████▏    | 52/100 [01:38<01:34,  1.96s/it]

Epoch 52, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 206.67it/s]
 53%|█████▎    | 53/100 [01:40<01:33,  1.99s/it]

Epoch 53, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 202.85it/s]
 54%|█████▍    | 54/100 [01:42<01:32,  2.01s/it]

Epoch 54, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 187.31it/s]
 55%|█████▌    | 55/100 [01:44<01:33,  2.09s/it]

Epoch 55, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 187.70it/s]
 56%|█████▌    | 56/100 [01:47<01:33,  2.13s/it]

Epoch 56, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 232.76it/s]
 57%|█████▋    | 57/100 [01:48<01:27,  2.04s/it]

Epoch 57, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 229.13it/s]
 58%|█████▊    | 58/100 [01:50<01:23,  1.98s/it]

Epoch 58, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 218.17it/s]
 59%|█████▉    | 59/100 [01:52<01:20,  1.96s/it]

Epoch 59, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 191.10it/s]
 60%|██████    | 60/100 [01:54<01:21,  2.04s/it]

Epoch 60, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 197.36it/s]
 61%|██████    | 61/100 [01:57<01:20,  2.07s/it]

Epoch 61, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 189.63it/s]
 62%|██████▏   | 62/100 [01:59<01:20,  2.11s/it]

Epoch 62, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 206.49it/s]
 63%|██████▎   | 63/100 [02:01<01:17,  2.09s/it]

Epoch 63, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 221.49it/s]
 64%|██████▍   | 64/100 [02:03<01:13,  2.04s/it]

Epoch 64, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 227.26it/s]
 65%|██████▌   | 65/100 [02:05<01:09,  1.98s/it]

Epoch 65, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 217.21it/s]
 66%|██████▌   | 66/100 [02:07<01:06,  1.97s/it]

Epoch 66, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 224.84it/s]
 67%|██████▋   | 67/100 [02:08<01:04,  1.94s/it]

Epoch 67, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 224.58it/s]
 68%|██████▊   | 68/100 [02:10<01:01,  1.92s/it]

Epoch 68, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 222.44it/s]
 69%|██████▉   | 69/100 [02:12<00:59,  1.91s/it]

Epoch 69, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 224.08it/s]
 70%|███████   | 70/100 [02:14<00:57,  1.90s/it]

Epoch 70, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 223.24it/s]
 71%|███████   | 71/100 [02:16<00:55,  1.90s/it]

Epoch 71, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 222.60it/s]
 72%|███████▏  | 72/100 [02:18<00:53,  1.90s/it]

Epoch 72, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 208.77it/s]
 73%|███████▎  | 73/100 [02:20<00:52,  1.93s/it]

Epoch 73, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 201.45it/s]
 74%|███████▍  | 74/100 [02:22<00:51,  1.98s/it]

Epoch 74, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 194.99it/s]
 75%|███████▌  | 75/100 [02:24<00:50,  2.04s/it]

Epoch 75, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 185.21it/s]
 76%|███████▌  | 76/100 [02:26<00:50,  2.11s/it]

Epoch 76, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 184.88it/s]
 77%|███████▋  | 77/100 [02:29<00:49,  2.16s/it]

Epoch 77, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 188.89it/s]
 78%|███████▊  | 78/100 [02:31<00:47,  2.18s/it]

Epoch 78, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 188.33it/s]
 79%|███████▉  | 79/100 [02:33<00:46,  2.20s/it]

Epoch 79, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 208.08it/s]
 80%|████████  | 80/100 [02:35<00:42,  2.15s/it]

Epoch 80, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 192.40it/s]
 81%|████████  | 81/100 [02:37<00:41,  2.16s/it]

Epoch 81, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 168.21it/s]
 82%|████████▏ | 82/100 [02:40<00:40,  2.26s/it]

Epoch 82, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 164.54it/s]
 83%|████████▎ | 83/100 [02:42<00:40,  2.35s/it]

Epoch 83, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 209.28it/s]
 84%|████████▍ | 84/100 [02:44<00:36,  2.25s/it]

Epoch 84, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 211.51it/s]
 85%|████████▌ | 85/100 [02:46<00:32,  2.17s/it]

Epoch 85, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 202.61it/s]
 86%|████████▌ | 86/100 [02:49<00:30,  2.15s/it]

Epoch 86, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 208.95it/s]
 87%|████████▋ | 87/100 [02:51<00:27,  2.11s/it]

Epoch 87, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 212.01it/s]
 88%|████████▊ | 88/100 [02:53<00:24,  2.07s/it]

Epoch 88, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 204.99it/s]
 89%|████████▉ | 89/100 [02:55<00:22,  2.07s/it]

Epoch 89, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 196.08it/s]
 90%|█████████ | 90/100 [02:57<00:20,  2.09s/it]

Epoch 90, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 191.23it/s]
 91%|█████████ | 91/100 [02:59<00:19,  2.13s/it]

Epoch 91, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 195.98it/s]
 92%|█████████▏| 92/100 [03:01<00:17,  2.13s/it]

Epoch 92, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 202.95it/s]
 93%|█████████▎| 93/100 [03:03<00:14,  2.12s/it]

Epoch 93, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 218.16it/s]
 94%|█████████▍| 94/100 [03:05<00:12,  2.06s/it]

Epoch 94, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 220.28it/s]
 95%|█████████▌| 95/100 [03:07<00:10,  2.02s/it]

Epoch 95, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 225.06it/s]
 96%|█████████▌| 96/100 [03:09<00:07,  1.97s/it]

Epoch 96, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 214.56it/s]
 97%|█████████▋| 97/100 [03:11<00:05,  1.97s/it]

Epoch 97, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 206.29it/s]
 98%|█████████▊| 98/100 [03:13<00:03,  1.99s/it]

Epoch 98, Loss: 0.0064


100%|██████████| 421/421 [00:01<00:00, 218.38it/s]
 99%|█████████▉| 99/100 [03:15<00:01,  1.97s/it]

Epoch 99, Loss: 0.0064


100%|██████████| 421/421 [00:02<00:00, 190.13it/s]
100%|██████████| 100/100 [03:17<00:00,  1.98s/it]

Epoch 100, Loss: 0.0064





GAT(
  (layer1): GATConv(3, 11, heads=6)
  (layer2): GATConv(66, 11, heads=6)
  (layer3): GATConv(66, 1, heads=1)
  (activation_function): ReLU()
)

In [26]:
y_pred  = []
y_true = []

with torch.no_grad():
    for idx in test_idx:
        batch = torch.zeros(pyg_data[idx].x.size(0), dtype=torch.long)
        output = gat(pyg_data[idx].x, pyg_data[idx].edge_index, pyg_data[idx].edge_attr)
        y_pred.append(output.item())
        y_true.append(pyg_data[idx].momentum)
        
from sklearn.metrics import r2_score
r2_score(y_true, y_pred)

0.01789867722752314