In [None]:
import time
import os
from itertools import product
import math
from collections import Counter

from lark import Lark, Transformer, v_args

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics.pairwise import cosine_similarity


import networkx as nx
import numpy as np
import pandas as pd

from gensim.models import Word2Vec
from stellargraph.data import UniformRandomWalk, BiasedRandomWalk, UniformRandomMetaPathWalk
from stellargraph import StellarGraph, StellarDiGraph, datasets
from neo4j import GraphDatabase

import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Regular Expressions

Regular expression grammar meant for creating succinct patterns through a graph. The grammar below must contain a NODE Identifier, which should correspond to the node label in the Neo4J database. 

Examples of valid grammar statements:

<i>"chemical_substance treats> disease"</i> -> this would search for any node with the label <b>chemical_substance</b> connected to nodes of type <b>disease</b> with an edge of type <i>treats</i>.

<i>"chemical_substance ? disease"</i> -> this would search for any node with the label <b>chemical_substance</b> connected to nodes of type <b>disease</b> with any edge.


In [None]:
regex_grammar = """
    start: node 
         | node (edge node)+
         | node "(" path ")"

       
    ?path: edge_node
         | path "|" edge_node   -> path_or 
    
    ?edge_node: edge node
         | "(" edge_node+ ")" 
         

         
    ?edge: attm
         | edge "|" attm        -> edge_or
         
    ?attm: EDGE_LABEL           -> edge
         | attm ">"             -> edge_right
         | attm "<"             -> edge_left
         | NULL                 -> edge_no_label
         | "(" edge ")"     
    
    ?node: atom
        | node "|" atom         -> node_or

    ?atom: NODE_LABEL           -> node
         | NODE_LABEL "*"       -> rep_from_0
         | NODE_LABEL "+"       -> rep_from_1
         | NULL                 -> node_no_label
         | "(" node ")"

    EDGE_LABEL: LABEL_STRING
    NODE_LABEL: LABEL_STRING
    NULL: "?"
    LCASE_LETTER: "a".."z"
    UCASE_LETTER: "A".."Z"
    DIGIT: "0".."9"

    LETTER: UCASE_LETTER | LCASE_LETTER | DIGIT | "_" | "-"
    LABEL_STRING: LETTER+ | "_" 

    %import common.CNAME -> NAME
    %import common.WS_INLINE
    %ignore WS_INLINE
"""


Lark parser. Converts a subgraph regular expression into a CYPHER query. These functions act on different triggers provided by the grammar above.

In [None]:
@v_args(inline=True)    # Affects the signatures of the methods
class CalculateTree(Transformer):
    node_idx = 0
    edge_idx = 0
        
    def node(self, name):
        self.node_idx += 1
        return "(n{}:".format(self.node_idx) + str(name) +")"
    
    def edge(self, name):
        self.edge_idx += 1
        return "-[r{}:".format(self.edge_idx) + str(name) +"]-"
    
    def node_no_label(self, name):
        self.node_idx += 1
        return "(n{}".format(self.node_idx) +")"
    
    def edge_no_label(self, name):
        self.edge_idx += 1
        return "-[r{}".format(self.edge_idx) +"]-"
    
    def edge_node(self, name1, name2):
        path = name1 + name2
        return path

    def edge_right(self, name):
        path = ''
        path = name 
        return path + ">"

    def edge_left(self, name):
        path = ''
        path = name 
        return "<" + path

    def rep_from_0(self, name):
        self.node_idx += 1
        path1 = "(n{}:".format(self.node_idx) + str(name) +")"
        path2 = "(n{}:".format(self.node_idx) + str(name) +")" + "--" + "(n{}:".format(self.node_idx+1) + str(name) +")"
        self.node_idx += 1
        return ['', path1, path2]


    def rep_from_1(self, name):
        self.node_idx += 1
        path1 = "(n{}:".format(self.node_idx) + str(name) +")"
        path2 = "(n{}:".format(self.node_idx) + str(name) +")" + "--" + "(n{}:".format(self.node_idx+1) + str(name) +")"
        self.node_idx += 1
        return [path1, path2]

    def node_or(self, name1, name2):
        return [name1, name2]
    
    def edge_or(self, name1, name2):
        return [name1, name2]
    
    def path_or(self, name1, name2):
        return [name1, name2]

In [None]:
#Permutes all possible pathways from the regular expression.
def extractPathways(parse_tree):
    all_elems = []
    mlist = []
    #Walks through the parse tree. If a node may have two or more labels
    # it is added to our collection as a list of all possible labels.
    for child in parse_tree.children:
        if type(child) == str:
            mlist.append([child])
        else:
            mlist.append(child)
            
    #We iterate through all possible combinations of node and edge labels
    # along the provided regex.
    for i in product(*mlist):
        all_elems.append(list(i))
        
    #Filter null characters out of node labels.
    pathways = []
    for i in all_elems:
        if('' in i): 
            idx = i.index("")
            i.pop(idx)
            i.pop(idx-1)
            pathways.append(i)
        else:
            pathways.append(i)
    return pathways

