In [1]:
import os, sys

In [9]:
from neo4j import GraphDatabase
from neo4j.data import Record
from dotenv import dotenv_values
from models import *

config = dotenv_values(".env")  # config = {"USER": "foo", "EMAIL": "foo@example.org"}

class GraphDBDriver:
    """
    Main methods:
        query_node_dict: Query database by a node dictionary
        query: makes a direct cypher query
        upload_nodes: Upload an iterable of nodes to the database
        upload_edges: Upload an iterable of edges to the database

    """
    def __init__(self, remote=False):
        if remote:
            uri = config["REMOTE_GRAPH_URI"]
            user = config["REMOTE_GRAPH_USER"]
            password = config["REMOTE_GRAPH_PWD"]
        else:
            uri = config["LOCAL_GRAPH_URI"]
            user = config["LOCAL_GRAPH_USER"]
            password = config["LOCAL_GRAPH_PWD"]

        try:
            self.driver = GraphDatabase.driver(uri, auth=(user, password))
        except Exception as e:
            self.driver = None
            print("Failed to create the driver:", e)

    def close(self):
        if self.driver:
            self.driver.close()

    # Query Methods (returns a list of neo4j.data.Records)
    def query_node_dict(self, node_dict):
        return self.query("MATCH " + self._node_dict_to_cypher(node_dict) + " RETURN node")
    
    def query_node(self, node):
#         print("MATCH (node:{} {{key: \"{}\"}}) RETURN node".format(node.type, node.key))
        return self.query("MATCH (node:{} {{key: \"{}\"}}) RETURN node".format(node.type, node.key))

    def query_edge(self, edge):
        return self.query("MATCH (a:{} {{key: \"{}\"}})-[edge:{}]->(b:{} {{key: \"{}\"}}) RETURN edge".format(edge.source.type, edge.source.key, edge.label, edge.dest.type, edge.dest.key))

    # Returns a list of neo4j.data.Records
    def raw_query(self, query, parse_node=False):
        assert self.driver, "Driver not initialized!"
        session = None
        response = None
        try: 
            session = self.driver.session()
            response = list(session.run(query))
        except Exception as e:
            print("Query failed:", e)
        finally: 
            if session:
                session.close()
        if parse_node:
            return [self.record_to_models(record)['node'] for record in response]
        else:
            return response    

    # Upload Methods
    # Node Methods
    def upload_nodes(self, nodes):
        assert self.driver, "Driver not initialized!"
        with self.driver.session() as session:
            ret = []
            for node in nodes:
                exists = list(session.run("MATCH (node:{} {{key: \"{}\"}}) RETURN node".format(node.type, node.key)))
                if len(exists) == 0:
                    # print("Attempting to upload", node.title, str(node.attrs))
                    # assert type(doc) == Document , "Error: non-Document node passed to doc upload function"
                    ret.append(session.write_transaction(self._create_and_return_node, node.to_dict()))
                    print("Uploaded", ret[-1])
                else:
                    print(node.key, "already exists in database")
            return ret

    @staticmethod
    def _create_and_return_node(tx, node_dict):
        cypherquery = ["CREATE (node:{})".format(node_dict['type'])]
        for key in node_dict.keys():
            cypherquery.append("SET node.{} = ${}".format(key, key))
        cypherquery.append("RETURN node")
#         print(" ".join(cypherquery))
        result = tx.run(" ".join(cypherquery), node_dict)
        entry = result.single()        
        return entry['node'] if entry else None

    # Edges Methods
    def upload_edges(self, edges):
        assert self.driver, "Driver not initialized!"
        with self.driver.session() as session:
            ret = []
            for edge in edges:
                exists = list(session.run("MATCH (a:{} {{key: \"{}\"}})-[edge:{}]->(b:{} {{key: \"{}\"}}) RETURN edge".format(edge.source_type, edge.source_key, edge.label, edge.dest_type, edge.dest_key)))
                if len(exists) == 0:
                    # assert type(doc) == Document , "Error: non-Document node passed to doc upload function"
                    ret.append(session.write_transaction(self._create_and_return_edge, edge.to_dict(), edge.source_key, edge.dest_key))
                    print("Uploaded", ret[-1])
                else:
                    print(str(edge), "already exists in database")
            return ret

    @staticmethod
    def _create_and_return_edge(tx, edge_dict, source_key, dest_key):
        cypherquery = ["MATCH (a), (b)", "WHERE a.key=\"{}\" AND b.key = \"{}\"".format(source_key, dest_key), "CREATE (a)-[edge:{}]->(b)".format(edge_dict['label'])]
        for key in edge_dict.keys():
            cypherquery.append("SET edge.{} = ${}".format(key, key))
        cypherquery.append("RETURN edge")
