Different alternatives for building paths:
* https://neo4j.com/labs/apoc/4.1/graph-querying/expand-spanning-tree/
````
graph.query("""MATCH (p:Drug {id: "DB00295"})
CALL apoc.path.spanningTree(p, {
        minLevel: 1,
        maxLevel: 1
    })
    YIELD path
    RETURN path;
    """)
````
* Query with relations and distances.

````
graph.query("""MATCH (p:Admission {id: "130248"})
-[*1..1]- (c) 
RETURN c;
    """)
````
* Query with expand.
````
graph.query("""
match (c:ICD {id :'ICD9CM/427.31'})
call apoc.path.expand(c,'*','*',0,3) yield path as pp
return pp;
""")
````
* Subgraph.
````
call apoc.path.subgraphNodes(startNode <id>Node/list, {maxLevel, relationshipFilter, labelFilter, bfs:true, filterStartNode:true, limit, optional:false, endNodes, terminatorNodes, sequence, beginSequenceAtStart:true}) yield node
````

Of course that retrieving neighbours is easy, but expanding the subgraph when nodes have over 10k relations is computationally expensive.

In [None]:
import os
import pandas as pd
from tqdm.notebook import tqdm
import time

In [None]:
path_dir = './'

In [None]:
os.environ["OPENAI_API_KEY"] = ""

In [None]:
from langchain.graphs import Neo4jGraph

NEO4J_URI= "bolt://localhost:7687"
NEO4J_USERNAME= "neo4j" 
NEO4J_PASSWORD= ""

graph = Neo4jGraph(
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD
)

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
search_cypher_queries = [] # TODO: Ver qué me conviene retornar !!

search_cypher_queries.append("""WITH $embedding AS e
                                CALL db.index.vector.queryNodes('acts_as_index', $k, e) yield node, score
                                RETURN node AS result
                                ORDER BY score DESC
                            """)

search_cypher_queries.append("""WITH $embedding AS e
                                CALL db.index.vector.queryNodes('is_known_as_index', $k, e) yield node, score
                                RETURN node AS result
                                ORDER BY score DESC
                            """)

search_cypher_queries.append("""WITH $embedding AS e
                                CALL db.index.vector.queryNodes('is_indicated_for_index', $k, e) yield node, score
                                RETURN node AS result
                                ORDER BY score DESC
                            """)

search_cypher_queries.append("""WITH $embedding AS e
                                CALL db.index.vector.queryNodes('can_be_described_as_index', $k, e) yield node, score
                                RETURN node AS result
                                ORDER BY score DESC
                            """)

search_cypher_queries.append("""WITH $embedding AS e
                                CALL db.index.vector.queryNodes('is_known_as_icd_index', $k, e) yield node, score
                                RETURN node AS result
                                ORDER BY score DESC
                            """)

In [None]:
query_name = '_df_base_queries_human_adhoc.pickle'
df_queries = pd.read_pickle(path_dir + query_name)

df_queries = df_queries[[x for x in df_queries.columns if x in set(['number','topics','description','query_summary'])]]
df_queries

In [None]:
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.chains.question_answering.stuff_prompt import CHAT_PROMPT
from langchain.callbacks.manager import CallbackManagerForChainRun

from typing import Any, Dict, List
from pydantic import Field

