In [1]:
import networkx as nx
import torch
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
import random
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import Node2Vec

In [2]:
G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)

In [3]:
train = "random"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )

    print(train_data)
    print(val_data)
    print(test_data)

Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[664450], edge_label_index=[2, 664450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])


In [4]:
train_data.x = torch.ones((train_data.num_nodes, 1))
val_data.x = train_data.x.clone()
test_data.x = train_data.x.clone()

In [5]:
model_n2v = Node2Vec(
    edge_index = train_data.edge_index,
    embedding_dim = 64,
    walk_length = 20,
    context_size = 10,
    walks_per_node = 10,
    p = 1.0,
    q = 1.0,
    num_negative_samples = 1,
    sparse=False  #True : seuls les noeuds vus sont mis à jour, False : toute la matrice reçoit un gradient donc moins opti mais nécessaire pour entrainer les 2 en même temps
)
n2v_loader = model_n2v.loader(batch_size = 128, shuffle =  True)

In [6]:
class Predictor(nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )

    def forward(self, x, edge_label_index):
        edge_emb_src = x[edge_label_index[0]]
        edge_emb_dst = x[edge_label_index[1]]
        edge_emb = torch.cat([edge_emb_src, edge_emb_dst], dim=-1)
        return self.mlp(edge_emb).view(-1)

predictor = Predictor(hidden_channels=64)

In [7]:
#optimizer_n2v = torch.optim.SparseAdam(model_n2v.parameters(), lr=0.01)
optimizer = torch.optim.Adam(list(model_n2v.parameters()) + list(predictor.parameters()),lr=0.01)
#optimizer_mlp = torch.optim.Adam(predictor.parameters(), lr=0.01)

In [None]:
def train_epoch():
    model_n2v.train()
    predictor.train()
    #optimizer_n2v.zero_grad()
    #optimizer_mlp.zero_grad()
    optimizer.zero_grad()
    loss_n2v = 0
    count = 0
    for pos, neg in n2v_loader:
        loss = model_n2v.loss(pos, neg)
        loss_n2v = loss_n2v + loss.item()
        count = count + 1
    loss_n2v = loss_n2v / count
    z = model_n2v()   
    pred = predictor(z, train_data.edge_label_index)
    loss_lp = F.binary_cross_entropy_with_logits(pred, train_data.edge_label.float())
    loss_total = loss_n2v + loss_lp
    loss_total.backward()
    #optimizer_n2v.step()
    #optimizer_mlp.step()
    optimizer.step()
    return loss_n2v, loss_lp

def evaluate():
    model_n2v.eval()
    predictor.eval()
    z = model_n2v()
    y_pred = predictor(z, val_data.edge_label_index)
    loss = F.binary_cross_entropy_with_logits(y_pred, val_data.edge_label.float()).item()
    auc = roc_auc_score(val_data.edge_label.detach().numpy(),y_pred.detach().numpy())
    ap = average_precision_score(val_data.edge_label.detach().numpy(),y_pred.detach().numpy())
    return loss, auc, ap

In [9]:
best_val_auc = 0
limit = 6
count = 0
for epoch in range(1, 50):
    loss_n2v, loss_lp = train_epoch()
    val_loss, val_auc, val_ap = evaluate()
    print(f"Epoch {epoch:03d}, N2V Loss: {loss_n2v:.4f}, LP Loss : {loss_lp:.4f}, Val Loss : {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count += 1
        if count >= limit:
            print("Early stop")
            break


Epoch 001, N2V Loss: 6.3733, LP Loss : 0.6971, Val Loss : 0.6954, Val AUC: 0.5038, Val AP: 0.5025
Epoch 002, N2V Loss: 6.3732, LP Loss : 0.6944, Val Loss : 0.6946, Val AUC: 0.5080, Val AP: 0.5067
Epoch 003, N2V Loss: 6.3721, LP Loss : 0.6933, Val Loss : 0.6941, Val AUC: 0.5123, Val AP: 0.5109
Epoch 004, N2V Loss: 6.3723, LP Loss : 0.6924, Val Loss : 0.6933, Val AUC: 0.5162, Val AP: 0.5154
Epoch 005, N2V Loss: 6.3714, LP Loss : 0.6912, Val Loss : 0.6930, Val AUC: 0.5196, Val AP: 0.5196
Epoch 006, N2V Loss: 6.3733, LP Loss : 0.6905, Val Loss : 0.6926, Val AUC: 0.5228, Val AP: 0.5234
Epoch 007, N2V Loss: 6.3704, LP Loss : 0.6896, Val Loss : 0.6922, Val AUC: 0.5255, Val AP: 0.5266
Epoch 008, N2V Loss: 6.3716, LP Loss : 0.6888, Val Loss : 0.6921, Val AUC: 0.5277, Val AP: 0.5291
Epoch 009, N2V Loss: 6.3712, LP Loss : 0.6881, Val Loss : 0.6919, Val AUC: 0.5298, Val AP: 0.5314
Epoch 010, N2V Loss: 6.3733, LP Loss : 0.6873, Val Loss : 0.6916, Val AUC: 0.5323, Val AP: 0.5340
Epoch 011, N2V Loss: