In [1]:
from langchain_neo4j import Neo4jGraph
import os
from dotenv import load_dotenv, find_dotenv, set_key
import getpass
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
import google.generativeai as genai

from langchain_core.output_parsers import StrOutputParser
from langchain.chains import LLMChain
from langchain.retrievers.multi_query import MultiQueryRetriever

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

from typing import List, Dict, TypedDict

from langchain_chroma import Chroma

from langchain_core.output_parsers import BaseOutputParser

from langgraph.graph import StateGraph, END


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


In [2]:
## find the path of the .env file 

dotenv_path = find_dotenv(filename='.env', usecwd=True, raise_error_if_not_found=False)

load_dotenv(dotenv_path=dotenv_path, override=True)

True

In [22]:
# set_key(dotenv_path, "GOOGLE_AI_KEY", "AIzaSyAhZAUn2S2DcWC9SSYEuduMvto--rCSMe0")

In [3]:
NEO4J_URI = os.getenv("server_NEO4J_URI_37")
NEO4J_USERNAME = os.getenv("server_NEO4J_USERNAME_37")
NEO4J_PASSWORD = os.getenv("server_NEO4J_PASSWORD_37")
AURA_INSTANCEID = os.getenv("AURA_INSTANCEID")
AURA_INSTANCENAME = os.getenv("AURA_INSTANCENAME")
api_key = os.getenv('GROQ_API_KEY')
api_key_openai = os.getenv("OPENAI_API_KEY")
api_key_google = os.getenv("GOOGLE_AI_KEY")

In [4]:
graph = Neo4jGraph(
    url=NEO4J_URI, 
    username=NEO4J_USERNAME, 
    password=NEO4J_PASSWORD
)

In [5]:
graph.refresh_schema()
print(graph.schema)

Node properties:
Gene {name: STRING, type: STRING}
Biological_Process {label: STRING, name: STRING, type: STRING}
Transcription_Factor {label: STRING, type: STRING, name: STRING}
Cell_Type {type: STRING, name: STRING, label: STRING}
Disease {name: STRING, type: STRING}
Phenotype {type: STRING, label: STRING, name: STRING}
Pathway {name: STRING, type: STRING}
Protein_Domain {name: STRING, type: STRING}
Pathway_Identifier {type: STRING, name: STRING}
Drug {type: STRING, name: STRING}
Cell_Line {organ: STRING, description: STRING, name: STRING, type: STRING, label: STRING}
lncrna_Deg_Gene {type: STRING, lncrna: STRING, name: STRING}
Tissue {type: STRING, label: STRING}
Gene_Expression_State {type: STRING, name: STRING}
Kinase_Enzymes {name: STRING, type: STRING}
Relationship properties:
gene_bioprocess {relation: STRING}
gene_cell_line {relation: STRING}
gene_cell_line_fitness_increase {relation: STRING}
gene_cell_line_fitness_decrease {relation: STRING}
gene_tf_interaction {relation: STR

In [6]:
# # Match all modes in the graph

# cypher = """
# MATCH (n) RETURN count(n); 
# """

# graph.query(cypher)

In [70]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-pro",
    google_api_key=api_key_google,
    temperature=0.0,
    # max_output_tokens=2048,  # Parameter for Google models is `max_output_tokens`
    # convert_system_message_to_human=True # Often helpful for complex agentic prompts
)

In [71]:
embeddings_google = GoogleGenerativeAIEmbeddings(
    model="models/text-embedding-004", # Google's latest and recommended embedding model
    google_api_key=api_key_google,
    show_progress_bar=True
)