#Generates a CYPHER query for a regular expression. The first node in the regular expression will be 
# mapped onto the source_node_name.
def getQueries(source_node_name, regexes):
    all_queries = []
    subgraph_nodes = []
    # parse
    regex_parser = Lark(regex_grammar, parser='lalr',transformer=CalculateTree())
    parsed = regex_parser.parse(regexes)
    all_pathways = extractPathways(parsed)
    
    queryStr = ''
    final_idx = len(all_pathways)-1
    for i, path in enumerate(all_pathways):
        
        if(path[-1]==''):path[-1]
        start_node_num = path[0].split(":")[0].split("(")[1].split(")")[0]
        query = "MATCH p1="
        query += ''.join(str(elem) for elem in path)   
        query += " WHERE %s.name =" % start_node_num
        query += " '%s'" % source_node_name
        
        # add conditions to the node names if repeated types in the path
        add_where = ""
        repeated = []
        for j in path:
            if ":" in j: # if type is given
                if j.split(":")[1] in repeated: # if type is repeated
                    if ")" in j: # if it is a node
                        prev = path[repeated.index(j.split(":")[1])].split(":")[0].split("(")[1]
                        current = j.split(":")[0].split("(")[1]
                        add_where += " AND %s" % prev
                        add_where += " <> %s" % current
                repeated.append(j.split(":")[1])
            else:
                repeated.append("?")                
            
        query += add_where
        query += " WITH collect(p1) as nodez UNWIND nodez as c RETURN c"
        if(i==final_idx): queryStr += query
        else: queryStr += query + " UNION "
    
    return queryStr

In [None]:
def parsing(regexes):
    # parse
    regex_parser = Lark(regex_grammar, parser='lalr',transformer=CalculateTree())
    reg = regex_parser.parse
    parsed = reg(regexes)
    all_elems = extractPathways(parsed)

    
    llink = []
    for i in all_elems:
        Link = []
        for j in i:
            if ':' in j:
                if '(' in j:
                    nodes = j.split(':')[1].split(')')[0]
                    Link.append(nodes)
                if '[' in j:
                    edges = j.split(':')[1].split(']')[0]
                    Link.append(edges)
            else:
                edges = '?'
                Link.append(edges)
        #print(Link)

        for i in range(math.floor(len(Link)/2)):
            #print(i)
            if i == 0:
                llink.append(Link[i:(i+3)])
            else:
                llink.append(Link[(i*2):(i*2)+3])

    return llink




# Provide Regex Subgraphs for a Neo4J database.

In [None]:
def getSubgraph_neo4j(graph_uri, source_node_name, regexes, compared_labels = None):
    
    queryStr = getQueries(source_node_name, regexes)    
    
    driver = GraphDatabase.driver(graph_uri)
    
    user_labels = []
    for ele in parsing(regexes):
        user_labels += ele
    user_labels = list(set(user_labels))
            
    with driver.session() as session:
        result = session.run(queryStr)
        d = {}
        join_values = []
        for i in result.graph().nodes:
            node_name = i['name']
            if node_name not in join_values:
                #print('labels = ',list(i.labels))
                if len(i.labels)>1:
                    for m in i.labels:
                        if m in user_labels:
                            node_type = m
                        
                        ### for multiple-labeled graph using regex "? ? ?"
                        
                        elif compared_labels != None:
                            if m in compared_labels:
                                node_type = m
                        else:
                            node_type = list(i.labels)[0]
                        ###
                else:
                    node_type = list(i.labels)[0]
                s = d.get(node_type,set())
                s.add(node_name)
                d[node_type] = s
            join_values.append(node_name)

        rels = set()
        for i in result.graph().relationships:
            start = i.start_node["name"]
            end = i.end_node["name"]
            rel_type = i.type
            rels.add((start, end, rel_type))

    raw_nodes = d        
    edges = pd.DataFrame.from_records(list(rels),columns=["source","target","label"])

    data_frames = {}
    for k in d:
        node_names = list(d[k])
        df = pd.DataFrame({"name":node_names}).set_index("name")
        data_frames[k] = df

    sg = StellarDiGraph(data_frames,edges=edges, edge_type_column="label")
   
    return sg 

