In [1]:
import networkx as nx
import torch
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
import random
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool

In [2]:
G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)

In [None]:
train = "random"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )

    print(train_data)
    print(val_data)
    print(test_data)



0.8155513977282852
Data(edge_index=[2, 664450], num_nodes=116979, edge_label_index=[2, 728450], edge_label=[728450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label_index=[2, 85195], edge_label=[85195])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label_index=[2, 90282], edge_label=[90282])


In [4]:
train_loader = LinkNeighborLoader(
    data = train_data,
    num_neighbors = [25, 10], #normalement [-1, -1] pour obtenir tous les voisins à une distance de 2, mais doit allourdir fortement le process et réduire peut permettre de se focus plus sur des motifs locaux que l'évaluation de tout une composante connexe
    edge_label_index = train_data.edge_label_index,
    edge_label = train_data.edge_label,
    subgraph_type="induced",
    batch_size = 1,
    shuffle = True
)
val_loader = LinkNeighborLoader(
    data = val_data,
    num_neighbors = [25, 10], #same
    edge_label_index = val_data.edge_label_index,
    edge_label = val_data.edge_label,
    subgraph_type="induced",
    batch_size = 1,
    shuffle = True
)

In [None]:
def labelling(batch):
    edge_index = batch.edge_index
    u, v = batch.edge_label_index
    mask = ~(((edge_index[0] == u) & (edge_index[1] == v)) | ((edge_index[0] == v) & (edge_index[1] == u)))
    edge_index = edge_index[:, mask]
    G = nx.Graph()
    G.add_nodes_from(range(batch.num_nodes))
    G.add_edges_from(batch.edge_index.t().tolist())
    du = nx.single_source_shortest_path_length(G, int(u))
    dv = nx.single_source_shortest_path_length(G, int(v))
    labels = []
    for n in range(batch.num_nodes):
        if n == u or n == v:
            labels.append(1)
        else:
            dun = du.get(n, "unr")
            dvn = dv.get(n, "unr")
            if dun == "unr" or dvn == "unr":
                 labels.append(0)
            else:
                labels.append(1 + min(dun, dvn) + ((dun + dvn - 2)*(dun + dvn - 1)) // 2)
    return torch.tensor(labels).unsqueeze(-1).float(), edge_index  

In [6]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x
    

class Predictor(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def forward(self, x):
        return self.mlp(x).view(-1)
    

class Model(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.gnn = GNN(in_channels, hidden_channels)
        self.predictor = Predictor(hidden_channels)
    
    def forward(self, data):
        x = self.gnn(data.x, data.edge_index)
        x = global_mean_pool(x, data.batch)
        pred = self.predictor(x)
        return pred
    

model = Model(in_channels = 1, hidden_channels = 64)

In [7]:
#from torch_geometric.nn import GlobalAttention

#att_pool = GlobalAttention(
    #gate_nn=torch.nn.Sequential(
        #torch.nn.Linear(hidden_dim, 1),
        #torch.nn.Sigmoid()
    #)
#)
#x = att_pool(x, batch)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) #tester plusieurs valeurs de lr

def train_epoch():
    model.train()
    total_loss = 0
    count = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch.x, batch.edge_index = labelling(batch)
        pred = model(batch)
        loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label.float())
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        count = count + 1
    return total_loss / count

def evaluate():
    model.eval()
    y_truth = []
    y_pred = []
    for batch in val_loader:
        batch.x, batch.edge_index = labelling(batch)
        pred = model(batch)
        y_truth.append(batch.edge_label)
        y_pred.append(torch.sigmoid(pred))
    y_truth = torch.cat(y_truth).numpy()
    y_pred = torch.cat(y_pred).detach().numpy()
    auc = roc_auc_score(y_truth, y_pred)
    ap = average_precision_score(y_truth, y_pred)
    return auc, ap

In [10]:
best_val_auc = 0
limit = 6
count = 0
train_losses = []
val_aps = []
val_aucs = []
for epoch in range(1, 50):
    loss = train_epoch()
    train_losses.append(loss)
    val_auc, val_ap = evaluate()
    val_aps.append(val_ap)
    val_aucs.append(val_auc)
    print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count =  count + 1
        if count >= limit:
            print("Early stop")
            break

Epoch 001, Loss: 0.0931, Val AUC: 0.9690, Val AP: 0.9488


KeyboardInterrupt: 

In [2]:
from torch_geometric.loader import DataLoader

G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)

train = "random"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )
    print(train_data)
    print(val_data)
    print(test_data)

