In [1]:
import torch
import torch.nn as nn
import torch.nn.functional
import torch_geometric.nn 
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import pandas as pd
import networkx as nx
import numpy as np
import tqdm
import random
from torch_geometric.nn import GATv2Conv


In [2]:
pkl_graphs = pd.read_pickle('/Users/MathildeStouby/Desktop/P5 GitHub/5-semester/Momentum graphs.pkl')

In [3]:
unique_nodes = []
for graph in pkl_graphs.values():
    temp = [node for node in graph.nodes() if node not in unique_nodes]
    unique_nodes.extend(temp)

In [4]:
idx_to_pos = dict(enumerate(unique_nodes))
pos_to_idx = {pos : idx for idx, pos in idx_to_pos.items()}

In [5]:
pyg_data = []

#add node attributes
for graph in pkl_graphs.values():
    filtered_edges = [(u, v) for u, v, d in graph.edges(data=True) if d['weight'] > 3]
    filtered_graph = graph.edge_subgraph(filtered_edges)

    closeness = nx.closeness_centrality(filtered_graph)
    betweenness = nx.closeness_centrality(filtered_graph)
    pagerank = nx.pagerank(graph, weight='weight')
    centrality_list = [closeness, betweenness, pagerank] 

    adj_dict = nx.to_dict_of_dicts(graph)
    
    for node in list(graph.nodes()):
        adj_vect = np.zeros((len(unique_nodes)))
        players = adj_dict[node]
        for key, value in players.items():
            adj_vect[pos_to_idx[key]] = value['weight']
        adj_vect = torch.from_numpy(adj_vect).float()
        centrality_vect = []
        for measure in centrality_list:
            if node in list(measure.keys()):
                centrality_vect.append(measure[node])
            else:
                centrality_vect.append(0)
        centrality_vect = torch.Tensor(centrality_vect).float()        
        graph.nodes[node]['x'] = torch.cat((adj_vect, centrality_vect), -1)

    for node in unique_nodes:
        if node not in graph.nodes:
            graph.add_node(node) 
            graph.nodes[node]['x'] = torch.from_numpy(np.zeros((len(unique_nodes)+3))).float()  
            

    data = from_networkx(graph)
    try:
        data.momentum
        pyg_data.append(data)
    except:
        print(data)
   

Data(x=[23, 26], edge_index=[2, 92], weight=[92])


In [6]:
train_idx = random.sample(range(len(pyg_data)), int(len(pyg_data) * 0.8))
test_idx = [i for i in range(len(pyg_data)) if i not in train_idx]

In [7]:
class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, v2 = True):
        super(GAT, self).__init__()
        self.layer1= GATv2Conv(input_dim, hidden_dim, heads=num_heads)
        self.layer2= GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=num_heads)
        self.layer3 = GATv2Conv(hidden_dim * num_heads, output_dim, heads=1, concat=False)
        self.activation_function = nn.Tanh()
     

    def forward(self, input, edge_index):
        output = self.activation_function(self.layer1(input, edge_index))
        output = self.activation_function(self.layer2(output, edge_index))
        output = self.layer3(output, edge_index)
        output = output.mean(dim=0) 
        return output

In [8]:
input_dim = len(unique_nodes)+3
lr = 0.01

gat = GAT(input_dim = input_dim, hidden_dim = 5, output_dim = 1, num_heads = 8)
optimizer = torch.optim.SGD(gat.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()
epochs_num = 20

In [9]:
for epoch in tqdm.tqdm(range(epochs_num)):
    
    epoch_loss = 0
    for idx in tqdm.tqdm(range(len(train_idx))):
        
        input = pyg_data[idx].x
        edge_idx = pyg_data[idx].edge_index
        label = pyg_data[idx]['momentum']
    
        # Forward pass
        optimizer.zero_grad()
        
        output = gat(input, edge_idx)
        
        # Calculate loss
        loss = loss_fn(output, label)
        loss.backward()
        # optimizer.step()
        for p in gat.parameters():
            p.data.add_(p.grad.data, alpha=-lr)
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_idx):.4f}") 
gat.eval() 

  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 434/434 [00:02<00:00, 178.33it/s]
  5%|▌         | 1/20 [00:02<00:46,  2.44s/it]

Epoch 1, Loss: 0.0108