# find union subgraph for two drugs
def querySubgraph(G, regexes, queryStr, compared_labels = None):
    
    uri = G
    driver = GraphDatabase.driver(uri)
    
    user_labels = []
    for ele in parsing(regexes):
        user_labels += ele
    user_labels = list(set(user_labels))
    
    with driver.session() as session:
        result = session.run(queryStr)
        d = {}
        join_values = []
        for i in result.graph().nodes:
            node_name = i['name']
            if node_name not in join_values:
                #print('labels = ',list(i.labels))
                if len(i.labels)>1:
                    for m in i.labels:
                        if m in user_labels:
                            node_type = m
                        
                        ### for multiple-labeled graph using regex "? ? ?"
                        
                        elif compared_labels != None:
                            if m in compared_labels:
                                node_type = m
                        else:
                            node_type = list(i.labels)[0]
                        ###
                else:
                    node_type = list(i.labels)[0]
                s = d.get(node_type,set())
                s.add(node_name)
                d[node_type] = s
            join_values.append(node_name)

        rels = set()
        for i in result.graph().relationships:
            start = i.start_node["name"]
            end = i.end_node["name"]
            rel_type = i.type
            rels.add((start, end, rel_type))

    raw_nodes = d        
    edges = pd.DataFrame.from_records(list(rels),columns=["source","target","label"])

    data_frames = {}
    for k in d:
        node_names = list(d[k])
        df = pd.DataFrame({"name":node_names}).set_index("name")
        data_frames[k] = df

    sg = StellarDiGraph(data_frames,edges=edges, edge_type_column="label")
    
    return sg 

#Helper function. Constructs a dictonary; the keys are node names provided in
# node_list. The values are the semantic subgraphs constructed using the 
# parameters G, semantic_query, and compare_labels.
def buildSubgraphDictonaryForNodes(node_list, G, semantic_query, compared_labels):
    subGs = {}
    for node in node_list:
        subG = getSubgraph_neo4j(G, node, semantic_query, compared_labels)
        subGs[node] = subG
    return subGs

In [None]:
#Returns counts of all node by labels in graph and all relationships by types.
def infoDict(subG):
    Info = {}
    for i in subG.info().split('\n'):
        if '[' in i:
            temp = i.split(':')
            text = temp[0].strip()
            num = temp[1].split('[')[1].split(']')[0]
            Info[text] = num
        
    return Info

# Generate Random Walks 

In [None]:
#Generates random walks for various methods.
def compactWalks(subgraph_dict, node_list, method, l, r, metapath = None):
    subGs = {}


    Walks = []
    for node in node_list:
        subG = subgraph_dict[node]
        # DeepWalk
        if method == 'deepwalk':
            rw = UniformRandomWalk(subG) #BiasedRandomWalk(G)
            walks = rw.run(
                nodes= [node],#list(G.nodes()),  # root nodes
                length = l,#adj_wlength,  # maximum length of a random walk
                n = r #,  # number of random walks per root node
                #seed = 1
            )

        # Node2Vec
        elif method == 'node2vec':
            rw = BiasedRandomWalk(subG)
            walks = rw.run(
                nodes= [node],  # root nodes
                length = l,  # maximum length of a random walk
                n = r,  # number of random walks per root node
                p = 0.25,  # Defines (unormalised) probability, 1/p, of returning to source node
                q = 0.25#,  # Defines (unormalised) probability, 1/q, for moving away from source node
                #seed = 5
            )

        #Metapath2vec
        elif method == 'metapath2vec':
            rw = UniformRandomMetaPathWalk(subG)
            walks = rw.run(
                nodes= [node],#list(G.nodes()),  # root nodes
                length = l,  # maximum length of a random walk
                n = r,  # number of random walks per root node
                metapaths = metapath#,
                #seed = 5
            )

        # append walks
        for w in walks:
            Walks.append(w)
            
    return Walks

# find semantic ratio for Walks provided.
def sematicRatio_walks(regexes, Walks, subGs):
    num = 0
    den = 0

    # parse
    llink = parsing(regexes)
    print(llink)

    # matching process
    for i in Walks:     
        # matching nodes
        for j in i:
            res = []
            # find node type for j
            # if two or more graphs
            if type(subGs) == dict: 

                nodes = []
                for n in subGs.keys():
                    nodes.append(n)

                
                for n in nodes:
                    if j in subGs[n].nodes():
                        node_label = subGs[n].node_type(j)
                        break



            # if only one graph        
            else: 
                node_label = subGs.node_type(j)

            for l in llink:
                if node_label in l:
                    res.append('Y')
                    break
                else:
                    res.append('N')
            
            # counting how many signals in nodes
            if('Y' in res):
                num += 1

            den += 1
        
        
        # matching edges
        for j in range(len(i)-1):
        
            res = []

            node1 = i[j]
            node2 = i[j+1]
            
            # if two graphs
            if type(subGs) == dict: 
                
                for n in nodes:
                    if (node1, node2) in subGs[n].edges():
                        loc = subGs[n].edges().index((node1, node2))
                        edge_label = subGs[n].edges(' ')[loc][2]
                        break
                    elif (node2, node1) in subGs[n].edges():
                        loc = subGs[n].edges().index((node2, node1))
                        edge_label = subGs[n].edges(' ')[loc][2]
                        break

            
            # if one graph
            else:
                if (node1, node2) in subGs.edges():
                    loc = subGs.edges().index((node1, node2))
                    edge_label = subGs.edges(' ')[loc][2]
                elif (node2, node1) in subGs.edges():
                    loc = subGs.edges().index((node2, node1))
                    edge_label = subGs.edges(' ')[loc][2]
            

            for l in llink:
                if edge_label in l:
                    res.append('Y')
                    break
                else:
                    res.append('N')
                           
            # counting how many signals in edges
            if('Y' in res):
                num += 1

            
            den += 1
    
    print(num, den)
    
    return round(num/den,4)