def labelling(batch):
    edge_index = batch.edge_index
    u, v = batch.edge_label_index
    mask = ~(((edge_index[0] == u) & (edge_index[1] == v)) | ((edge_index[0] == v) & (edge_index[1] == u)))
    edge_index = edge_index[:, mask]
    G = nx.Graph()
    G.add_nodes_from(range(batch.num_nodes))
    G.add_edges_from(edge_index.t().tolist())
    du = nx.single_source_shortest_path_length(G, int(u))
    dv = nx.single_source_shortest_path_length(G, int(v))
    labels = []
    for n in range(batch.num_nodes):
        if n == u or n == v:
            labels.append(1)
        else:
            dun = du.get(n, "unr")
            dvn = dv.get(n, "unr")
            if dun == "unr" or dvn == "unr":
                 labels.append(0)
            else:
                labels.append(1 + min(dun, dvn) + ((dun + dvn - 2)*(dun + dvn - 1)) // 2)
    return torch.tensor(labels).unsqueeze(-1).float(), edge_index  

train_loader = LinkNeighborLoader(
    data = train_data,
    num_neighbors = [25, 10],
    edge_label_index = train_data.edge_label_index,
    edge_label = train_data.edge_label,
    subgraph_type="induced",
    batch_size = 1,
    shuffle = True
)
val_loader = LinkNeighborLoader(
    data = val_data,
    num_neighbors = [25, 10], 
    edge_label_index = val_data.edge_label_index,
    edge_label = val_data.edge_label,
    subgraph_type="induced",
    batch_size = 1,
    shuffle = True
)

train_subgraphs_data = []
for batch in train_loader:
    batch.x, batch.edge_index = labelling(batch)
    data = Data(
        x = batch.x,
        edge_index = batch.edge_index,  
        edge_label = batch.edge_label,      
        edge_label_index = batch.edge_label_index  
    )
    train_subgraphs_data.append(data)
#torch.save(subgraphs_data, "subgraphs/train_subgraphs_data.pt")
val_subgraphs_data = []
for batch in val_loader:
    batch.x, batch.edge_index = labelling(batch)
    data = Data(
        x = batch.x,
        edge_index = batch.edge_index,  
        edge_label = batch.edge_label,      
        edge_label_index = batch.edge_label_index  
    )
    val_subgraphs_data.append(data)
#torch.save(subgraphs_data, "subgraphs/train_subgraphs_data.pt")

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x
    

class Predictor(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def forward(self, x):
        return self.mlp(x).view(-1)
    

class Model(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.gnn = GNN(in_channels, hidden_channels)
        self.predictor = Predictor(hidden_channels)
    
    def forward(self, data):
        x = self.gnn(data.x, data.edge_index)
        x = global_mean_pool(x, data.batch)
        pred = self.predictor(x)
        return pred
    

model = Model(in_channels = 1, hidden_channels = 64)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

train_reloader = DataLoader(train_subgraphs_data, batch_size=16, shuffle=True)
val_reloader = DataLoader(val_subgraphs_data, batch_size=16)

def train_epoch():
    model.train()
    total_loss = 0
    count = 0
    for batch in train_reloader:
        optimizer.zero_grad()
        pred = model(batch)
        loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label.float())
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        count = count + 1
    return total_loss / count

def evaluate():
    model.eval()
    y_truth = []
    y_pred = []
    for batch in val_reloader:
        pred = model(batch)
        y_truth.append(batch.edge_label)
        y_pred.append(torch.sigmoid(pred))
    y_truth = torch.cat(y_truth).numpy()
    y_pred = torch.cat(y_pred).detach().numpy()
    auc = roc_auc_score(y_truth, y_pred)
    ap = average_precision_score(y_truth, y_pred)
    return auc, ap


Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[664450], edge_label_index=[2, 664450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])


In [3]:
best_val_auc = 0
limit = 6
count = 0
train_losses = []
val_aps = []
val_aucs = []
for epoch in range(1, 50):
    loss = train_epoch()
    train_losses.append(loss)
    val_auc, val_ap = evaluate()
    val_aps.append(val_ap)
    val_aucs.append(val_auc)
    print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count =  count + 1
        if count >= limit:
            print("Early stop")
            break

Epoch 001, Loss: 0.0375, Val AUC: 0.9988, Val AP: 0.9989
Epoch 002, Loss: 0.0343, Val AUC: 0.9988, Val AP: 0.9988
Epoch 003, Loss: 0.0338, Val AUC: 0.9989, Val AP: 0.9988
Epoch 004, Loss: 0.0332, Val AUC: 0.9989, Val AP: 0.9988
Epoch 005, Loss: 0.0325, Val AUC: 0.9989, Val AP: 0.9988
Epoch 006, Loss: 0.0328, Val AUC: 0.9989, Val AP: 0.9988
Epoch 007, Loss: 0.0323, Val AUC: 0.9988, Val AP: 0.9987
Epoch 008, Loss: 0.0322, Val AUC: 0.9989, Val AP: 0.9987
Epoch 009, Loss: 0.0320, Val AUC: 0.9989, Val AP: 0.9988
Epoch 010, Loss: 0.0321, Val AUC: 0.9989, Val AP: 0.9988
Epoch 011, Loss: 0.0326, Val AUC: 0.9989, Val AP: 0.9987
Epoch 012, Loss: 0.0434, Val AUC: 0.9988, Val AP: 0.9986
Epoch 013, Loss: 0.0323, Val AUC: 0.9989, Val AP: 0.9987
Epoch 014, Loss: 0.0352, Val AUC: 0.9989, Val AP: 0.9988
Epoch 015, Loss: 0.0324, Val AUC: 0.9989, Val AP: 0.9989
Epoch 016, Loss: 0.0338, Val AUC: 0.9989, Val AP: 0.9988
Epoch 017, Loss: 0.0324, Val AUC: 0.9989, Val AP: 0.9989
Epoch 018, Loss: 0.0325, Val AU

In [5]:
from torch_geometric.loader import DataLoader

G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)