In [72]:
# --- 2. DEFINE THE STATE FOR OUR GRAPH ---
# The state object is passed between nodes and updated at each step.

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: The user's original question.
        entities: A list of extracted biological entities.
        vector_context: Context retrieved from the vector store.
        cypher_query: The generated Cypher query.
        graph_context: The raw data retrieved from the knowledge graph.
        final_answer: The final, synthesized answer.
        error: A field to capture any errors.
    """
    question: str
    entities: List[str]
    vector_context: List[str]
    cypher_query: str
    graph_context: List[Dict]
    final_answer: str
    error: str

In [73]:
graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]"
    )

class Entities(BaseModel):
    """Identifying information about entities."""
    names: List[str] = Field(
        ...,
        description="All names of valid biological entities that appear in the text."
    )

def extract_entities_node(state: GraphState) -> Dict[str, any]:
    """
    Extracts entities from the user's question, grounded in the specific graph schema.
    """
    print("---NODE: EXTRACTING ENTITIES (SCHEMA-GROUNDED)---")
    question = state["question"]

    # This is the significantly improved, schema-aware prompt.
    prompt = ChatPromptTemplate.from_messages([
        (
            "system",
            """You are a highly specialized biological entity extraction system.
Your task is to identify and extract entities from a user's question that STRICTLY match the entity types defined in our knowledge graph schema.

**1. VALID ENTITY TYPES (Node Labels):**
- Gene
- Biological_Process
- Transcription_Factor
- Cell_Type
- Disease
- Phenotype
- Pathway
- Protein_Domain
- Pathway_Identifier
- Drug
- Cell_Line
- lncrna_Deg_Gene
- Tissue
- Gene_Expression_State
- Kinase_Enzymes

**2. CRITICAL EXTRACTION RULES:**
- You MUST only extract the name of the entity, not its type. For "Gene AKNA", you must extract "AKNA".
- If you find multiple entities, extract all of them.
- If no entities from the list above are found in the question, you MUST return an empty list `[]`. Do not guess or hallucinate.
- you should not extract the type means - "pathways", "gene", just need to take the NAME.

**3. EXAMPLES:**
-   User Question: "What diseases are associated with the gene AKNA?"
-   Extracted Entities: ["AKNA"]

-   User Question: "Compare Imatinib and Dasatinib for Chronic Myeloid Leukemia."
-   Extracted Entities: ["Imatinib", "Dasatinib", "Chronic Myeloid Leukemia"]

-   User Question: "What is the role of the Wnt signaling pathway in SH-SY5Y cells?"
-   Extracted Entities: ["Wnt signaling pathway", "SH-SY5Y"]

-   User Question: "Which transcription factors interact with the gene STAT3?"
-   Extracted Entities: ["STAT3"]

-   User Question: "Tell me about your system."
-   Extracted Entities: []
"""
        ),
        (
            "human",
            "Based on the provided schema, rules, and examples, extract all valid entities from the following question: {question}",
        ),
    ])

    entity_chain = prompt | llm.with_structured_output(Entities)

    try:
        entities_result = entity_chain.invoke({"question": question})

        # Defensive programming check remains crucial
        if not entities_result or not entities_result.names:
            print("  > No entities matching the schema were extracted.")
            return {"entities": []}

        print(f"  > Extracted entities: {entities_result.names}")
        return {"entities": entities_result.names}

    except Exception as e:
        print(f"  > An unexpected error occurred during entity extraction: {e}")
        return {"entities": [], "error": "Failed to extract entities due to an unexpected error."}

In [74]:
# vector_db = Chroma(
#     embedding_function=embeddings_google,
#     persist_directory="../vector_db"

# )

# client = vector_db._client  # Accessing the Chroma client directly

# # Get all the collections
# collections = client.list_collections()

# # Delete all collections
# for collection in collections:
#     client.delete_collection(collection.name)

In [75]:
vector_db = Chroma(
    embedding_function=embeddings_google,
    persist_directory="../vector_db",
    collection_name="langraph_new"
)

In [76]:
QUERY_PROMPT_MULTI = PromptTemplate(
    input_variables=["question"],
    template="""You are an AI language model specializing in generating multiple perspectives
    on a user's question to improve document retrieval. Rephrase the user's question in
    four different ways, focusing on its core semantic meaning.

    Original question: {question}

    1.
    2.
    3.
    4.
    """
)

In [77]:
class LineListOutputParser(BaseOutputParser[List[str]]):
    def parse(self, text: str) -> List[str]:
        return [line.strip() for line in text.strip().split("\n") if line.strip()]

In [78]:
llm_chain = LLMChain(llm=llm, prompt=QUERY_PROMPT_MULTI, output_parser=LineListOutputParser())

multi_query_retriever = MultiQueryRetriever(
    retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 4}),
    llm_chain=llm_chain
)

In [79]:
def retrieve_from_vectorstore_node(state: GraphState) -> Dict[str, any]:
    """Retrieves context from the vector store."""
    print("---NODE: RETRIEVING FROM VECTOR STORE---")
    question = state["question"]
    try:
        docs = multi_query_retriever.get_relevant_documents(question)
        vector_context = [doc.page_content for doc in docs]
        print(f"  > Retrieved {len(vector_context)} documents.")
        return {"vector_context": vector_context}
    except Exception as e:
        print(f"  > Error retrieving from vector store: {e}")
        return {"error": "Failed to retrieve from vector store."}

In [80]:
# --- NODE: Cypher Query Generator (Using Your Expert Prompt) ---
def generate_cypher_node(state: GraphState) -> Dict[str, any]:
    """
    Generates a Cypher query to search the knowledge graph using a highly specialized,
    domain-expert prompt.
    """
    print("---NODE: GENERATING CYPHER QUERY (EXPERT PROMPT)---")
    question = state["question"]
    entities = state["entities"]

    # Guardrail: If no entities were extracted, it's impossible to generate a good query.
    if not entities:
        print("  > No entities found. Skipping Cypher generation.")
        return {"cypher_query": None}

    # HERE, WE INTEGRATE YOUR POWERFUL, DETAILED TEMPLATE
    CYPHER_GENERATION_TEMPLATE = """Task: Generate a Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema below.
Do not use any other relationship types or properties that are not provided.

### Schema:
{schema}

### Core Task:
Based on the user's question and a list of identified entities, write a single, efficient Cypher query.

### Additional Guidelines for Query Construction:
1.  **Anchor the Query on Entities:** You MUST use the provided entities to start your search. Find nodes that match these names: `{entities}`.
2.  **Understanding and Interpreting Relationships:**
    - Relationships of similar types should be grouped logically. For example, any query about drug effects should consider `gene_drug_up`, `gene_drug_down`, `gene_drug_treatment_up_reg`, and `gene_drug_treatment_down_reg`.
3.  **Dynamic & Transitive Exploration:**
    - Use `OPTIONAL MATCH` for querying multiple relationships to ensure that all available data is fetched.
    - When a query asks for related entities (e.g., "find all entities connected to a disease"), explore connections up to a depth of 3 using `-[*1..3]-`.
4.  **Query Syntax and Rules:**
    - Always alias return variables with meaningful names (e.g., `g.name as geneName`).
    - Limit query results to 100 for efficient execution.
    - Strive to create syntactically correct Cypher statements.
    - For questions requiring multiple pieces of information, craft a single query that fetches all relevant data.
5.  **Failure Condition:**
    - If the user’s query cannot be answered using the provided schema and entities, you MUST respond with only the string "NO_QUERY".

Now, generate the Cypher query for the following task.

