In [None]:
import pandas as pd
import time
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.graphs import Neo4jGraph
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Annoy


# from extract_keywords_from_input import extract_keywords_from_input_question


# os.environ["OPENAI_API_KEY"] = 'sk-'


# 1. get the similar entities from the scholar graph
def get_similar_entities_in_existing_graph(existing_graph, search_entity, similar_threshold=0.9):
    result=existing_graph.similarity_search_with_score(search_entity, k=15)
    similar_entities=[]
    for res in result:
        if res[1]>similar_threshold:
            similar_entities.append((res[0].page_content, res[1]))
    return similar_entities


# get similar entities using HuggingFace BGE model
def get_similar_entities(search_entity,entity_list,model_name="BAAI/bge-large-en-v1.5",similar_threshold=0.7):
    model_kwargs = {'device': 'cpu'}
    encode_kwargs = {'normalize_embeddings': True}
    hf = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )

    # vector_store = Annoy.from_texts(entity_list, hf,n_jobs=3)

    embeddings=hf.embed_documents(entity_list)
    data=list(zip(entity_list, embeddings))
    vector_store = Annoy.from_embeddings(data, hf,n_jobs=3,mertric="angular")
    # allows for custom annoy parameters, defaults are n_trees=100, n_jobs=-1, metric="angular"
    # vector_store_v2 = Annoy.from_texts(entity_list, hf, metric="dot", n_trees=100, n_jobs=1)

    embs=hf.embed_query(search_entity)
    result=vector_store.similarity_search_with_score_by_vector(embs, k=100)
    # result=vector_store.similarity_search_with_score(search_entity, k=15)
    similar_entities=[]
    for res in result:
        if res[1]<=similar_threshold and (res[0].page_content, res[1]) not in similar_entities:
            similar_entities.append((res[0].page_content, res[1]))
    return similar_entities




# 2. get the retrieved info from the scholar graph
def retrieve_info(existing_graph,similar_entities,retrieval_template):
    retrieved_info=[]
    for entity in similar_entities:
        result=existing_graph.query(retrieval_template.format(entity,entity,entity,entity))
        if result:
            retrieved_info.extend([val for res in result for _,val in res.items() ])
    return set(retrieved_info)



# 3. generate the answer based on the input question and retrieved info
def generate_answer(question,llm, retrieved_info, answer_template):
    prompt = ChatPromptTemplate.from_template(answer_template)
    rag_chain = (
        {"context":  lambda x: retrieved_info,  "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    # query = "the commom methods used in object detection?"
    # query = "what kind of task include gerneration?"
    answer=rag_chain.invoke(question)
    return answer



def retrieve_knowledgd_from_scholar_graph(input_question,entity_list,existing_graph,retrieval_template,llm,answer_template,similar_threshold=0.9):
    # ner_res=extract_keywords_from_input_question(input_question)
    ner_res={"Task":["text-to-image generation"]} # TODO:需要更换为真实的NER结果
    keywords=[k for key,val in ner_res.items() for k in val]
    if keywords:
        retrived_keywords=[]
        print("Start to get similar entities...")
        start_time=time.time()
        for keyword in keywords:
            # similar_entities=get_similar_entities_in_existing_graph(existing_graph, keyword, similar_threshold=0.9)
            similar_entities=get_similar_entities(keyword,entity_list)
            similar_entities_res=[x[0] for x in similar_entities]
            retrived_keywords.extend(similar_entities_res)
            print("Time used to get similar entities: ",time.time()-start_time)
            print(f"similar entities for {keyword}: {similar_entities_res}")
        
    
        retrieved_info=retrieve_info(existing_graph,retrived_keywords,retrieval_template)
        answer=generate_answer(input_question,llm, retrieved_info, answer_template)
    else:
        answer="No keywords found in the input question."
    return answer



if __name__=="__main__":

    # TODO：换种方式构建
    existing_graph = Neo4jVector.from_existing_index(
        embedding=OpenAIEmbeddings(
            openai_api_key='hk-',
            openai_api_base='https://api.openai-hk.com/v1/'
        ),
        # embedding=embedding_model,
        url = "",
        username = "neo4j",
        password = "baishitong123!",
        index_name='vector',
        text_node_property='name',
        node_label=["Task"],
    )
    
    retrieval_template='''
    OPTIONAL MATCH p=(n1)-[r]-(n2)
    WHERE n1.name contains '{}' or n2.name contains '{}'
    RETURN
    CASE
        WHEN n1.name CONTAINS '{}' THEN n1.name +' -> '+type(r)+' -> '+n2.name
        WHEN n2.name CONTAINS '{}' THEN n2.name +' -> '+type(r)+' -> '+n1.name
    END AS result
    '''

    llm = ChatOpenAI(
        openai_api_key='',
        base_url='http://127.0.0.1:8080/v1',
        model='gpt-3.5-turbo',
    )

    answer_template = """
    You are an assistant for NER tasks.
    Use the following pieces of retrieved context to answer the question.
    If you don't know the answer, just say that you don't know.
    Use three sentences maximum and keep the answer concise.
    Question: {question}
    Context: {context}
    Answer:
    """

    entity_df=pd.read_csv("data/fused_entity_id_0613_llm.csv")
    entity_list=entity_df["entity"].tolist()

    input_question = "what kind of task include text-to-image generation?"
    answer=retrieve_knowledgd_from_scholar_graph(input_question,entity_list,existing_graph,retrieval_template,llm,answer_template,similar_threshold=0.7)
    print(answer)


In [None]:
from langchain.prompts import PromptTemplate  
import pandas as pd
import time
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.graphs import Neo4jGraph
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Annoy


# ----------------------------------------------------------------------
cypher_generation_template = """
Task:
You are a database query operator who needs to extract relevant entities and relationship points from user questions, and construct Cypher statements to query relevant information in the Neo4j graph database.
Schema:
{schema}  
Requirements:
The information stored in academic graphs is in English, so the entities extracted from input questions need to be converted to English, consistent with the entities and relationship types defined in the graph.
Only use the relationship types and attributes provided in the schema.
Do not use any other relationship types or attributes that are not provided.
Try to extract potential entities and query targets in the field of artificial intelligence.
To avoid missing matches due to case differences, use lowercase for matching uniformly. To ensure the broadness of the query, use fuzzy matching for ambiguous entity names, or match nodes that contain specific keywords.
To increase the query hit rate, multiple synonyms can be generated for the extracted entities for joint querying.
The returned results must include relevant papers.
be careful:
Do not include any explanation or apology in your answer.
Do not answer any questions that may require you to construct any text other than Cypher statements.
Ensure that the relationship direction in the query is correct.
Ensure that aliases are set correctly for entities and relationships.
Do not run any queries that add or remove content to the database.
Make sure to alias all subsequent statements as the 'with' statement.
If division is required, make sure to filter the denominator to non-zero values.
The input problem is:
{question}  
"""

qa_generation_template = """
You are an expert in the field of artificial intelligence, generating reasonable and logically correct academic answers based on user questions and the results of Neo4j Cypher queries.
The query results section contains the results of Cypher queries generated based on user natural language questions.
The information provided is authoritative, and you must not doubt it or attempt to correct it using internal knowledge.
Make the answer sound like an answer to the problem.
Query results:
{context}  
Question:
{question}  
If the information provided is empty, please answer conservatively, or if it prompts that there is no relevant information in the graph, please re-enter the question.
Empty information looks like this: []
If the information is not empty, you must use the results to provide an answer, keep the answer concise, do not explain too much, and list the relevant article titles at the end. Remove additional quotes in title list.
If the question involves time duration, please assume that the query results are in days unless otherwise specified.
Please answer the input questions:
"""

cypher_generation_prompt = PromptTemplate(  
    input_variables=["schema", "question"], template=cypher_generation_template
)
qa_generation_prompt = PromptTemplate(  
    input_variables=["context", "question"], template=qa_generation_template
)  
  
from langchain_community.graphs import Neo4jGraph  
from langchain.chains import GraphCypherQAChain  
from langchain_openai import ChatOpenAI  
 


graph = Neo4jGraph(
    url="bolt://localhost:7687",
    username="neo4j",
    password="wyf990306",
)

graph.refresh_schema()



gpt_llm = ChatOpenAI(
    openai_api_key="sk-",
    base_url="https://api.openai.com/v1/",
    model="gpt-4o-2024-08-06",
)

ai_cypher_chain = GraphCypherQAChain.from_llm(   
    cypher_llm = gpt_llm,
    qa_llm = gpt_llm,
    graph=graph,  
    verbose=True,  
    qa_prompt=qa_generation_prompt,  
    cypher_prompt=cypher_generation_prompt,  
    validate_cypher=True,  
    top_k=100,  
)  

problem = "which fields has reinforcement learning been applied to?"

response = ai_cypher_chain.invoke(problem)
print(response['result'])
