#Imports

In [None]:
from google.colab import drive
drive.mount("MyDrive")

Mounted at MyDrive


In [None]:
%cd MyDrive/MyDrive/Challenge/

/content/MyDrive/MyDrive/Challenge


In [None]:
!unzip data_challenge_2021.zip

Archive:  data_challenge_2021.zip
mapname:  conversion of  failed
 extracting: test.txt                
 extracting: authors.txt             
 extracting: edgelist.txt            
 extracting: abstracts.txt           
 extracting: text_baseline.py        
 extracting: graph_baseline.py       


In [None]:
!pip install transformers
!pip install -U sentence-transformers --quiet
from google.colab import drive
import networkx as nx
import csv
import numpy as np
from tqdm import tqdm
import torch
import transformers
from transformers import DistilBertTokenizerFast
from transformers import DistilBertModel, DistilBertTokenizerFast
from random import randint, sample
import random

Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 7.5 MB/s 
Collecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 53.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 7.2 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 65.2 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 61.0 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found exis


#Read Graph

In this section we read a graph and form a train/test datasets. In order to prevent the model from training on ground truth, we drop edges for pairs in the dataset.\
First, 15% of edges are removed from the graph, the pairs for these edges form a validation dataset. The same amount of pairs without edges is sampled at random.\
Next, we sample 200k edges from graphto use as a training set, 60% of these edges are also removed. The obtained graph is denoted as ```G_train```

In [None]:
G = nx.read_edgelist('edgelist.txt', delimiter=',', create_using=nx.Graph(), nodetype=int)
nodes = list(G.nodes())
n = G.number_of_nodes()
m = G.number_of_edges()
print('Number of nodes:', n)
print('Number of edges:', m)

Number of nodes: 138499
Number of edges: 1091955


In [None]:
random.seed(43)
EDGE_DROP_TRAIN_RATE = 0.6
G_train = G.copy()
test_size = int(m * 0.15) 
test_pairs = []
test_labels = []
for i,edge in enumerate(tqdm(sample(list(G_train.edges()), test_size))):

    n1 = randint(0, n-1)
    n2 = randint(0, n-1)
    while (n1, n2) in G_train.edges() or n1 == n2:
        n1 = randint(0, n-1)
        n2 = randint(0, n-1)

    test_pairs.append([edge[0], edge[1]])
    test_labels.append(1)
    test_pairs.append([n1, n2])
    test_labels.append(0)
    G_train.remove_edge(edge[0], edge[1]) #drop edge in test set


train_pairs = []
train_labels = []


train_abstracts = {}
for i,edge in enumerate(tqdm(sample(list(G_train.edges()), 200000))):

    n1 = randint(0, n-1)
    n2 = randint(0, n-1)
    while (n1, n2) in G_train.edges() or n1 == n2:
        n1 = randint(0, n-1)
        n2 = randint(0, n-1)

    train_pairs.append([edge[0], edge[1]])
    train_labels.append(1)
    train_pairs.append([n1, n2])
    train_labels.append(0)
    remove_edge = np.random.rand() < EDGE_DROP_TRAIN_RATE #remove edge in train set with some probability
    if remove_edge:
        G_train.remove_edge(edge[0], edge[1])

100%|██████████| 163793/163793 [00:01<00:00, 97862.74it/s]
100%|██████████| 200000/200000 [00:04<00:00, 47023.55it/s] 


In [None]:
LOAD_TEXT_FEATURES = True
LOAD_AUTHOR_FEATURES = True
LOAD_PAIR_FEATURES = False

#Bert model

We load a pre-trained BERT to obtain abstract embeddings

In [None]:
import torch.nn as nn
from sentence_transformers import SentenceTransformer

In [None]:
!gdown --id 1Z7rU8crJcgZJoScxPcRre9mdScioxDO1

Access denied with the following error:

 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?id=1Z7rU8crJcgZJoScxPcRre9mdScioxDO1 



In [None]:
if not LOAD_TEXT_FEATURES:
    model = SentenceTransformer('all-MiniLM-L6-v2')
    #model = SentenceTransformer('average_word_embeddings_glove.6B.300d')

    encodings = model.encode(list(abstracts.values()), device='cuda', show_progress_bar=True, batch_size=256, convert_to_numpy=True)

    #np.save("embeddings.npy", encodings)
else:
    !gdown --id 1233WIYYiiKavkFHIXI9R8niky1V-bBNm
    embeddings = np.load("embeddings.npy")

Access denied with the following error:

 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?id=1233WIYYiiKavkFHIXI9R8niky1V-bBNm 



#Process Authors

