In [1]:
import networkx as nx
import torch
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
import random
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt
from torch_geometric.nn import GraphSAGE

In [2]:
G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)

In [3]:
train = "random"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )

    print(train_data)
    print(val_data)
    print(test_data)

Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[664450], edge_label_index=[2, 664450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])


In [4]:
train_data.x = torch.ones((train_data.num_nodes, 1))
val_data.x = train_data.x.clone()
test_data.x = train_data.x.clone()

In [5]:
model = GraphSAGE(
    in_channels=1,
    hidden_channels = 64,
    num_layers = 2,
    aggr = "lstm"
)


def predictor(z, edge_label_index):
    edge_emb_src = z[edge_label_index[0]]
    edge_emb_dst = z[edge_label_index[1]]
    edge_emb_src = F.normalize(edge_emb_src, dim = -1)
    edge_emb_dst = F.normalize(edge_emb_dst, dim = -1)
    pred = (edge_emb_src * edge_emb_dst).sum(dim = -1)
    return pred

In [6]:
train_loader = LinkNeighborLoader(
    data = train_data,
    num_neighbors = [25, 10],
    edge_label_index = train_data.edge_label_index,
    edge_label = train_data.edge_label,
    batch_size = 128,
    shuffle = True
)
val_loader = LinkNeighborLoader(
    data = val_data,
    num_neighbors = [25, 10], 
    edge_label_index = val_data.edge_label_index,
    edge_label = val_data.edge_label,
    batch_size = 128,
    shuffle = True
)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train_epoch():
    model.train()
    total_loss = 0
    count = 0
    for batch in train_loader:
        optimizer.zero_grad()
        z = model(batch.x, batch.edge_index)
        pred = predictor(z, batch.edge_label_index)
        loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label.float())
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        count = count + 1
    return total_loss / count

def evaluate():
    model.eval()
    y_truth = []
    y_pred = []
    for batch in val_loader:
        z = model(batch.x, batch.edge_index)
        pred = predictor(z, batch.edge_label_index)
        y_truth.append(batch.edge_label)
        y_pred.append(torch.sigmoid(pred))
    y_truth = torch.cat(y_truth).numpy()
    y_pred = torch.cat(y_pred).detach().numpy()
    auc = roc_auc_score(y_truth, y_pred)
    ap = average_precision_score(y_truth, y_pred)
    return auc, ap
    
#ROC-AUC : la probabilité qu’un positif ait un score plus haut qu’un négatif
#AP : aproxime l'air sous la courbe Precisoion/Recall, plus les positifs sont rares plus un bon score AP est difficile à obtenir

In [8]:
best_val_auc = 0
limit = 6
count = 0
train_losses = []
val_aps = []
val_aucs = []
for epoch in range(1, 50):
    loss = train_epoch()
    train_losses.append(loss)
    val_auc, val_ap = evaluate()
    val_aps.append(val_ap)
    val_aucs.append(val_auc)
    print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count =  count + 1
        if count >= limit:
            print("Early stop")
            break

Epoch 001, Loss: 0.6313, Val AUC: 0.8624, Val AP: 0.8480
Epoch 002, Loss: 0.5347, Val AUC: 0.8676, Val AP: 0.8500
Epoch 003, Loss: 0.5300, Val AUC: 0.8330, Val AP: 0.8275
Epoch 004, Loss: 0.5286, Val AUC: 0.8690, Val AP: 0.8736
Epoch 005, Loss: 0.5275, Val AUC: 0.8412, Val AP: 0.8654
Epoch 006, Loss: 0.5280, Val AUC: 0.8553, Val AP: 0.8720
Epoch 007, Loss: 0.5265, Val AUC: 0.8656, Val AP: 0.8741
Epoch 008, Loss: 0.5270, Val AUC: 0.8584, Val AP: 0.8709
Epoch 009, Loss: 0.5257, Val AUC: 0.8613, Val AP: 0.8740
Epoch 010, Loss: 0.5252, Val AUC: 0.8562, Val AP: 0.8668
Early stop