train = "double split"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )
    print(train_data)
    print(val_data)
    print(test_data)


G_train = nx.Graph()
G_train.add_nodes_from(range(train_data.num_nodes))
G_train.add_edges_from(train_data.edge_index.t().tolist())
degree = dict(G_train.degree())
max_degree = max(degree.values())
degree_norm = {n: d/max_degree for n, d in degree.items()}
clustering = nx.clustering(G_train)
degree_tensor = torch.tensor(list(degree_norm.values()), dtype=torch.float32)
clustering_tensor = torch.tensor(list(clustering.values()), dtype=torch.float32)
train_data.x = torch.stack([degree_tensor, clustering_tensor], dim=-1)
val_data.x = train_data.x.clone()
test_data.x = train_data.x.clone()


def labelling(batch):
    edge_index = batch.edge_index
    u, v = batch.edge_label_index
    mask = ~(((edge_index[0] == u) & (edge_index[1] == v)) | ((edge_index[0] == v) & (edge_index[1] == u)))
    edge_index = edge_index[:, mask]
    G = nx.Graph()
    G.add_nodes_from(range(batch.num_nodes))
    G.add_edges_from(edge_index.t().tolist())
    du = nx.single_source_shortest_path_length(G, int(u))
    dv = nx.single_source_shortest_path_length(G, int(v))
    labels = []
    for n in range(batch.num_nodes):
        if n == u or n == v:
            labels.append(1)
        else:
            dun = du.get(n, "unr")
            dvn = dv.get(n, "unr")
            if dun == "unr" or dvn == "unr":
                 labels.append(0)
            else:
                labels.append(1 + min(dun, dvn) + ((dun + dvn - 2)*(dun + dvn - 1)) // 2)
    return torch.tensor(labels).unsqueeze(-1).float(), edge_index  

train_loader = LinkNeighborLoader(
    data = train_data,
    num_neighbors = [25, 10],
    edge_label_index = train_data.edge_label_index,
    edge_label = train_data.edge_label,
    subgraph_type="induced",
    batch_size = 1,
    shuffle = True
)
val_loader = LinkNeighborLoader(
    data = val_data,
    num_neighbors = [25, 10], 
    edge_label_index = val_data.edge_label_index,
    edge_label = val_data.edge_label,
    subgraph_type="induced",
    batch_size = 1,
    shuffle = True
)

train_subgraphs_data = []
for batch in train_loader:
    x_label, batch.edge_index = labelling(batch)
    x_stats = batch.x
    batch.x = torch.cat([x_label, x_stats], dim = 1)
    data = Data(
        x = batch.x,
        edge_index = batch.edge_index,  
        edge_label = batch.edge_label,      
        edge_label_index = batch.edge_label_index  
    )
    train_subgraphs_data.append(data)
#torch.save(subgraphs_data, "subgraphs/train_subgraphs_data.pt")
val_subgraphs_data = []
for batch in val_loader:
    x_label, batch.edge_index = labelling(batch)
    x_stats = batch.x
    batch.x = torch.cat([x_label, x_stats], dim = 1)
    data = Data(
        x = batch.x,
        edge_index = batch.edge_index,  
        edge_label = batch.edge_label,      
        edge_label_index = batch.edge_label_index  
    )
    val_subgraphs_data.append(data)
#torch.save(subgraphs_data, "subgraphs/train_subgraphs_data.pt")

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x
    

class Predictor(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def forward(self, x):
        return self.mlp(x).view(-1)
    

class Model(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.gnn = GNN(in_channels, hidden_channels)
        self.predictor = Predictor(hidden_channels)
    
    def forward(self, data):
        x = self.gnn(data.x, data.edge_index)
        x = global_mean_pool(x, data.batch)
        pred = self.predictor(x)
        return pred
    

model = Model(in_channels = 3, hidden_channels = 64)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

train_reloader = DataLoader(train_subgraphs_data, batch_size=16, shuffle=True)
val_reloader = DataLoader(val_subgraphs_data, batch_size=16)

def train_epoch():
    model.train()
    total_loss = 0
    count = 0
    for batch in train_reloader:
        optimizer.zero_grad()
        pred = model(batch)
        loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label.float())
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        count = count + 1
    return total_loss / count

def evaluate():
    model.eval()
    y_truth = []
    y_pred = []
    for batch in val_reloader:
        pred = model(batch)
        y_truth.append(batch.edge_label)
        y_pred.append(torch.sigmoid(pred))
    y_truth = torch.cat(y_truth).numpy()
    y_pred = torch.cat(y_pred).detach().numpy()
    auc = roc_auc_score(y_truth, y_pred)
    ap = average_precision_score(y_truth, y_pred)
    return auc, ap

best_val_auc = 0
limit = 6
count = 0
train_losses = []
val_aps = []
val_aucs = []
for epoch in range(1, 50):
    loss = train_epoch()
    train_losses.append(loss)
    val_auc, val_ap = evaluate()
    val_aps.append(val_ap)
    val_aucs.append(val_auc)
    print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count =  count + 1
        if count >= limit:
            print("Early stop")
            break



0.8155513977282852
Data(edge_index=[2, 664450], num_nodes=116979, edge_label_index=[2, 728450], edge_label=[728450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label_index=[2, 85195], edge_label=[85195])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label_index=[2, 90282], edge_label=[90282])
Epoch 001, Loss: 0.1053, Val AUC: 0.9789, Val AP: 0.9577
Epoch 002, Loss: 0.1071, Val AUC: 0.9789, Val AP: 0.9578
Epoch 003, Loss: 0.1067, Val AUC: 0.9790, Val AP: 0.9580
Epoch 004, Loss: 0.1072, Val AUC: 0.9790, Val AP: 0.9580
Epoch 005, Loss: 0.1073, Val AUC: 0.9791, Val AP: 0.9581
Epoch 006, Loss: 0.1073, Val AUC: 0.9791, Val AP: 0.9581
Epoch 007, Loss: 0.1073, Val AUC: 0.9789, Val AP: 0.9577
Epoch 008, Loss: 0.1067, Val AUC: 0.9790, Val AP: 0.9579
Epoch 009, Loss: 0.1076, Val AUC: 0.9790, Val AP: 0.9580
Epoch 010, Loss: 0.1081, Val AUC: 0.9791, Val AP: 0.9581
Epoch 011, Loss: 0.1072, Val AUC: 0.9786, Val AP: 0.9573
Epoch 012, Loss: 0.1076, Val AUC: 0.9791, Val AP: 0.9581
Early s