In [11]:
from langchain.chains import GraphCypherQAChain
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.callbacks import get_openai_callback
from dotenv import load_dotenv
import os
import openai
import pandas as pd
from neo4j.exceptions import CypherSyntaxError


## Choose the LLM

In [12]:
LLM_MODEL = 'gpt-4-32k'


## Load test data

In [13]:
data = pd.read_csv('../data/rag_comparison_data.csv')



## Custom function for neo4j RAG chain

In [35]:
def get_neo4j_cypher_rag_chain():
    load_dotenv(os.path.join(os.path.expanduser('~'), '.mate_neo4j_config.env'))
    username = os.environ.get('MATE_USR')
    password = os.environ.get('MATE_PSW')
    url = os.environ.get('MATE_URI')
    database = os.environ.get('DB_NAME')
    
    graph = Neo4jGraph(
        url=url, 
        username=username, 
        password=password,
        database = database
    )

    load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))
    API_KEY = os.environ.get('API_KEY')
    API_VERSION = os.environ.get('API_VERSION')
    RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')
    openai.api_type = "azure"
    openai.api_key = API_KEY
    openai.api_base = RESOURCE_ENDPOINT
    openai.api_version = API_VERSION
    chat_deployment_id = LLM_MODEL
    chat_model_id = chat_deployment_id
    temperature = 0
    chat_model = ChatOpenAI(openai_api_key=API_KEY, 
                            engine=chat_deployment_id, 
                            temperature=temperature)
    chain = GraphCypherQAChain.from_llm(
        chat_model, 
        graph=graph, 
        verbose=True, 
        validate_cypher=True,
        return_intermediate_steps=True
    )
    return chain

## Initiate neo4j RAG chain

In [36]:
%%time
neo4j_rag_chain = get_neo4j_cypher_rag_chain()


                    engine was transferred to model_kwargs.
                    Please confirm that engine is what you intended.


CPU times: user 13.4 ms, sys: 5.35 ms, total: 18.7 ms
Wall time: 71.1 ms


## Run on test data

In [37]:
%%time

neo4j_rag_answer = []
total_tokens_used = []

for index, row in data.iterrows():
    question = row['question']
    with get_openai_callback() as cb:
        try:
            neo4j_rag_answer.append(neo4j_rag_chain.run(query=question, return_final_only=True, verbose=False))
        except ValueError as e:
            neo4j_rag_answer.append(None)
    total_tokens_used.append(cb.total_tokens)    
    

data.loc[:,'neo4j_rag_answer'] = neo4j_rag_answer
data.loc[:, 'total_tokens_used'] = total_tokens_used




[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mThe provided schema does not include a property for GWAS p-value or any nodes or relationships that would represent an association between a disease and a gene. Therefore, it is not possible to construct a Cypher statement to answer this question based on the provided schema.[0m


ValueError: Length of values (1) does not match length of index (100)

In [38]:
total_tokens_used

[473]

## Save the result

In [120]:
save_path = '../data/results'
os.makedirs(save_path, exist_ok=True)
data.to_csv(os.path.join(save_path, 'neo4j_rag_output.csv'), index=False)



In [8]:
data

Unnamed: 0,disease_name,gene_name,gwas_pvalue,question,question_perturbed,neo4j_rag_answer,total_tokens_used
0,childhood-onset asthma,RORA,2.000000e-37,What is the GWAS p-value for the association b...,What is the GWAS p-value for the association b...,,10993
1,skin benign neoplasm,SHANK2,5.000000e-08,What is the GWAS p-value for the association b...,What is the GWAS p-value for the association b...,"I'm sorry, but I don't know the answer.",11138
2,hypertrophic cardiomyopathy,AMBRA1,1.000000e-16,Is hypertrophic cardiomyopathy associated with...,Is hypertrophic cardiomyopathy associated with...,"Yes, hypertrophic cardiomyopathy is associated...",11468
3,lung adenocarcinoma,CYP2A6,8.000000e-11,What is the GWAS p-value for the association b...,What is the GWAS p-value for the association b...,"I'm sorry, but I don't have the information to...",11150
4,idiopathic generalized epilepsy,RYR2,3.000000e-09,Is idiopathic generalized epilepsy associated ...,Is idiopathic generalized epilepsy associated ...,"No, idiopathic generalized epilepsy is not ass...",11129
...,...,...,...,...,...,...,...
95,lung squamous cell carcinoma,BRCA2,1.000000e-15,Is lung squamous cell carcinoma associated wit...,Is lung squamous cell carcinoma associated wit...,"Yes, lung squamous cell carcinoma is associate...",11129
96,systemic lupus erythematosus,HLA-DRA,2.000000e-60,What is the GWAS p-value for the association b...,What is the GWAS p-value for the association b...,"I'm sorry, but I don't have the information to...",11152
97,type 2 diabetes mellitus,UBE2E2,2.000000e-42,Is type 2 diabetes mellitus associated with UB...,Is type 2 diabetes mellitus associated with ub...,"No, type 2 diabetes mellitus is not associated...",11139
98,allergic rhinitis,HLA-DQA1,1.000000e-43,What is the GWAS p-value for the association b...,What is the GWAS p-value for the association b...,"I'm sorry, but I don't know the answer.",11143