# Machine Learning on Semantic Subgraphs

In [None]:
#Generates a model with embeddings from provided collection of Walks.
def buildModel(Walks):
    str_walks = [[str(n) for n in walk] for walk in Walks]
    model = Word2Vec(str_walks, size=128, window=10, min_count=0, sg=1, workers=2, iter=5)
    return model

#Computes various benchmarks for a machine learning models.
def evaluate(model, subgraph_dict, node_list1, node_list2):
    evaluate_dict = {}
    hit_at_1_in_list = 0
    hit_at_3_in_list = 0
    hit_at_5_in_list = 0
    mrr_in_list = 0
    
    Node_List = node_list1 + node_list2
    for n in node_list1:

        print("==", n, "==")
        n_all = 0
        num_in_list = 0
        rank_in_list = 0

        for i in model.wv.most_similar(n, topn = 20000):
            n_all += 1

            for j in Node_List:
                if i[0] in subgraph_dict[j].nodes():
                    nodeType = subgraph_dict[j].node_type(i[0])
                    break

            if i[0] in Node_List:
                num_in_list += 1

            if i[0] == node_list2[node_list1.index(n)]:
                print('drugs* ',num_in_list)
                # test: include only drugs in list
                rank_in_list = num_in_list
                if rank_in_list == 1:
                    hit_at_1_in_list += 1
                if rank_in_list <= 3:
                    hit_at_3_in_list += 1
                if rank_in_list <= 5:
                    hit_at_5_in_list += 1
                mrr_in_list += 1/rank_in_list

    print('compute only in list:')
    print("num of Compound*: ",num_in_list)
    print("HIT@1 = ", round(hit_at_1_in_list/len(node_list1),4))
    print("HIT@3 = ", round(hit_at_3_in_list/len(node_list1),4))
    print("HIT@5 = ", round(hit_at_5_in_list/len(node_list1),4))
    print("MRR = ", round(mrr_in_list/len(node_list1),4))
    HIT1 = round(hit_at_1_in_list/len(node_list1),4)
    HIT3 = round(hit_at_3_in_list/len(node_list1),4)
    HIT5 = round(hit_at_5_in_list/len(node_list1),4)
    MRR = round(mrr_in_list/len(node_list1),4)
    
    evaluate_dict['HIT@1'] = HIT1
    evaluate_dict['HIT@3'] = HIT3
    evaluate_dict['HIT@5'] = HIT5
    evaluate_dict['MRR'] = MRR
    
    return evaluate_dict

<h1>Example Run</h1>

In [None]:
# given G, user_input, Node_List, method(m), walk_length(l), num_walk(r), compared_labels, metapath     
node_list = ['Canagliflozin', 'Dapagliflozin','Dexamethasone', 'Betamethasone','Lapatinib', 'Afatinib',
            'Captopril', 'Enalapril','Losartan', 'Valsartan','Nifedipine', 'Felodipine',
            'Simvastatin', 'Atorvastatin','Alendronate', 'Incadronate','Citalopram', 'Escitalopram']

semantic_query = "Compound BINDS_CbG Gene ASSOCIATES_DaG< Disease"

subgraph_dict = buildSubgraphDictonaryForNodes(node_list, "bolt://neo4j.het.io", semantic_query, None)

Walks = compactWalks(subgraph_dict, node_list, 'deepwalk', 80, 5)
model = buildModel(Walks)

# Retrieve node embeddings and corresponding subjects
node_ids = model.wv.index2word  # list of node IDs
node_embeddings = (
    model.wv.vectors
)

evaluate(model, subgraph_dict, ['Canagliflozin', 'Dexamethasone','Lapatinib', 
            'Captopril','Losartan', 'Nifedipine', 
            'Simvastatin', 'Alendronate', 'Citalopram'], ['Dapagliflozin','Betamethasone','Afatinib',
            'Enalapril','Valsartan','Felodipine',
            'Atorvastatin','Incadronate','Escitalopram'])