In [1]:
from sentence_transformers import SentenceTransformer
# For model selection, see https://www.sbert.net/docs/pretrained_models.html
# For MPNet, see https://arxiv.org/pdf/2004.09297.pdf
model = SentenceTransformer('all-mpnet-base-v2')

In [1]:
import networkx as nx
from tqdm.notebook import tqdm
import pickle, json, os

def compute_vertex_edges(G):
    # for each vertex v, store which edges are from v to another vertex and from another vertex to v
    G.vertex_edges_from = dict({v:set() for v in G.nodes()})
    G.vertex_edges_to = dict({v:set() for v in G.nodes()})
    for u, v in G.edges():
        G.vertex_edges_from[u].add((u, v))
        G.vertex_edges_to[v].add((u, v))

DBLP

In [None]:
import os, json
from collections import Counter

data_dir = "dblp/"

G = nx.DiGraph()

for conference_file in os.listdir(data_dir):
    with open(os.path.join(data_dir, conference_file), encoding="utf8", errors='ignore') as f:
        conference_data = json.load(f)
    for paper in tqdm(conference_data):
        id = paper["id"]
        if "references" in paper:
            references = paper["references"]
        else:
            references = []
        title  = paper["title"]

        G.add_node(id)
        G.nodes[id]["text"] = title
        G.nodes[id]["features"] = tuple(model.encode(title))

        for reference in references:
            G.add_edge(id, reference)
    print(conference_file)

In [None]:
print("Nodes:", len(G.nodes), "Edges:", len(G.edges))

Nodes: 131626 Edges: 358282


In [5]:
# remove nodes that were only included as references and do not have features
for v in tuple(G.nodes()):
    if not 'features' in G.nodes[v]:
        G.remove_node(v)

In [6]:
print("Nodes:", len(G.nodes), "Edges:", len(G.edges))
compute_vertex_edges(G)

Nodes: 30581 Edges: 70972


In [12]:
with open("dblp_graph.pkl", "wb") as file:
    pickle.dump(G, file)

Reddit

In [34]:
with open("reddit_graph.pkl", "rb") as file:
    G = pickle.load(file)

print("Nodes:", len(G.nodes), "Edges:", len(G.edges))
compute_vertex_edges(G)

for node in G.nodes():
    G.nodes[node]["features"] = tuple(model.encode(G.nodes[node]["text"]))

with open("reddit_graph.pkl", "wb") as file:
    pickle.dump(G, file)

Nodes: 74778 Edges: 74777