**Question:** {question}
**Entities:** {entities}
**Cypher Query:**"""

    cypher_prompt = PromptTemplate(
        template=CYPHER_GENERATION_TEMPLATE,
        input_variables=["schema", "question", "entities"],
    )

    # This "manual chain" is exactly what you wanted.
    # It combines your prompt, the LLM, and a simple string output parser.
    cypher_chain = cypher_prompt | llm | StrOutputParser()

    try:
        # Get the live graph schema
        schema = graph.get_schema
        # Format entities for the prompt
        entities_str = ", ".join(f"'{e}'" for e in entities)

        query = cypher_chain.invoke({
            "schema": schema,
            "question": question,
            "entities": entities_str
        })

        if "NO_QUERY" in query or not query.strip():
            print("  > LLM decided no suitable Cypher query can be generated.")
            return {"cypher_query": None}

        # Clean up the generated query
        cleaned_query = query.strip().replace("cypher\n", "").replace("```", "").replace("`", "")
        print(f"  > Generated Cypher: {cleaned_query}")
        return {"cypher_query": cleaned_query}
    except Exception as e:
        print(f"  > Error generating Cypher query: {e}")
        return {"error": "Failed to generate Cypher query."}

In [81]:
# --- NODE: Graph Executor ---
def execute_cypher_node(state: GraphState) -> Dict[str, any]:
    """Executes the Cypher query and retrieves data from the graph."""
    print("---NODE: EXECUTING CYPHER QUERY---")
    query = state["cypher_query"]
    try:
        context = graph.query(query)
        print(f"  > Retrieved {len(context)} results from the graph.")
        return {"graph_context": context}
    except Exception as e:
        print(f"  > Error executing Cypher query: {e}")
        return {"error": f"Failed to execute Cypher query: {e}"}

In [82]:
# --- NODE: Final Response Synthesizer ---
def generate_response_node(state: GraphState) -> Dict[str, any]:
    """Synthesizes the final answer from all retrieved context."""
    print("---NODE: SYNTHESIZING FINAL RESPONSE---")
    question = state["question"]
    vector_context = state["vector_context"]
    graph_context = state["graph_context"]

    # This prompt is key to solving your original problem.
    # It explicitly tells the LLM how to use the structured and unstructured data.
    synthesis_prompt = PromptTemplate(
        template="""You are an AI assistant specializing in biological and medical information.
Your task is to provide a clear and concise answer to the user's question based *only* on the provided context.

**User's Question:**
{question}

**Context from Vector Search (general information):**
---
{vector_context}
---

**Structured Data from Knowledge Graph (specific facts and relationships):**
---
{graph_context}
---

**Instructions:**
1.  Synthesize the information from both the vector search context and the structured graph data.
2.  Prioritize the structured data from the knowledge graph as the primary source of truth for specific facts.
3.  Use the vector search context to provide additional background or explanation if necessary.
4.  If the context does not contain an answer to the question, state that you could not find the information.
5.  Do not use any information outside of the provided context.
6.  Present the answer in a clear, easy-to-read format.