100%|██████████| 434/434 [00:02<00:00, 161.05it/s]
 10%|█         | 2/20 [00:05<00:46,  2.59s/it]

Epoch 2, Loss: 0.0075


100%|██████████| 434/434 [00:02<00:00, 155.49it/s]
 15%|█▌        | 3/20 [00:07<00:45,  2.68s/it]

Epoch 3, Loss: 0.0071


100%|██████████| 434/434 [00:02<00:00, 200.84it/s]
 20%|██        | 4/20 [00:10<00:39,  2.48s/it]

Epoch 4, Loss: 0.0068


100%|██████████| 434/434 [00:02<00:00, 185.60it/s]
 25%|██▌       | 5/20 [00:12<00:36,  2.43s/it]

Epoch 5, Loss: 0.0067


100%|██████████| 434/434 [00:02<00:00, 160.41it/s]
 30%|███       | 6/20 [00:15<00:35,  2.52s/it]

Epoch 6, Loss: 0.0066


100%|██████████| 434/434 [00:02<00:00, 199.34it/s]
 35%|███▌      | 7/20 [00:17<00:31,  2.41s/it]

Epoch 7, Loss: 0.0065


100%|██████████| 434/434 [00:02<00:00, 205.48it/s]
 40%|████      | 8/20 [00:19<00:27,  2.32s/it]

Epoch 8, Loss: 0.0064


100%|██████████| 434/434 [00:02<00:00, 201.52it/s]
 45%|████▌     | 9/20 [00:21<00:24,  2.27s/it]

Epoch 9, Loss: 0.0063


100%|██████████| 434/434 [00:02<00:00, 195.49it/s]
 50%|█████     | 10/20 [00:23<00:22,  2.25s/it]

Epoch 10, Loss: 0.0063


100%|██████████| 434/434 [00:02<00:00, 180.17it/s]
 55%|█████▌    | 11/20 [00:26<00:20,  2.30s/it]

Epoch 11, Loss: 0.0063


100%|██████████| 434/434 [00:02<00:00, 207.13it/s]
 60%|██████    | 12/20 [00:28<00:17,  2.24s/it]

Epoch 12, Loss: 0.0062


100%|██████████| 434/434 [00:02<00:00, 193.89it/s]
 65%|██████▌   | 13/20 [00:30<00:15,  2.24s/it]

Epoch 13, Loss: 0.0062


100%|██████████| 434/434 [00:02<00:00, 201.23it/s]
 70%|███████   | 14/20 [00:32<00:13,  2.22s/it]

Epoch 14, Loss: 0.0062


100%|██████████| 434/434 [00:02<00:00, 197.38it/s]
 75%|███████▌  | 15/20 [00:34<00:11,  2.21s/it]

Epoch 15, Loss: 0.0062


100%|██████████| 434/434 [00:02<00:00, 191.97it/s]
 80%|████████  | 16/20 [00:37<00:08,  2.23s/it]

Epoch 16, Loss: 0.0061


100%|██████████| 434/434 [00:02<00:00, 210.33it/s]
 85%|████████▌ | 17/20 [00:39<00:06,  2.18s/it]

Epoch 17, Loss: 0.0061


100%|██████████| 434/434 [00:02<00:00, 198.34it/s]
 90%|█████████ | 18/20 [00:41<00:04,  2.18s/it]

Epoch 18, Loss: 0.0061


100%|██████████| 434/434 [00:02<00:00, 180.64it/s]
 95%|█████████▌| 19/20 [00:43<00:02,  2.25s/it]

Epoch 19, Loss: 0.0061


100%|██████████| 434/434 [00:02<00:00, 188.47it/s]
100%|██████████| 20/20 [00:46<00:00,  2.31s/it]

Epoch 20, Loss: 0.0061





GAT(
  (layer1): GATv2Conv(26, 5, heads=8)
  (layer2): GATv2Conv(40, 5, heads=8)
  (layer3): GATv2Conv(40, 1, heads=1)
  (activation_function): Tanh()
)

In [12]:
y_pred  = []
y_true = []

with torch.no_grad():
    for idx in test_idx:
        output = gat(pyg_data[idx].x, pyg_data[idx].edge_index)
        y_pred.append(output)
        y_true.append(pyg_data[idx].momentum)
 

In [13]:
from sklearn.metrics import mean_absolute_error

mean_absolute_error(y_true, y_pred)

0.05980244