class Neo4jVectorChain(Chain):
    """Chain for question-answering against a neo4j vector index."""

    graph: Neo4jGraph = Field(exclude=True)
    input_key: str = "prompt"  #: :meta private:
    input_medical_key: str = 'medical_note' #: :meta private:
    output_key: str = "result"  #: :meta private:
    embeddings: SentenceTransformer = model
    qa_chain: LLMChain = LLMChain(llm=ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0), prompt=CHAT_PROMPT)
    indexes = search_cypher_queries
    cache_context = {}
    cache_context[True] = {}
    cache_context[False] = {}

    cache_nodes = {}
    
    @property
    def input_keys(self) -> List[str]:
        """Return the input keys.
        :meta private:
        """
        return [self.input_key, self.input_medical_key]

    @property
    def output_keys(self) -> List[str]:
        """Return the output keys.
        :meta private:
        """
        _output_keys = [self.output_key]
        return _output_keys

    def _call(self, inputs: Dict[str, str], run_manager, k=3) -> Dict[str, Any]:
        """Embed a question and do vector search."""
        prompt = inputs[self.input_key] # desacopla la instrucción de lo que se embede para buscar en el grafo
        medical_note = inputs[self.input_medical_key]
        embedding = self.embeddings.encode(medical_note)
        expand_path = inputs.get('expand_path',False)
        
#         if medical_note not in self.cache_context[expand_path]:
#             print('Not in cache')        
        retrieved_nodes = []
        for index_ in self.indexes:
            pp = self.graph.query(index_, {'embedding': embedding, 'k': k})
#             print(pp)
            retrieved_nodes.extend([x['result'] for x in pp])

        add_nodes = []
        if expand_path: # we add all node neighbours
            print('Expanding set of retrieved nodes...')
            rr_ = set([x['id'] for x in retrieved_nodes])
            for x in tqdm(rr_):

                if x in self.cache_nodes: # this allows more freedom in how then nodes are used
                    print('In cache')
                    rr = self.cache_nodes[x]
                else:
                    print('Not in cache')
                    rr = graph.query("""
                        UNWIND $data AS x
                        MATCH (p {id: x["id"]})
                              -[*1..1]- (c) 
                              RETURN c;
                              """,{'data':[{'id':x}]})
                    self.cache_nodes[x] = rr

                add_nodes.extend([y['c'] for y in rr])

        retrieved = set()
        context = set()

        for x in retrieved_nodes:
            if x['id'] in retrieved:
                continue
            retrieved.add(x['id'])
            if x['id'][0] == 'D': # drug
#                 if 'is_known_as' in x: context.add(x['is_known_as'])
#                 if 'acts_as' in x: context.append(x['acts_as'])
                if 'is_indicated_for' in x: context.add(x['is_indicated_for'])
#                 if 'can_be_described_as' in x: context.add(x['can_be_described_as'])
            elif x['id'][0] == 'I': # icd
                context.add(x['is_known_as'])

        for x in add_nodes: # for the additional nodes, we only add the name of diagnoses
            if x['id'][0] == 'I': # icd
                context.add(x['is_known_as'])
                    
        result = self.qa_chain(
            {"question": prompt + medical_note + 'Context: ' + ' '.join(context), "context": ' '.join(context)},
        )
        final_result = result[self.qa_chain.output_key]
        return {self.output_key: final_result}

In [None]:
chain = Neo4jVectorChain(graph=graph, embeddings=model, verbose=True)

In [None]:
expand_path = True

prompt = """You are a helpful medical assistant who needs to retrieve relevant medical documents for your patient. 
Extract the most relevant keywords from the clinical note using the provided context. 
Response format: Do not repeat keywords. Keywords should be returned in a comma-separated list that will be used for search.
Do not explain.
Clinical note: """

# result = []
for i in tqdm(range(12,len(df_queries))):
    kk = chain({'medical_note':df_queries['description'].values[i],'prompt':prompt,'expand_path':expand_path})
    print(kk)
    result.append(kk)
    time.sleep(10)
    
df_queries['query_KG_extract'] = result
df_queries.to_pickle(path_dir + f'df_all_queries_KG_{str(expand_path)}.pickle')
df_queries

In [None]:
prompt = """You are a helpful medical assistant that needs to retrieve relevant medical documents for your patient. 
Using the provided context, summarize the clinical note by extracting its most relevant keywords.
Response format: Do not repeat keywords. Keywords should be returned in a comma-separated list that will be used for search.
Do not explain.
Clinical note: """

