In [1]:
import networkx as nx
import torch
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
import random
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import matplotlib.pyplot as plt
from torch_geometric.nn import GATv2Conv

In [2]:
G = nx.read_gml("/home/schoenstein/these/gnn_postbiblio/graph/graph_light.gml")
data = from_networkx(G)
print(data)

Data(edge_index=[2, 830562], num_nodes=116979)


In [3]:
train = "double split"

if train == "random":
    transform = T.RandomLinkSplit(
        num_val = 0.1,  
        num_test = 0.1,  
        disjoint_train_ratio = 0,  
        neg_sampling_ratio = 1,
        is_undirected = True
    )
    train_data, val_data, test_data = transform(data)
    print(train_data)
    print(val_data)
    print(test_data)

elif train == "double split":
    cc = list(nx.connected_components(G))
    neg_inside = 0
    train_list = []
    val_list = []
    test_list = []
    for c in cc:
        G2 = G.subgraph(c).copy()
        data2 = from_networkx(G2)
        transform = T.RandomLinkSplit(
                num_val = 0.1,  
                num_test = 0.1,  
                disjoint_train_ratio = 0,  
                neg_sampling_ratio = 1.5,
                is_undirected = True
            )
        train_data2, val_data2, test_data2 = transform(data2)
        neg_inside = neg_inside + len(train_data2.edge_label)
        train_list.append(train_data2)
        val_list.append(val_data2)
        test_list.append(test_data2)
    ratio_neg_inside = neg_inside/(len(list(G.edges()))*2)
    print(ratio_neg_inside)
    transform = T.RandomLinkSplit(
            num_val = 0.1,  
            num_test = 0.1,  
            disjoint_train_ratio = 0,  
            neg_sampling_ratio = 1 - ratio_neg_inside,
            is_undirected = True
        )
    train_data3, val_data3, test_data3 = transform(data)

    #pos_train = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                             #for d in train_list], dim=1)
    pos_train = train_data3.edge_label_index[:, train_data3.edge_label == 1]
    #pos_val = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                           #for d in val_list], dim=1)
    pos_val = val_data3.edge_label_index[:, val_data3.edge_label == 1]
    #pos_test = torch.cat([d.edge_label_index[:, d.edge_label==1] 
                            #for d in test_list], dim=1)
    pos_test = test_data3.edge_label_index[:, test_data3.edge_label == 1]
    neg_train1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                             for d in train_list], dim=1)
    neg_val1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                           for d in val_list], dim=1)
    neg_test1 = torch.cat([d.edge_label_index[:, d.edge_label==0] 
                            for d in test_list], dim=1)
    neg_train2 = train_data3.edge_label_index[:, train_data3.edge_label==0]             
    neg_val2 = val_data3.edge_label_index[:, val_data3.edge_label==0] 
    neg_test2 = test_data3.edge_label_index[:, test_data3.edge_label==0]
    neg_train = torch.cat([neg_train1, neg_train2], dim=1)
    neg_val = torch.cat([neg_val1, neg_val2], dim=1)
    neg_test = torch.cat([neg_test1, neg_test2], dim=1)
    train_data = Data(
        edge_index=train_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_train, neg_train], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_train.size(1), dtype=torch.long),
            torch.zeros(neg_train.size(1), dtype=torch.long)
        ])
    )
    val_data = Data(
        edge_index=val_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_val, neg_val], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_val.size(1), dtype=torch.long),
            torch.zeros(neg_val.size(1), dtype=torch.long)
        ])
    )
    test_data = Data(
        edge_index=test_data3.edge_index,
        num_nodes=data.num_nodes,
        edge_label_index=torch.cat([pos_test, neg_test], dim=1),
        edge_label=torch.cat([
            torch.ones(pos_test.size(1), dtype=torch.long),
            torch.zeros(neg_test.size(1), dtype=torch.long)
        ])
    )
    print(train_data)
    print(val_data)
    print(test_data)



0.8155513977282852
Data(edge_index=[2, 664450], num_nodes=116979, edge_label_index=[2, 728450], edge_label=[728450])
Data(edge_index=[2, 664450], num_nodes=116979, edge_label_index=[2, 85195], edge_label=[85195])
Data(edge_index=[2, 747506], num_nodes=116979, edge_label_index=[2, 90282], edge_label=[90282])


In [4]:
train_data.x = torch.ones((train_data.num_nodes, 1))
val_data.x = train_data.x.clone()
test_data.x = train_data.x.clone()

