In [1]:
import json
import pickle
import pandas as pd
import networkx as nx
import numpy as np
from tqdm import tqdm

In [2]:
data_path = '/Users/alex/Library/CloudStorage/OneDrive-MUNI/Dokumenty/Projects Data/NLP/SciDocs/'
paper_cite_file = data_path + 'paper_metadata_view_cite_read.json'
paper_cls_file = data_path + 'paper_metadata_mag_mesh.json'
paper_rec_file = data_path + 'paper_metadata_recomm.json'
user_activity_and_citations_embeddings_path = data_path + 'specter-embeddings/user-citation.jsonl'

## Look at the data

In [3]:
papers_data = {}
with open(paper_rec_file, 'r') as f:
    papers_data = json.load(f)

In [4]:
list(list(papers_data.values())[0].keys())

['abstract', 'authors', 'cited_by', 'paper_id', 'references', 'title', 'year']

In [5]:
papers_df = pd.DataFrame(papers_data).T

In [6]:
papers_df.head(3)

Unnamed: 0,abstract,authors,cited_by,paper_id,references,title,year
0002c0f45b3ef0f1491f91cbfefe9543e9af6163,In this paper we introduce an hp certified red...,"[3122778, 39921175, 1905947, 2848614]","[16728d6f0a225bb8a71ebe4d0acd2512ca775327, 48d...",0002c0f45b3ef0f1491f91cbfefe9543e9af6163,"[07e1f620f68c0be579fb05bf6d231fa06b0db7c3, 7f2...",An hp Certified Reduced Basis Method for Param...,2011
0004d38aa501306e2719e4d0413dcb5c788676b1,We test whether momentum strategies remain pro...,[],"[67d88e58812562f0270b7077408776157a542fbf, 946...",0004d38aa501306e2719e4d0413dcb5c788676b1,"[18cb63580217983f2fd4b54141b2f83b96819dd3, 360...",Are Momentum Profits Robust to Trading Costs,2003
000587e08c6ce8c3d4c74360e34abe7c543a0e98,OBJECTIVES\nThe death of a child in the pediat...,"[4526222, 6110925, 38791518, 3743376, 31914233...","[55bac4ab517f5174d3259e7631d4fc6fa58cbac7, 593...",000587e08c6ce8c3d4c74360e34abe7c543a0e98,"[95f677c6287e19a9afb8c25848b6f3437340e17a, ecb...","""I was able to still be her mom""--parenting at...",2012


In [7]:
#extra column
all(papers_df.index == papers_df.paper_id)

True

In [8]:
papers_df.shape

(36261, 7)

## Create graph

In [9]:
G = nx.DiGraph()
G.add_nodes_from(papers_data.keys())

In [10]:
for paper_id, paper_attrs in papers_data.items():
    for citing_id in paper_attrs['cited_by']:
        if citing_id in G:
            G.add_edge(citing_id, paper_id)
    for cited_id in paper_attrs['references']:
        if cited_id in G:
            G.add_edge(paper_id, cited_id)

In [11]:
len(G)

36261

In [12]:
conn_comp_sizes = []
for c in nx.weakly_connected_components(G):
    conn_comp_sizes.append(len(c))

In [13]:
sorted(conn_comp_sizes, reverse=True)[:10]

[4716, 187, 111, 97, 93, 84, 61, 56, 55, 46]

In [14]:
conn_comps = {len(c):c for c in nx.weakly_connected_components(G)}

In [15]:
largest_conn_comp = conn_comps[max(conn_comps.keys())]

In [16]:
len(largest_conn_comp)

4716

In [17]:
Glcc = G.subgraph(largest_conn_comp)

In [18]:
len(Glcc)

4716

In [93]:
#combine previous cells into functions
def get_graph(file_name):
    papers_data = {}
    with open(file_name, 'r') as f:
        papers_data = json.loads(f.read())
    G = nx.DiGraph()
    G.add_nodes_from(papers_data.keys())
    for paper_id, paper_attrs in papers_data.items():
        for citing_id in paper_attrs['cited_by']:
            if citing_id in G:
                G.add_edge(citing_id, paper_id)
        for cited_id in paper_attrs['references']:
            if cited_id in G:
                G.add_edge(paper_id, cited_id)
    return G

def get_LCC(G):
    conn_comps = {len(c):c for c in nx.weakly_connected_components(G)}
    largest_conn_comp = conn_comps[max(conn_comps.keys())]
    Glcc = G.subgraph(largest_conn_comp)
    return Glcc.copy()

In [96]:
G = get_graph(paper_rec_file)
G = get_LCC(G)

In [97]:
len(G)

4716

In [99]:
def load_embeddings_from_jsonl(embeddings_path, G):
    embeddings = {}
    with open(embeddings_path, 'r') as f:
        for line in tqdm(f, desc='reading embeddings from file...'):
            line_json = json.loads(line)
            if line_json['paper_id'] in G:
                embeddings[line_json['paper_id']] = np.array(line_json['embedding'], dtype=np.float32)
    return embeddings

embeddings = load_embeddings_from_jsonl(user_activity_and_citations_embeddings_path, G)

reading embeddings from file...: 142009it [00:21, 6482.55it/s]


In [100]:
Grec = get_LCC(G.subgraph(embeddings.keys()))
nx.set_node_attributes(Grec, {node_id: {"x":embedding} for node_id, embedding in embeddings.items()})

In [None]:
len(Grec), len(Grec.edges), len(embeddings)

(419, 679, 964)

In [103]:
with open(data_path + 'rec_graph.pkl', 'wb') as f:
    pickle.dump(Grec, f)

## Link prediction

In [76]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch_geometric
import torch_geometric.utils
from torch_geometric.utils import from_networkx

In [104]:
class LinkPredictionModel(torch.nn.Module):

    def __init__(self, layer_type, sz_in, num_layers=2, sz_hid=128, sz_out=64):
        super().__init__()

        # GNN layers with ReLU
        encoder = []
        encoder.append(layer_type(sz_in, sz_hid))
        encoder.append(nn.ReLU())
        for _ in range(num_layers-2):
            encoder.append(layer_type(sz_hid, sz_hid))
            encoder.append(nn.ReLU())
        encoder.append(layer_type(sz_hid, sz_out))
        self.encoder = nn.ModuleList(encoder)
    
    # Encoding: usual GNN propagation
    def encode(self, fts, adj):
        for l in self.encoder:
            if isinstance(l, nn.ReLU):
                fts = l(fts)
            else:
                fts = l(fts, adj)
        return fts
    
    # Decoding: dot(H[i], H[j]) for each edge in edge_index
    # Larger dot => the model is more confident that this edge should exist
    def decode(self, H, edge_index):
        return (H[edge_index[0]] * H[edge_index[1]]).sum(dim=1)

In [108]:
from torch_geometric.utils import train_test_split_edges

data = from_networkx(Grec)
data.train_mask = data.val_mask = data.test_mask = data.y = None
print(data)

data = train_test_split_edges(data, 0.2, 0.2)
print(data.x)
print(data)
print()
print(f'Train set: {data.train_pos_edge_index.shape[1]} positive edges, we will sample the same number of negative edges at runtime')
print(f'Val set: {data.val_pos_edge_index.shape[1]} positive edges, {data.val_neg_edge_index.shape[1]} negative edges')
print(f'Test set: {data.test_pos_edge_index.shape[1]} positive edges, {data.test_neg_edge_index.shape[1]} negative edges')

Data(x=[419, 768], edge_index=[2, 679])
tensor([[ 0.0922, -5.5931, -3.3866,  ...,  1.5318, -0.6453, -1.3122],
        [-4.3983, -0.8796, -0.8786,  ..., -2.4426, -0.3191, -1.0760],
        [-4.6795, -3.8367, -1.8152,  ..., -0.8004, -3.1326, -4.6144],
        ...,
        [-3.1228, -4.3944, -0.7400,  ..., -3.5155, -0.7761, -0.7719],
        [-1.0095,  1.1776, -0.4630,  ...,  2.4056, -3.0469, -1.7069],
        [-0.2017, -4.2765, -0.3223,  ..., -1.5384, -1.9732, -4.0335]])
Data(x=[419, 768], val_pos_edge_index=[2, 65], test_pos_edge_index=[2, 65], train_pos_edge_index=[2, 392], train_neg_adj_mask=[419, 419], val_neg_edge_index=[2, 65], test_neg_edge_index=[2, 65])

Train set: 392 positive edges, we will sample the same number of negative edges at runtime
Val set: 65 positive edges, 65 negative edges
Test set: 65 positive edges, 65 negative edges




In [109]:
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score

# Train the given model on the given dataset for num_epochs
def train(model, data, num_epochs):
    # Set up the loss and the optimizer
    loss_fn = nn.BCEWithLogitsLoss() # Binary classification
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # Prepare all edges/labels for val/test
    val_pos, val_neg = data.val_pos_edge_index, data.val_neg_edge_index
    val_edge_index = torch.cat([val_pos, val_neg], dim=1)
    val_labels = torch.cat([torch.ones(val_pos.shape[1]), torch.zeros(val_neg.shape[1])])
    
    test_pos, test_neg = data.test_pos_edge_index, data.test_neg_edge_index
    test_edge_index = torch.cat([test_pos, test_neg], dim=1)
    test_labels = torch.cat([torch.ones(test_pos.shape[1]), torch.zeros(test_neg.shape[1])])

    # A utility function to compute the ROC-AUC on given edges
    def get_roc_auc(model, data, edge_index, labels):
        with torch.no_grad():
            H = model.encode(data.x, data.train_pos_edge_index)
            z = model.decode(H, edge_index)
            s = z.sigmoid()
            return roc_auc_score(labels, s)

    best_auc_val = -1
    for epoch in range(num_epochs):
        # Sample negative edges
        pos_edge_index = data.train_pos_edge_index # T_+
        neg_edge_index = negative_sampling(
            edge_index=pos_edge_index, # edges to ignore
            num_nodes=data.num_nodes, # N
            num_neg_samples=pos_edge_index.shape[1] # number of edges to sample
        )
    
        # Zero grads -> encode to get node latents
        optimizer.zero_grad()
        H = model.encode(data.x, pos_edge_index)

        # Decode to get a score for all (positive and negative) edges
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)
        z = model.decode(H, edge_index)

        # Construct the label vector and backprop the loss
        labels = torch.cat([torch.ones(pos_edge_index.shape[1]), torch.zeros(neg_edge_index.shape[1])])
        
        loss = loss_fn(z, labels)
        loss.backward()
        optimizer.step()

        # Compute accuracies, print only if this is the best result so far
        auc_val = get_roc_auc(model, data, val_edge_index, val_labels)
        auc_test = get_roc_auc(model, data, test_edge_index, test_labels)
        if auc_val > best_auc_val:
            best_auc_val = auc_val
            print(f'[Epoch {epoch+1}/{num_epochs}] Loss: {loss} | Val: {auc_val:.3f} | Test: {auc_test:.3f}')

In [110]:
model = LinkPredictionModel(torch_geometric.nn.GCNConv, data.x.shape[1])
print(model)
train(model, data, num_epochs=1000)

LinkPredictionModel(
  (encoder): ModuleList(
    (0): GCNConv(768, 128)
    (1): ReLU()
    (2): GCNConv(128, 64)
  )
)
[Epoch 1/1000] Loss: 165.4088134765625 | Val: 0.500 | Test: 0.500
[Epoch 3/1000] Loss: 119.71524047851562 | Val: 0.517 | Test: 0.584