result = []
for i in tqdm(range(0,len(df_queries))):
    kk = chain({'medical_note':df_queries['description'].values[i],'prompt':prompt})
    print(kk)
    result.append(kk)
    time.sleep(20)
    
df_queries['query_KG_summary'] = result
df_queries.to_pickle(path_dir + f'df_all_queries_KG_{str(expand_path)}.pickle')
df_queries

In [None]:
prompt = """You are a helpful medical assistant that needs to retrieve relevant medical documents for your patient. 
Using the provided context, expand the clinical note with keywords referring to medical concepts, symptoms, diseases, synonyms or other information relevant. Keywords will be used to retrieve medical documents. 
Response format: Do not repeat keywords. Only keywords should be returned in a comma-separated list.
Clinical note: """

# result = []
for i in tqdm(range(0,len(df_queries))):
    kk = chain({'medical_note':df_queries['description'].values[i],'prompt':prompt})
    print(kk)
    result.append(kk)
    time.sleep(20)
    
df_queries['query_KG_extend'] = result
df_queries.to_pickle(path_dir + f'df_all_queries_KG_{str(expand_path)}.pickle')
df_queries

In [None]:
i = 0

prompt = """You are a helpful medical assistant that needs to create queries for retrieving relevant medical documents for your patient. 
1. Using the provided context, extract the most relevant keywords from the clinical note that fully describe its content.
Response format: Do not repeate keywords. Only keywords should be returned in a comma-separated list.
2. Use the provided context and only the provided context, expand the extracted list of keywords with at most 20 additional relevant keywords. 
Response format: Do not repeat keywords. Only keywords should be returned in a comma-separated list.
Clinical note: """

result = []
for i in tqdm(range(0,len(df_queries))):
    kk = chain({'medical_note':df_queries['description'].values[i],'prompt':prompt})
    print(kk)
    result.append(kk)
    time.sleep(20)
#     break

df_queries['query_KG_combine'] = result
df_queries.to_pickle(path_dir + f'df_all_queries_KG_{str(expand_path)}.pickle')
df_queries

In [None]:
df_queries = df_queries.rename(columns={'description':'query_description'})
df_queries['query_KG_extract'] = [x['result'] for x in df_queries['query_KG_extract']]
df_queries['query_KG_summary'] = [x['result'] for x in df_queries['query_KG_summary']]
df_queries['query_KG_extend'] = [x['result'] for x in df_queries['query_KG_extend']]

df_queries

In [None]:
ex_ex = []
ex_ep = []
ex_un = []

for x in df_queries['query_KG_combine']:
    zz = [y.replace('\n',' ') for y in x['result'].split('\n') if len(y) > 0]
    for i in range(0,len(zz)):
        
        if len(zz[i]) != 0:
            if zz[i][0].isdigit():
                zz[i] = zz[i][3:]

            if zz[i].startswith('Keywords') or zz[i].startswith('Expanded'):
                zz[i] = ' '.join(zz[i].split(':')[1:]).strip()
            
            if len(zz[i]) != 0:
                # podría ser que empiecen con keywords o "expanded keywords"
                if zz[i][0] == '[' and zz[i][-1] == ']':
                    zz[i] = zz[i].replace('[','').replace(']','')

                if zz[i].startswith('-'):
                    zz[i] = zz[i].replace('- ',',')
        
    ex_ex.append(zz[0])
    ex_ep.append(zz[1])
    ex_un.append(', '.join(list(set(zz[0].split(',') + zz[1].split(','))))[1:].strip())
    
df_queries['query_KG_combine_extract'] = ex_ex
df_queries['query_KG_combine_expand'] = ex_ep
df_queries['query_KG_combine_union'] = ex_un

df_queries = df_queries.drop(columns=['query_KG_combine'])

df_queries

In [None]:
df_queries.to_csv(path_dir + 'df_all_queries_adhoc__KG.csv')
df_queries.to_pickle(path_dir + 'df_all_queries_adhoc__KG.pickle')