In [5]:
"""G_train = nx.Graph()
G_train.add_nodes_from(range(train_data.num_nodes))
G_train.add_edges_from(train_data.edge_index.t().tolist())
degree = dict(G_train.degree())
degree_norm = {n: d for n,d in degree.items()}
clustering = nx.clustering(G_train)
degree_tensor = torch.tensor(list(degree_norm.values()), dtype=torch.float32)
clustering_tensor = torch.tensor(list(clustering.values()), dtype=torch.float32)
train_data.x = torch.stack([degree_tensor, clustering_tensor], dim=-1)
val_data.x = train_data.x.clone()
test_data.x = train_data.x.clone()

print(train_data)
print(val_data)
print(test_data)"""

'G_train = nx.Graph()\nG_train.add_nodes_from(range(train_data.num_nodes))\nG_train.add_edges_from(train_data.edge_index.t().tolist())\ndegree = dict(G_train.degree())\ndegree_norm = {n: d for n,d in degree.items()}\nclustering = nx.clustering(G_train)\ndegree_tensor = torch.tensor(list(degree_norm.values()), dtype=torch.float32)\nclustering_tensor = torch.tensor(list(clustering.values()), dtype=torch.float32)\ntrain_data.x = torch.stack([degree_tensor, clustering_tensor], dim=-1)\nval_data.x = train_data.x.clone()\ntest_data.x = train_data.x.clone()\n\nprint(train_data)\nprint(val_data)\nprint(test_data)'

In [6]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GATv2Conv(in_channels, hidden_channels)
        self.conv2 = GATv2Conv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x


class Predictor(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )
    
    def forward(self, x, edge_label_index):
        edge_emb_src = x[edge_label_index[0]]
        edge_emb_dst = x[edge_label_index[1]]
        edge_emb = torch.cat([edge_emb_src, edge_emb_dst], dim=-1)
        return self.mlp(edge_emb).view(-1)


class Model(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.gnn = GNN(in_channels, hidden_channels)
        self.predictor = Predictor(hidden_channels)
    
    def forward(self, data):
        x = self.gnn(data.x, data.edge_index)
        pred = self.predictor(x, data.edge_label_index)
        return pred


model = Model(in_channels=train_data.x.shape[1], hidden_channels=64)

In [7]:
train_loader = LinkNeighborLoader(
    data = train_data,
    num_neighbors = [25, 10],
    edge_label_index = train_data.edge_label_index,
    edge_label = train_data.edge_label,
    batch_size = 128,
    shuffle = True
)
val_loader = LinkNeighborLoader(
    data = val_data,
    num_neighbors = [25, 10], 
    edge_label_index = val_data.edge_label_index,
    edge_label = val_data.edge_label,
    batch_size = 128,
    shuffle = True
)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #testet plusieurs valeurs de lr
#optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=0.01) #si un MLP est utilisé pour la prédiction


def train_epoch():
    model.train()
    total_loss = 0
    count = 0
    for batch in train_loader:
        optimizer.zero_grad()
        pred = model(batch)
        loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label.float())
        loss.backward()
        optimizer.step()
        total_loss = total_loss + loss.item()
        count = count + 1
    return total_loss / count

def evaluate():
    model.eval()
    y_truth = []
    y_pred = []
    total_loss_val = 0
    count = 0
    for batch in val_loader:
        pred = model(batch)
        loss_val = F.binary_cross_entropy_with_logits(pred, batch.edge_label.float())
        y_truth.append(batch.edge_label)
        y_pred.append(torch.sigmoid(pred))
        total_loss_val = total_loss_val + loss_val.item()
        count = count + 1
    y_truth = torch.cat(y_truth).numpy()
    y_pred = torch.cat(y_pred).detach().numpy()
    auc = roc_auc_score(y_truth, y_pred)
    ap = average_precision_score(y_truth, y_pred)
    return total_loss_val/count, auc, ap
    
#ROC-AUC : la probabilité qu’un positif ait un score plus haut qu’un négatif
#AP : aproxime l'air sous la courbe Precisoion/Recall, plus les positifs sont rares plus un bon score AP est difficile à obtenir

In [9]:
best_val_auc = 0
limit = 10
count = 0
train_losses = []
val_aps = []
val_aucs = []
for epoch in range(1, 50):
    loss = train_epoch()
    train_losses.append(loss)
    val_loss, val_auc, val_ap = evaluate()
    val_aps.append(val_ap)
    val_aucs.append(val_auc)
    print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Val Loss : {val_loss:.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        count = 0
    else:
        count =  count + 1
        if count >= limit:
            print("Early stop")
            break

Epoch 001, Loss: 0.6894, Val Loss : 0.6950, Val AUC: 0.5000, Val AP: 0.4874
Epoch 002, Loss: 0.6893, Val Loss : 0.6952, Val AUC: 0.5000, Val AP: 0.4874
Epoch 003, Loss: 0.6893, Val Loss : 0.6945, Val AUC: 0.5000, Val AP: 0.4874


KeyboardInterrupt: 