In [19]:
import pandas as pd
import json
from tqdm import tqdm

4999it [00:01, 4640.99it/s]


In [None]:
# Database wrapper from: https://towardsdatascience.com/create-a-graph-database-in-neo4j-using-python-4172d40f89c4
from neo4j import GraphDatabase


class Neo4jConnection:

    def __init__(self, uri, user, pwd):
        self.__uri = uri
        self.__user = user
        self.__pwd = pwd
        self.__driver = None
        try:
            self.__driver = GraphDatabase.driver(self.__uri, auth=(self.__user, self.__pwd))
        except Exception as e:
            print("Failed to create the driver:", e)

    def close(self):
        if self.__driver is not None:
            self.__driver.close()

    def query(self, query, parameters=None, db=None):
        assert self.__driver is not None, "Driver not initialized!"
        session = None
        response = None
        try:
            session = self.__driver.session(database=db) if db is not None else self.__driver.session()
            response = list(session.run(query, parameters))
        except Exception as e:
            print("Query failed:", e)
        finally:
            if session is not None:
                session.close()
        return response

port = 7687 # Check if is the case for your server!

conn = Neo4jConnection(uri="bolt://localhost:"+str(port),
                       user="driver",
                       pwd="driver")

In [None]:
conn.query('CREATE CONSTRAINT papers IF NOT EXISTS ON (p:Paper)     ASSERT p.id IS UNIQUE')
conn.query('CREATE CONSTRAINT authors IF NOT EXISTS ON (a:Author) ASSERT a.name IS UNIQUE')
conn.query('CREATE CONSTRAINT categories IF NOT EXISTS ON (c:Category) ASSERT c.category IS UNIQUE')

In [None]:
conn.query('DROP CONSTRAINT ON (a:Author) ASSERT a.name IS UNIQUE')

In [41]:
import time


def insert_data(query, rows, batch_size = 10000):
    # Function to handle the updating the Neo4j database in batch mode.

    total = 0
    batch = 0
    start = time.time()
    result = None

    while batch * batch_size < len(rows):

        res = conn.query(query,
                         parameters = {'rows': rows[batch*batch_size:(batch+1)*batch_size].to_dict('records')})
        total += res[0]['total']
        batch += 1
        result = {"total_inserted":total,
                  "batches_done":batch,
                  "total_time":time.time()-start}
        print(result)

    return result

def add_papers(rows, batch_size=5000):
   # Adds paper nodes and (:Author)--(:Paper)
   query = '''
    // Create papers
    UNWIND $rows as paper
    MERGE (p:Paper {paperid: paper.id})
    ON CREATE SET
    p.title = paper.title,
    p.year = paper.year,
    p.n_citation = paper.n_citation,
    p.doi = paper.doi

    // Match authors
    WITH paper, p
    UNWIND  paper.authors AS author
    MERGE (a:Author {authorid: author.id})
    ON CREATE SET a.name = author.name
    MERGE (a)-[:AUTHORED]->(p)

    // Match references
    WITH paper, p
    UNWIND  paper.references AS refid
    MATCH (r:Paper {paperid:refid})
    MERGE (p)-[:references]->(r)
    RETURN count(p:Paper) as total
   '''

   return insert_data(query, rows, batch_size)

In [None]:
file = "data/dblp_papers_v11.txt"

subset = ["id", "title", "year", "n_citation", "doi", "authors", "references"]

# TODO: Might be possible to speed up inserts by using just the name for matching instead of ID https://stackoverflow.com/a/23609143/9994398
with open(file, 'r') as f:
    while True:
        try:
            lines = 100000
            rows  = []
            for line in tqdm(f):
                rows.append(json.loads(line))
                lines -= 1
                if lines == 0: break
            df = pd.DataFrame(rows)
            add_papers(df[subset], 5000)
        except Exception as e:
            print(e)
            break

99999it [00:19, 5109.87it/s] 


In [None]:
def get_stack_exchange_df():
    path = "data/Posts.xml"
    with open(path, 'r') as f:
        try:
            raw_data = f.read()
        except Exception as e:
            print(e)
    return pd.read_xml(raw_data)

df = get_stack_exchange_df()
df.head()