In [1]:
import networkx as nx
import torch
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
import random
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import Node2Vec

In [2]:
G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)

In [3]:
train = "random"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )

    print(train_data)
    print(val_data)
    print(test_data)

Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[664450], edge_label_index=[2, 664450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label=[83056], edge_label_index=[2, 83056])


In [4]:
train_data.x = torch.ones((train_data.num_nodes, 1))
val_data.x = train_data.x.clone()
test_data.x = train_data.x.clone()

In [5]:
model_n2v = Node2Vec(
    edge_index = train_data.edge_index,
    embedding_dim = 64,
    walk_length = 20,
    context_size = 10,
    walks_per_node = 10,
    p = 1.0,
    q = 1.0,
    num_negative_samples = 1,
    sparse=True
)
optimizer = torch.optim.SparseAdam(list(model_n2v.parameters()), lr=0.01)
n2v_loader = model_n2v.loader(batch_size = 128, shuffle =  True)

In [6]:
def train_epoch():
    model_n2v.train()
    total_loss = 0
    count = 0
    for pos, neg in n2v_loader:
        optimizer.zero_grad()
        loss = model_n2v.loss(pos, neg)
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        count = count + 1
    return total_loss / count

for epoch in range(1, 10):
    loss = train_epoch()
    print(f"Epoch {epoch:03d} | Loss: {loss:.4f}")

model_n2v.eval()
z = model_n2v() 

Epoch 001 | Loss: 2.8980
Epoch 002 | Loss: 1.2316
Epoch 003 | Loss: 0.9485
Epoch 004 | Loss: 0.8530
Epoch 005 | Loss: 0.8115
Epoch 006 | Loss: 0.7893
Epoch 007 | Loss: 0.7761
Epoch 008 | Loss: 0.7677
Epoch 009 | Loss: 0.7620


In [7]:
class Predictor(nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )

    def forward(self, x, edge_label_index):
        edge_emb_src = x[edge_label_index[0]]
        edge_emb_dst = x[edge_label_index[1]]
        edge_emb = torch.cat([edge_emb_src, edge_emb_dst], dim=-1)
        return self.mlp(edge_emb).view(-1)

predictor = Predictor(hidden_channels=64)

In [8]:
optimizer_mlp = torch.optim.Adam(predictor.parameters(), lr=0.01)

def train_epoch2():
    predictor.train()
    optimizer_mlp.zero_grad()
    pred = predictor(z, train_data.edge_label_index)
    loss = F.binary_cross_entropy_with_logits(pred, train_data.edge_label.float())
    loss.backward()
    optimizer_mlp.step()
    return loss

def evaluate():
    predictor.eval()
    y_pred = predictor(z, val_data.edge_label_index)
    loss = F.binary_cross_entropy_with_logits(y_pred, val_data.edge_label.float()).item()
    auc = roc_auc_score(val_data.edge_label.numpy(), y_pred.detach().numpy())
    ap = average_precision_score(val_data.edge_label.numpy(), y_pred.detach().numpy())
    return loss, auc, ap

best_val_auc = 0
limit = 6
count = 0
val_aps = []
val_aucs = []
for epoch in range(1, 50):
    loss = train_epoch2()
    val_loss, val_auc, val_ap = evaluate()
    print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Val Loss : {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count =  count + 1
        if count >= limit:
            print("Early stop")
            break


Epoch 001, Loss: 0.6925, Val Loss : 0.6844, Val AUC: 0.6570, Val AP: 0.6637
Epoch 002, Loss: 0.6836, Val Loss : 0.6757, Val AUC: 0.7378, Val AP: 0.7447
Epoch 003, Loss: 0.6743, Val Loss : 0.6660, Val AUC: 0.7829, Val AP: 0.7875
Epoch 004, Loss: 0.6640, Val Loss : 0.6556, Val AUC: 0.8096, Val AP: 0.8117
Epoch 005, Loss: 0.6529, Val Loss : 0.6447, Val AUC: 0.8285, Val AP: 0.8279
Epoch 006, Loss: 0.6413, Val Loss : 0.6335, Val AUC: 0.8448, Val AP: 0.8414
Epoch 007, Loss: 0.6294, Val Loss : 0.6220, Val AUC: 0.8609, Val AP: 0.8553
Epoch 008, Loss: 0.6172, Val Loss : 0.6100, Val AUC: 0.8777, Val AP: 0.8705
Epoch 009, Loss: 0.6045, Val Loss : 0.5973, Val AUC: 0.8952, Val AP: 0.8874
Epoch 010, Loss: 0.5910, Val Loss : 0.5838, Val AUC: 0.9128, Val AP: 0.9053
Epoch 011, Loss: 0.5766, Val Loss : 0.5694, Val AUC: 0.9292, Val AP: 0.9229
Epoch 012, Loss: 0.5612, Val Loss : 0.5541, Val AUC: 0.9431, Val AP: 0.9384
Epoch 013, Loss: 0.5448, Val Loss : 0.5381, Val AUC: 0.9539, Val AP: 0.9508
Epoch 014, L