In this block we extract all the needed information about aithors to create emperical features and author embeddins

In [None]:
def process_line_of_authors(line):
    paper_id, authors = line.rstrip('\n').split('|--|')
    paper_id = int(paper_id)
    authors = authors.split(',')
    return (paper_id, authors)

#FIND N-LESS COMMON WORDS
def less_common(n_words, words_list, min_occurence=9):
    big_count = Counter(words_list) #count each word

    #select word occuring more than 9 times
    filtered_big_count = Counter()
    temp = dict(filter(lambda x : x[1] >= min_occurence, big_count.items()))
    for key, value in temp.items():
      filtered_big_count[key] = value
    
    return [tup[0] for tup in filtered_big_count.most_common()[-n_words:]]

# INTERSECTION OF LESS COMMON WORDS BETWEEN TWO PAPER ABSTRACT
def intersection_less_common(node1, node2, abstracts_list, less_common_list):
    node1_oh = [abstracts_list[node1].count(x) for x in less_common_list]
    node2_oh = [abstracts_list[node2].count(x)  for x in less_common_list]
    #node1_oh = list(map(lambda x: int(x in abstracts_list[node1]), less_common_list))
    #node2_oh = list(map(lambda x: int(x in abstracts_list[node2]), less_common_list))

    return np.dot(node1_oh,  node2_oh)

def len_shortest_p(graph, src, target):
  try :
    return nx.shortest_path_length(graph, src, target)
  except :
    return 10000

In [None]:
with open('authors.txt', 'r') as f:
  l = f.readlines()
  paper_to_authors = dict(map(process_line_of_authors, l))

# LIST WITH ALL AUTHORS
All_authors = []
for authors in paper_to_authors.values():
    All_authors.extend(authors)
All_authors = set(All_authors)

#DICTIONNARY DICT[AUTHOR] + AUTHOR_ID
author_to_authorid = {author : i+1 for i, author in enumerate(All_authors)}

#DICTIONNARY DICT[PAPE_ID] = LIST[AUTHORS_ID]
paper_to_authorsid = {i : list(map(lambda x: author_to_authorid[x], authors)) for i, authors in paper_to_authors.items()}

In this cell two graphs of authors are created, one is the graph of citation, another is a graph of cooperation

In [None]:

def add_edge_or_increase_weight(G, u, v): #create edge if it is not in graph or add +1 to its weight otherwise
    if G.has_edge(u,v):
        G[u][v]["weight"] += 1
    else:
        G.add_edge(u, v, weight=1)
    return

A_cite_graph = nx.Graph()
for edge in G.edges():
    authors_p1 = paper_to_authorsid[edge[0]]
    authors_p2 = paper_to_authorsid[edge[1]]
    #for each pair of athuros of two papers we create edge in the citation graph
    for a1 in authors_p1:
        for a2 in authors_p2:
            add_edge_or_increase_weight(A_cite_graph, a1, a2)

A_coop_graph = nx.Graph()
for a_list in paper_to_authorsid.values():
    #for each paper we look at all its authors and include information in cooperation graph
    for i in range(len(a_list)-1):
        for j in range(i+1, len(a_list)):
            add_edge_or_increase_weight(A_coop_graph, a_list[i],a_list[j])

#Node2Vec

Here we run Node2Vec to extract representations of each author from the graph

In [None]:
!pip install fastnode2vec
from fastnode2vec import Node2Vec, Graph
import gensim

Collecting fastnode2vec
  Downloading fastnode2vec-0.0.5-py3-none-any.whl (7.1 kB)
Installing collected packages: fastnode2vec
Successfully installed fastnode2vec-0.0.5


In [None]:
if not LOAD_AUTHOR_FEATURES
    edges_list = [(str(e[0]), str(e[1]), A_cite_graph[e[0]][e[1]]["weight"])
                  for e in A_cite_graph.edges]
    g = Graph(edges_list, directed=False, weighted=True)
    node2vec = Node2Vec(g, dim=128, walk_length=15,
                        context=10, p=1, q=0.5, workers=10)
    node2vec.train(epochs=100)
    node2vec.wv.save_word2vec_format('a_cite.nodevectors')

    edges_list_coop = [(str(e[0]), str(e[1]), A_coop_graph[e[0]][e[1]]["weight"])
              for e in A_coop_graph.edges]
    g = Graph(edges_list, directed=False, weighted=True)
    a_coop_model = Node2Vec(g, dim=128, walk_length=15,
                        context=10, p=1, q=0.5, workers=10)
    a_coop_model.train(epochs=100)
    a_coop_model.wv.save_word2vec_format('a_coop.nodevectors')

