In [55]:
import torch
import torch.nn.functional
import torch_geometric.nn 
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import pandas as pd
import networkx as nx
import numpy as np
import tqdm
import random
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import init
from torch_geometric.data import DataLoader


In [56]:
pkl_graphs = pd.read_pickle('/Users/morten/Desktop/p5 kode/5-semester/Momentum graphs Bayern only.pkl')

In [57]:
unique_nodes = []
delete_idx = []
for idx, graph in pkl_graphs.items():
    if idx.endswith('45'):
        delete_idx.append(idx)
        continue
    temp = [node for node in graph.nodes() if node not in unique_nodes]
    unique_nodes.extend(temp)
for idx in delete_idx:
    del pkl_graphs[idx]


In [58]:
idx_to_pos = dict(enumerate(unique_nodes))
pos_to_idx = {pos : idx for idx, pos in idx_to_pos.items()}

In [59]:
pyg_data = []

#add node attributes
for graph in pkl_graphs.values():
    filtered_edges = [(u, v) for u, v, d in graph.edges(data=True) if d['weight'] > 3]
    filtered_graph = graph.edge_subgraph(filtered_edges)

    closeness = nx.closeness_centrality(filtered_graph)
    betweenness = nx.closeness_centrality(filtered_graph)
    pagerank = nx.pagerank(graph, weight='weight')
    centrality_list = [closeness, betweenness, pagerank] 

    adj_dict = nx.to_dict_of_dicts(graph)
    
    for node in list(graph.nodes()):
        adj_vect = np.zeros((len(unique_nodes)))
        players = adj_dict[node]
        for key, value in players.items():
            adj_vect[pos_to_idx[key]] = value['weight']
        adj_vect = torch.from_numpy(adj_vect).float()
        centrality_vect = []
        for measure in centrality_list:
            if node in list(measure.keys()):
                centrality_vect.append(measure[node])
            else:
                centrality_vect.append(0)
        centrality_vect = torch.Tensor(centrality_vect).float()        
        graph.nodes[node]['x'] = torch.cat((adj_vect, centrality_vect), -1)

    for node in unique_nodes:
        if node not in graph.nodes:
            graph.add_node(node) 
            graph.nodes[node]['x'] = torch.from_numpy(np.zeros((len(unique_nodes)+3))).float()  
            

    data = from_networkx(graph)
    try:
        data.momentum
        pyg_data.append(data)
    except:
        print(data)


Data(x=[23, 26], edge_index=[2, 22], weight=[22])


In [60]:
momentum_values = [data.momentum for data in pyg_data]

momentum_min = min(momentum_values)
momentum_max = max(momentum_values)

for data in pyg_data:
    normalized_momentum = (data.momentum - momentum_min) / (momentum_max - momentum_min)
    
    scaled_momentum = 2 * normalized_momentum - 1
    data.momentum = scaled_momentum

In [61]:
train_idx = random.sample(range(len(pyg_data)), int(len(pyg_data) * 0.8))
test_idx = [i for i in range(len(pyg_data)) if i not in train_idx]

In [62]:
class GAT(torch.nn.Module):
    def __init__(self, alpha, input_dim, hidden_dim, output_dim, num_heads, dropout=0.4):
        super(GAT, self).__init__()
        self.dropout_rate = dropout
        
        self.layer1 = GATConv(input_dim, hidden_dim, heads=num_heads, dropout=dropout)
        self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout)  # Reduced hidden_dim
        self.layer3 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout)  # Further reduction
        self.layer4 = GATConv(hidden_dim * num_heads, output_dim, heads=1, concat=False, dropout=dropout)
        
        self.activation_function = nn.ELU(alpha=alpha)
        self.final_activation = nn.Tanh()


    def forward(self, input, edge_index):
        output = self.layer1(input, edge_index)
        output = self.activation_function(output)
        output = F.dropout(output, p=self.dropout_rate, training=self.training)
        
        output = self.layer2(output, edge_index)
        output = self.activation_function(output)
        output = F.dropout(output, p=self.dropout_rate, training=self.training)
        
        output = self.layer3(output, edge_index)
        output = self.activation_function(output)
        output = F.dropout(output, p=self.dropout_rate, training=self.training)
        
        output = self.layer4(output, edge_index)
        
        output = output.mean(dim=0)
        output = self.final_activation(output)
        
        return output


In [63]:
input_dim = len(unique_nodes)+3
lr = 0.001

gat = GAT(alpha=0.005, input_dim = input_dim, hidden_dim = 100, output_dim = 1, num_heads = 5)
optimizer = torch.optim.SGD(gat.parameters(), lr=lr, weight_decay=1e-4)
loss_fn = torch.nn.MSELoss()
epochs_num = 100

In [None]:
for epoch in tqdm.tqdm(range(epochs_num)):
    
    epoch_loss = 0
    for idx in tqdm.tqdm(range(len(train_idx))):
        
        input = pyg_data[idx].x
        edge_idx = pyg_data[idx].edge_index
        label = pyg_data[idx]['momentum']
        
        # Forward pass
        optimizer.zero_grad()
        
        output = gat(input, edge_idx)
        

        loss = loss_fn(output, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gat.parameters(), max_norm=1.0)

        for p in gat.parameters():
            p.data.add_(p.grad.data, alpha=-lr)
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_idx):.4f}") 
gat.eval()

  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 431/431 [00:01<00:00, 215.78it/s]
  1%|          | 1/100 [00:01<03:17,  2.00s/it]

Epoch 1, Loss: 0.1036


 95%|█████████▍| 409/431 [00:01<00:00, 230.78it/s]
  1%|          | 1/100 [00:03<06:13,  3.77s/it]


KeyboardInterrupt: 

23

In [41]:
y_pred  = []
y_true = []

with torch.no_grad():
    for idx in test_idx:
        output = gat(pyg_data[idx].x, pyg_data[idx].edge_index)
        y_pred.append(output.numpy())
        y_true.append(pyg_data[idx].momentum.numpy())



In [42]:
from sklearn.metrics import mean_absolute_error

mean_absolute_error(y_true, y_pred)


0.28043282