In [78]:
from langchain.chains import GraphCypherQAChain
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.callbacks import get_openai_callback
from dotenv import load_dotenv
import os
import openai
import pandas as pd
from neo4j.exceptions import CypherSyntaxError


## Choose the LLM

In [79]:
LLM_MODEL = 'gpt-4-32k'


## Load test data

In [80]:
data = pd.read_csv('../data/rag_comparison_data.csv')



## Custom function for neo4j RAG chain

In [81]:
def get_neo4j_cypher_rag_chain():
    load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))
    username = os.environ.get('NEO4J_USER')
    password = os.environ.get('NEO4J_PSW')
    url = os.environ.get('NEO4J_URI')
    database = os.environ.get('NEO4J_DB')

    graph = Neo4jGraph(
        url=url, 
        username=username, 
        password=password,
        database = database
    )

    load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))
    API_KEY = os.environ.get('API_KEY')
    API_VERSION = os.environ.get('API_VERSION')
    RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')
    openai.api_type = "azure"
    openai.api_key = API_KEY
    openai.api_base = RESOURCE_ENDPOINT
    openai.api_version = API_VERSION
    chat_deployment_id = LLM_MODEL
    chat_model_id = chat_deployment_id
    temperature = 0
    chat_model = ChatOpenAI(openai_api_key=API_KEY, 
                            engine=chat_deployment_id, 
                            temperature=temperature)
    chain = GraphCypherQAChain.from_llm(
        chat_model, 
        graph=graph, 
        verbose=True, 
        validate_cypher=True,
        return_intermediate_steps=True
    )
    return chain

## Initiate neo4j RAG chain

In [82]:
%%time
neo4j_rag_chain = get_neo4j_cypher_rag_chain()


                    engine was transferred to model_kwargs.
                    Please confirm that engine is what you intended.


CPU times: user 61.2 ms, sys: 5.99 ms, total: 67.2 ms
Wall time: 36.5 s


## Run on test data

In [83]:
%%time

neo4j_rag_answer = []
total_tokens_used = []

for index, row in data.iterrows():
    question = row['question']
    with get_openai_callback() as cb:
        try:
            neo4j_rag_answer.append(neo4j_rag_chain.run(query=question, return_final_only=True, verbose=False))
        except ValueError as e:
            neo4j_rag_answer.append(None)
    total_tokens_used.append(cb.total_tokens)

data.loc[:,'neo4j_rag_answer'] = neo4j_rag_answer
data.loc[:, 'total_tokens_used'] = total_tokens_used




[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'childhood-onset asthma'})-[r:ASSOCIATES_DaG]-(g:Gene {name: 'RORA'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 2e-37}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "SHANK2"})-[r:ASSOCIATES_DaG]-(d:Disease {name: "skin benign neoplasm"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 5e-08}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'hypertrophic cardiomyopathy'}), (g:Gene {name: 'AMBRA1'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'l


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: "systemic lupus erythematosus"})-[r:ASSOCIATES_DaG]-(g:Gene {name: "ELF3"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 4e-09}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'GMDS'}), (d:Disease {name: 'hemorrhoid'}) 
RETURN EXISTS((g)-[:ASSOCIATES_DaG]->(d)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': False}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'rheumatoid arthritis'})-[r:ASSOCIATES_DaG]-(g:Gene {name: 'DPP4'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 2e-21}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gen

Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'pancreatic carcinoma'})-[r:ASSOCIATES_DaG]-(g:Gene {name: 'CCDC68'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 5e-09}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'primary biliary cholangitis'}), (g:Gene {name: 'HLA-DPB1'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'atopic dermatitis'}), (g:Gene {name: 'GABPB2'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'HIP1'})-[r:ASSOCIATES_DaG]-(d:Disease {

Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'type 1 diabetes mellitus'}), (g:Gene {name: 'DGKQ'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "SIX6"})-[r:ASSOCIATES_DaG]-(d:Disease {name: "refractive error"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 2e-16}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "MAP4K4"})-[r:ASSOCIATES_DaG]-(d:Disease {name: "Parkinson's disease"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 8e-13}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'autism spectrum disorder'}), (g:Gene {name: 'DCC'})
RETURN EXISTS(


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'cleft lip'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'SFI1'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 3e-08}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: "multiple sclerosis"})-[r:ASSOCIATES_DaG]->(g:Gene {name: "INAVA"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 6e-21}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'esophageal carcinoma'}), (g:Gene {name: 'CASP8'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {nam

## Save the result

In [84]:
save_path = '../data/results'
os.makedirs(save_path, exist_ok=True)
data.to_csv(os.path.join(save_path, 'neo4j_rag_output.csv'), index=False)



In [85]:
data_no_nan = data.dropna()
data_no_nan[data_no_nan.neo4j_rag_answer.str.contains('sorry')].shape[0] + data[data.neo4j_rag_answer.isna()].shape[0]


15