else:
    ! gdown --id 1-00RDGpf6bvC7De3K1I2aAP5gErgEO9Y
    !gdown --id 1G-uItGZdwtH5E6vNBOlGlefvL4qEbTAO

Reading graph: 100%|██████████| 8050051/8050051 [00:09<00:00, 864594.65it/s]
Training: 100%|██████████| 14968200/14968200 [46:01<00:00, 5419.69it/s]


In [None]:
a_coop_model = gensim.models.KeyedVectors.load_word2vec_format("a_coop.nodevectors")
a_cite_model = gensim.models.KeyedVectors.load_word2vec_format("a_cite.nodevectors")

In [None]:
node_a_embeddings = {}
for node in G_train.nodes():
    a_ids = paper_to_authorsid[int(node)]
    emb_coop = np.mean([a_coop_model.wv[str(a)] for a in a_ids if str(a) in a_coop_model.wv.vocab],axis = 0)
    emb_cite = np.mean([a_cite_model.wv[str(a)] for a in a_ids if str(a) in a_cite_model.wv.vocab],axis = 0)
    node_a_embeddings[node] = np.hstack([emb_coop, emb_cite])

  after removing the cwd from sys.path.
  """


#Pair features

In [None]:
NB_PAIR_FEATURES = 12

X_pairs_train = torch.zeros((len(train_pairs), NB_PAIR_FEATURES))
for i, edge in enumerate(tqdm(train_pairs)):
    X_pairs_train[i, 0] = G_train.degree(edge[0]) + G_train.degree(edge[1])
    X_pairs_train[i, 1] = abs(G_train.degree(edge[0]) - G_train.degree(edge[1]))
    X_pairs_train[i, 2] = len(set(paper_to_authorsid[edge[0]]).intersection(set(paper_to_authorsid[edge[1]]))) 
    X_pairs_train[i, 3] = len(set(paper_to_authorsid[edge[0]]).intersection(set(paper_to_authorsid[edge[1]]))) / len(set(paper_to_authorsid[edge[0]]).union(set(paper_to_authorsid[edge[1]])))
    X_pairs_train[i, 4] = len({n for n in G_train.neighbors(edge[0])}.intersection({n for n in G_train.neighbors(edge[1])}))
    X_pairs_train[ i, 5] = len({n for n in G_train.neighbors(edge[0])}.intersection({n for n in G_train.neighbors(edge[1])})) / (len({n for n in G_train.neighbors(edge[0])}.union({n for n in G_train.neighbors(edge[1])})) + 1)
    X_pairs_train[i, 6] = list(nx.resource_allocation_index(G_train, [(edge[0], edge[1])]))[0][-1]
    X_pairs_train[i, 7] = list(nx.adamic_adar_index(G_train, [(edge[0], edge[1])]))[0][-1]
    X_pairs_train[i, 8] = len(set(paper_to_authorsid[edge[0]])) + len(set(paper_to_authorsid[edge[1]])) 
    X_pairs_train[i, 9] = nx.algorithms.centrality.dispersion(G_train, u=edge[0], v=edge[1])
    X_pairs_train[i, 10] = nx.has_path(G_train, edge[0], edge[1])
    X_pairs_train[i, 11] = len_shortest_p(G_train, edge[0], edge[1]) < 11
    #X_pairs_train[ i, 9] = len([path for path in nx.all_simple_edge_paths(G, source=edge[0], target=edge[1], cutoff=3) if len(path) == 3])

    

In [None]:
X_pairs_test = torch.zeros((len(test_pairs), NB_PAIR_FEATURES))
for i, edge in enumerate(tqdm(test_pairs)):
    X_pairs_test[i, 0] = G_train.degree(edge[0]) + G_train.degree(edge[1])
    X_pairs_test[i, 1] = abs(G_train.degree(edge[0]) - G_train.degree(edge[1]))
    X_pairs_test[i, 2] = len(set(paper_to_authorsid[edge[0]]).intersection(set(paper_to_authorsid[edge[1]]))) 
    X_pairs_test[i, 3] = len(set(paper_to_authorsid[edge[0]]).intersection(set(paper_to_authorsid[edge[1]]))) / len(set(paper_to_authorsid[edge[0]]).union(set(paper_to_authorsid[edge[1]])))
    X_pairs_test[i, 4] = len({n for n in G_train.neighbors(edge[0])}.intersection({n for n in G_train.neighbors(edge[1])}))
    X_pairs_test[ i, 5] = len({n for n in G_train.neighbors(edge[0])}.intersection({n for n in G_train.neighbors(edge[1])})) / (len({n for n in G_train.neighbors(edge[0])}.union({n for n in G_train.neighbors(edge[1])})) + 1)
    X_pairs_test[i, 6] = list(nx.resource_allocation_index(G_train, [(edge[0], edge[1])]))[0][-1]
    X_pairs_test[i, 7] = list(nx.adamic_adar_index(G_train, [(edge[0], edge[1])]))[0][-1]
    X_pairs_test[i, 8] = len(set(paper_to_authorsid[edge[0]])) + len(set(paper_to_authorsid[edge[1]])) 
    
    X_pairs_test[i, 9] = nx.algorithms.centrality.dispersion(G_train, u=edge[0], v=edge[1])
    X_pairs_test[i, 10] = nx.has_path(G_train, edge[0], edge[1])
    X_pairs_test[i, 11] = len_shortest_p(G_train, edge[0], edge[1]) < 11
    #X_pairs_test[ i, 9] = len([path for path in nx.all_simple_edge_paths(G, source=edge[0], target=edge[1], cutoff=3) if len(path) == 3])

In [None]:
sub_pairs = list()
with open('test.txt', 'r') as f:
    for line in f:
        t = line.split(',')
        sub_pairs.append((int(t[0]), int(t[1])))

# Create the test matrix. Use the same features as above
X_sub = torch.zeros((len(sub_pairs), NB_PAIR_FEATURES))

for i,edge in enumerate(tqdm(sub_pairs)):
    X_sub[i, 0] = G.degree(edge[0]) + G.degree(edge[1])
    X_sub[i, 1] = abs(G.degree(edge[0]) - G.degree(edge[1]))
    X_sub[i, 2] = len(set(paper_to_authorsid[edge[0]]).intersection(set(paper_to_authorsid[edge[1]]))) 
    X_sub[i, 3] = len(set(paper_to_authorsid[edge[0]]).intersection(set(paper_to_authorsid[edge[1]]))) / len(set(paper_to_authorsid[edge[0]]).union(set(paper_to_authorsid[edge[1]])))
    X_sub[i, 4] = len({n for n in G.neighbors(edge[0])}.intersection({n for n in G.neighbors(edge[1])}))
    X_sub[ i, 5] = len({n for n in G.neighbors(edge[0])}.intersection({n for n in G.neighbors(edge[1])})) / len({n for n in G.neighbors(edge[0])}.union({n for n in G.neighbors(edge[1])}))
    X_sub[i, 6] = list(nx.resource_allocation_index(G, [(edge[0], edge[1])]))[0][-1]
    X_sub[i, 7] = list(nx.adamic_adar_index(G, [(edge[0], edge[1])]))[0][-1]
    X_sub[i, 8] = len(set(paper_to_authorsid[edge[0]])) + len(set(paper_to_authorsid[edge[1]])) 
    
    X_sub[i, 9] = nx.algorithms.centrality.dispersion(G, u=edge[0], v=edge[1])
    X_sub[i, 10] = nx.has_path(G, edge[0], edge[1])
    X_sub[i, 11] = len_shortest_p(G, edge[0], edge[1]) < 11
    #X_sub[ i, 9] = len([path for path in nx.all_simple_edge_paths(G, source=edge[0], target=edge[1], cutoff=3) if len(path) == 3])


#GCN

In this block we use our GNN-based architecture as described in the report

In [None]:
!pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html

Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl-cu111
  Downloading https://data.dgl.ai/wheels/dgl_cu111-0.7.2-cp37-cp37m-manylinux1_x86_64.whl (165.0 MB)
[K     |████████████████████████████████| 165.0 MB 54 kB/s 
Installing collected packages: dgl-cu111
Successfully installed dgl-cu111-0.7.2


In [None]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp
from dgl.nn import SAGEConv, GATConv

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [None]:


# build a two-layer GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'gcn')
        self.conv2 = SAGEConv(h_feats, h_feats, 'gcn')
        self.conv3 = SAGEConv(h_feats, h_feats, 'gcn')
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout(p=.5)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = self.relu(h)
        h = self.dropout(h)

        h = self.conv2(g, h)
        h = self.relu(h)
        h = self.dropout(h)

        h = self.conv3(g, h)
        return h

In [None]:
class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.fc1 = nn.Linear(h_feats * 2 + NB_PAIR_FEATURES, h_feats)
        self.fc2 = nn.Linear(h_feats, 1)

    def forward(self, h, X ):
        feat = torch.cat((h[:,0,:], h[:,1,:], X), dim=1)
        x = F.relu(self.fc1(feat))
        return self.fc2(x)


class MLPPredictorSE(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.fc1 = nn.Linear(h_feats, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x ):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
X = torch.tensor(embeddings).to(device)

test_labels = torch.tensor(test_labels).float().to(device)
train_labels = torch.tensor(train_labels).float().to(device)

X_pairs_test = X_pairs_test.to(device)
X_pairs_train = X_pairs_train.to(device)

In [None]:
X_author = torch.zeros(X.shape[0], node_a_embeddings[0].shape[0])
for i, vec in zip(node_a_embeddings.keys(), node_a_embeddings.values()):
    X_author[i] = torch.tensor(vec)

In [None]:
X = torch.cat([X, X_author.to(device)], dim=1)
g_train = dgl.from_networkx(G_train).to(device)
g_train = dgl.add_self_loop(g_train)


gnn_model = GraphSAGE(X.shape[1], 150).to(device)
mlp_model = MLPPredictorSE(150 + NB_PAIR_FEATURES).to(device)
params = list(gnn_model.parameters()) + list(mlp_model.parameters())
optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.0005)

In [None]:
epochs = 1000
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.4, verbose=True, patience=25)
for e in range(epochs):
    optimizer.zero_grad()
    h = gnn_model(g_train, X) #gnn forward pass
    h_train = h[train_pairs] #pick features from train parirs
    h_test = h[test_pairs]#pick features from test parirs
    se_vec = (h_train[:,0,:] - h_train[:,1,:]) ** 2 #compute squared difference 
    se_vec = torch.cat([se_vec, X_pairs_train], dim=-1) # add emperical features
    #scores = (h_train[:,0,:] * h_train[:,1,:]).sum(dim=-1)
    scores = mlp_model(se_vec.squeeze()).squeeze() #MLP forward pass
    with torch.no_grad():
        #same for test
        se_vec_test = (h_test[:,0,:] - h_test[:,1,:]) ** 2
        se_vec_test = torch.cat([se_vec_test, X_pairs_test], dim=-1)
        scores_test = mlp_model(se_vec_test.squeeze()).squeeze()
        test_loss = F.binary_cross_entropy_with_logits(scores_test, test_labels)
        lr_scheduler.step(test_loss)
    loss  = F.binary_cross_entropy_with_logits(scores, train_labels)
    loss.backward()
    optimizer.step()

    print("epoch {},train_loss = {}, test_loss = {}".format(e, loss.item(), test_loss.item()))

epoch 0,train_loss = 0.6477522850036621, test_loss = 0.646271288394928
epoch 1,train_loss = 0.6253641843795776, test_loss = 0.624096155166626
epoch 2,train_loss = 0.6159651279449463, test_loss = 0.6152353882789612
epoch 3,train_loss = 0.6041077971458435, test_loss = 0.6041202545166016
epoch 4,train_loss = 0.588280200958252, test_loss = 0.5892348289489746
epoch 5,train_loss = 0.571135938167572, test_loss = 0.5734578967094421
epoch 6,train_loss = 0.5546502470970154, test_loss = 0.5586272478103638
epoch 7,train_loss = 0.5370312333106995, test_loss = 0.5436577200889587
epoch 8,train_loss = 0.5151461958885193, test_loss = 0.5258376002311707
epoch 9,train_loss = 0.4904625713825226, test_loss = 0.5070428252220154
epoch 10,train_loss = 0.47099119424819946, test_loss = 0.4977725148200989
epoch 11,train_loss = 0.4657295048236847, test_loss = 0.506120502948761
epoch 12,train_loss = 0.47196632623672485, test_loss = 0.5240391492843628
epoch 13,train_loss = 0.4735821783542633, test_loss = 0.52965176

In [None]:
from sklearn.metrics import f1_score, accuracy_score

In [None]:
test_prob = torch.sigmoid(scores_test).cpu().numpy()
pred_labels = (test_prob > 0.3) * 1.
y_test_np = test_labels.cpu().numpy()

In [None]:
print(f1_score(y_test_np, pred_labels))

print(accuracy_score(y_test_np, pred_labels))

0.9286927095350502
0.9264132166820316


In [None]:
with torch.no_grad():
    h = gnn_model(g_train, X)
    h_sub = h[sub_pairs]
    se_vec = (h_sub[:,0,:] - h_sub[:,1,:]) ** 2
    se_vec = torch.cat([se_vec, X_sub[:i+1].to(device)], dim=-1)

    scores = mlp_model(se_vec.squeeze()).squeeze()

In [None]:
probs = torch.sigmoid(scores).cpu().detach().numpy()

In [None]:
predictions = zip(range(len(probs)), probs)