#         print(" ".join(cypherquery))
        result = tx.run(" ".join(cypherquery), edge_dict)
        # print(result)
        entry = result.single()        
        return entry['edge'] if entry else None

    # Helper Methods
    """
    Converts a node in dictionary form to a cypher create query
    """
    @staticmethod
    def _node_dict_to_cypher(node, name="node"):
        assert type(node) in [dict, defaultdict]
        query = "({}:{} {{".format(name, node['type'])
        properties = []
        for key, value in node.items():
            properties.append("{}:\"{}\"".format(key, value))
        props = ", ".join(properties)
        end_query = "})"
        return query + props + end_query

    """
    Converts a returned object to a dictionary of Node or Edge from Soup Models
    """
    @staticmethod
    def record_to_models(record):
        assert type(record) == Record
        assert 'node' in record.keys() or 'edge' in record.keys(), "Neither node or edge found in record keys: " + str(record.keys())
        ret = dict()
        def node_to_model(node):
            # print(node)
            # print(node.keys())
            if node['type'] == 'entity':
                return Entity(node['key'], attrs=dict(node))
            elif node['type'] == 'action':
                return Action(node['key'], attrs=dict(node))
            elif node['type'] == 'source':
                return Source(node['key'], node['name'], node['source_type'], attrs=dict(node), date_processed=node['date_processed'])
            elif node['type'] == 'document':
                return Document(node['key'], node['title'], node['doc_type'], attrs=dict(node), date_processed=node['date_processed'])
            else:
                print("ERROR: unrecognized node type:", node['type'])

        # Process Node
        if 'node' in record.keys():
            assert record['node']['type'] in ['entity', 'action', 'source', 'document'], "Error: unidentified node type " + record['node']['type']
            node = node_to_model(record['node'])
            if node:
                ret['node'] = node
        
        # TODO: This part can probably be greatly optimized since I am building two new node objects for every edge
        # Process Edge
        if 'edge' in record.keys():
            assert record['edge']['label'] in ['contains', 'authored', 'interacts', 'references', 'involved'], "Error: unidentified edge type " + record['edge']['label']
            edge = record['edge']
            # source, dest = node_to_model(edge.nodes[0]), node_to_model(edge.nodes[1])
            # if edge['label'] == 'contains':
            #     ret['edge'] = Contains(source, dest)
            # elif edge['label'] == 'authored':
            #     ret['edge'] = Authored(source, dest)
            # elif edge['label'] == 'interacts':
            #     ret['edge'] = Interacts(source, dest, edge['interaction_type'])
            # elif edge['label'] == 'references':
            #     ret['edge'] = References(source, dest)
            # elif edge['label'] == 'involved':
            #     ret['edge'] = Involved(source, dest)
            ret['edge'] = (edge['label'], edge['source_key'], edge['dest_key'])
        
        return ret
            

# if __name__ == "__main__":  
#     # greeter = HelloWorldExample("bolt://localhost:7687", "neo4j", "neo4j")
#     print("Testing Graph Driver...")
#     driver = GraphDBDriver()
#     ret = driver.query("MATCH (node)-[edge:interacts]->() RETURN node, edge")
#     res = driver.record_to_models(ret[0])
#     print(res)
#     # print("Querying all nodes")
#     # print(driver.query("MATCH (n) return n"))
#     # print("Uploading one node")
#     # tweet1 = Document("tweet1", "title", "tweet")
#     # source = Source("source1", "title", "source")
#     # driver.upload_nodes([tweet1, source])
#     # print("Adding connecting Edge")
#     # edge = Interacts(source, tweet1, "follows")
#     # driver.upload_edges([edge])
#     # # driver.close()
    
#     # print("Querying single node")
#     # result = driver.query_node(tweet1)
#     # print(result[0]['node'].keys())
#     # result1 = driver.query_edge(edge)
#     # print(result1[0]['edge'].keys())
#     # # print("Uploading several nodes with pre-existing in database")
#     # # tweets = [tweet, Document("test2", "title", "tweet"), Document("test3", "title", "tweet")]
#     # # # driver.upload_nodes(tweets)
#     # driver.close()
#     # print("Finished")

In [11]:
driver = GraphDBDriver()
ret = driver.raw_query("MATCH (node:source) RETURN node LIMIT 10", parse_node=True)
len(ret), type(ret[0]), ret

(10,
 models.Source,
 [<models.Source at 0x1bb61a02c88>,
  <models.Source at 0x1bb61a02cc0>,
  <models.Source at 0x1bb61a1aac8>,
  <models.Source at 0x1bb61a1a748>,
  <models.Source at 0x1bb61a1a048>,
  <models.Source at 0x1bb6014db00>,
  <models.Source at 0x1bb624e2240>,
  <models.Source at 0x1bb62432780>,
  <models.Source at 0x1bb62432a90>,
  <models.Source at 0x1bb624325c0>])

In [8]:
driver.record_to_models(ret[0])

{'node': <models.Document at 0x1bb6014d8d0>}

In [None]:
for node_type in ["document", "entity", "action", "source"]:
    querystring = "CREATE CONSTRAINT unique_key_{} IF NOT EXISTS ON (n:{}) ASSERT n.key IS UNIQUE".format(node_type, node_type)
    print(driver.query(querystring))

In [None]:
driver.query_node(tweet)

In [None]:
driver.upload_nodes(tweets)