Final Answer:""",
        input_variables=["question", "vector_context", "graph_context"],
    )

    synthesis_chain = synthesis_prompt | llm | StrOutputParser()

    # Format context for the prompt
    formatted_vector = "\n".join(vector_context) if vector_context else "No information found."
    formatted_graph = "\n".join(map(str, graph_context)) if graph_context else "No information found."

    final_answer = synthesis_chain.invoke({
        "question": question,
        "vector_context": formatted_vector,
        "graph_context": formatted_graph
    })
    print(f"  > Final Answer Generated.")
    return {"final_answer": final_answer}

In [83]:
# --- 4. DEFINE THE EDGES (CONTROL FLOW) ---

def route_after_cypher_generation(state: GraphState) -> str:
    """Determines the next step after attempting to generate a Cypher query."""
    print("---ROUTING: DECIDING NEXT STEP---")
    if state.get("error"):
        print("  > Error detected. Ending.")
        return "end"
    if state.get("cypher_query"):
        print("  > Cypher query exists. Executing against graph.")
        return "execute_cypher"
    else:
        print("  > No Cypher query. Proceeding to final synthesis.")
        return "generate_response"

In [84]:
workflow = StateGraph(GraphState)

In [85]:
# Add the nodes
workflow.add_node("extract_entities", extract_entities_node)
workflow.add_node("retrieve_from_vectorstore", retrieve_from_vectorstore_node)
workflow.add_node("generate_cypher", generate_cypher_node)
workflow.add_node("execute_cypher", execute_cypher_node)
workflow.add_node("generate_response", generate_response_node)

<langgraph.graph.state.StateGraph at 0x7f90fc584650>

In [86]:
# Define the workflow edges
workflow.set_entry_point("extract_entities")
workflow.add_edge("extract_entities", "retrieve_from_vectorstore")
workflow.add_edge("retrieve_from_vectorstore", "generate_cypher")

<langgraph.graph.state.StateGraph at 0x7f90fc584650>

In [87]:
# Add the conditional routing
workflow.add_conditional_edges(
    "generate_cypher",
    route_after_cypher_generation,
    {
        "execute_cypher": "execute_cypher",
        "generate_response": "generate_response",
        "end": END
    }
)

<langgraph.graph.state.StateGraph at 0x7f90fc584650>

In [63]:
workflow.add_edge("execute_cypher", "generate_response")
workflow.add_edge("generate_response", END)

<langgraph.graph.state.StateGraph at 0x7f90fc5d9700>

In [64]:
app = workflow.compile()

In [91]:
inputs = {"question": "Genes associated with the Breast cancer ?"}

In [92]:
for s in app.stream(inputs, stream_mode="values"):
    # The key will be the name of the node that just finished running
    node_that_just_ran = list(s.keys())[0]
    print(f"\n--- After Node: {node_that_just_ran} ---")
    # Print the current state
    print(s[node_that_just_ran])

# Or get the final result directly
final_state = app.invoke(inputs)
print("\n\n" + "="*50)
print("              FINAL ANSWER")
print("="*50)
print(final_state['final_answer'])


--- After Node: question ---
Genes associated with the Breast cancer ?
---NODE: EXTRACTING ENTITIES (SCHEMA-GROUNDED)---
  > Extracted entities: ['Breast cancer']

--- After Node: question ---
Genes associated with the Breast cancer ?
---NODE: RETRIEVING FROM VECTOR STORE---
  > Retrieved 0 documents.

--- After Node: question ---
Genes associated with the Breast cancer ?
---NODE: GENERATING CYPHER QUERY (EXPERT PROMPT)---
  > Generated Cypher: MATCH (d:Disease {name: 'Breast cancer'})-[r:gene_disease]->(g:Gene)
RETURN g.name AS geneName
LIMIT 100

---ROUTING: DECIDING NEXT STEP---
  > Cypher query exists. Executing against graph.

--- After Node: question ---
Genes associated with the Breast cancer ?
---NODE: EXECUTING CYPHER QUERY---
  > Retrieved 100 results from the graph.

--- After Node: question ---
Genes associated with the Breast cancer ?
---NODE: SYNTHESIZING FINAL RESPONSE---
  > Final Answer Generated.

--- After Node: question ---
Genes associated with the Breast cancer ?

KeyboardInterrupt: 

In [69]:
final_state

{'question': 'Gene Genes are associated with the pathway Apoptosis ?',
 'entities': ['Apoptosis'],
 'vector_context': [],
 'cypher_query': "MATCH (p:Pathway {name: 'Apoptosis'})-[:gene_pathway]->(g1:Gene)-[:gene_gene]-(g2:Gene)\nRETURN g1.name AS gene, g2.name AS associated_gene, p.name AS pathway\nLIMIT 100\n",
 'graph_context': [{'gene': 'PRF1',
   'associated_gene': 'GZMB',
   'pathway': 'Apoptosis'},
  {'gene': 'PRF1', 'associated_gene': 'GNLY', 'pathway': 'Apoptosis'},
  {'gene': 'PRF1', 'associated_gene': 'SRGN', 'pathway': 'Apoptosis'},
  {'gene': 'PRF1', 'associated_gene': 'CALR', 'pathway': 'Apoptosis'},
  {'gene': 'TP63', 'associated_gene': 'EP300', 'pathway': 'Apoptosis'},
  {'gene': 'TP63', 'associated_gene': 'WWP1', 'pathway': 'Apoptosis'},
  {'gene': 'TP63', 'associated_gene': 'HDAC1', 'pathway': 'Apoptosis'},
  {'gene': 'TP63', 'associated_gene': 'YAP1', 'pathway': 'Apoptosis'},
  {'gene': 'TP63', 'associated_gene': 'RELA', 'pathway': 'Apoptosis'},
  {'gene': 'TP63', 'as