In [None]:
"""
This project implements an intelligent Q&A system for exploring a database schema stored in Neo4j.
It works as follows:
1. Load a JSON schema file describing tables, columns, and relationships.
2. Store this schema in Neo4j as nodes (:Table) and relationships (:FK for foreign keys).
3. Build a FAISS vector index from table and column descriptions for semantic search.
4. Use Groq LLM (via API) to:
   - Understand the user's natural language question.
   - Generate a safe Cypher query targeting only the schema graph (no row-level data).
5. Execute the generated Cypher query on Neo4j and retrieve results.
6. Analyze results and return a clear Arabic summary + preview (relationships, table names, columns, etc.).
7. The workflow is managed with LangGraph, ensuring a step-by-step process:
   reasoning → Cypher generation → execution → analysis → output.
"""
!pip install -q langgraph groq neo4j sentence-transformers faiss-cpu tqdm requests python-dotenv pandas

In [None]:
import os, json, re, pickle, traceback
from getpass import getpass
from neo4j import GraphDatabase
import requests
from sentence_transformers import SentenceTransformer
import faiss
from tqdm import tqdm
from typing import List, Dict, Any, Tuple

# LangGraph
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END


In [None]:
# ----------------------------------------------------
# Step 1: Detect if running inside Google Colab
# ----------------------------------------------------
# - Try to import the Colab "files" module.
# - If successful → open a file picker UI to upload the schema JSON file.
# - Get the uploaded file name and set it as schema_path.
# - If running locally (import fails) → ask the user to manually type the path.
try:
    from google.colab import files
    print("Colab detected: use file upload UI…")
    uploaded = files.upload()
    schema_path = list(uploaded.keys())[0]
except Exception:
    print("Colab not detected. Set local path to your schema JSON file:")
    schema_path = input("Path to schema.json: ").strip()

with open(schema_path, "r", encoding="utf-8") as f:
    schema = json.load(f)

print("Loaded schema keys:", list(schema.keys()))
if "tables" in schema:
    print("Tables found:", list(schema["tables"].keys()))
else:
    print("Warning: 'tables' key not found in schema; check file.")

Colab detected: use file upload UI…


Saving mock_db_schema (1).json to mock_db_schema (1) (10).json
Loaded schema keys: ['tables']
Tables found: ['users', 'addresses', 'categories', 'suppliers', 'products', 'orders', 'order_items', 'shipments', 'payments', 'reviews']


In [None]:
# This block checks if Neo4j and Groq API credentials are already set as environment variables.
# If not, it prompts the user to input them (securely for passwords and API keys)
# and stores them temporarily for the current session.
# After that, it tries to connect to the Neo4j database and prints a confirmation message if successful,
# or an error message if the connection fails
if not  os.getenv("NEO4J_URI"):
    print("Enter Neo4j credentials (or leave blank to skip Neo4j steps)")
    neo_uri = input("Neo4j URI (bolt://localhost:7687 or neo4j+s://...): ").strip()
    if neo_uri:
        neo_user = input("Neo4j user (e.g. neo4j): ").strip()
        neo_pwd = getpass("Neo4j password (hidden): ")
        os.environ["NEO4J_URI"] = neo_uri
        os.environ["NEO4J_USER"] = neo_user
        os.environ["NEO4J_PASSWORD"] = neo_pwd

if not os.getenv("GROQ_API_KEY"):
    k = getpass("Enter GROQ_API_KEY (hidden) — leave blank to skip LLM calls: ")
    if k:
        os.environ["GROQ_API_KEY"] = k

try:
    driver = GraphDatabase.driver(
        os.getenv("NEO4J_URI"),
        auth=(os.getenv("NEO4J_USER"), os.getenv("NEO4J_PASSWORD"))
    )
    with driver.session() as sess:
        res = sess.run("RETURN 'Connected to Neo4j!' AS msg")
        print(res.single()["msg"])
except Exception as e:
    print("Failed to connect to Neo4j:", e)


Connected to Neo4j!


In [None]:
# This section defines file paths used to store and retrieve the FAISS index, metadata,
# and records for schema search functionality. All files are saved in the current directory.
DATA_PATH = "./"ن
FAISS_INDEX_PATH = os.path.join(DATA_PATH, "faiss_index.faiss")
META_PATH = os.path.join(DATA_PATH, "faiss_meta.pkl")
RECORDS_PATH = os.path.join(DATA_PATH, "records.pkl")

In [None]:
import json
# This function ingests a JSON database schema into Neo4j.
# It creates a (Table) node for each table with a 'columns' property storing column metadata as JSON,
# and creates :FK relationships between tables based on foreign key-like column names (ending with '_id').
# Primary keys and self-references are excluded to avoid redundant or invalid relationships.
import json

def ingest_schema_to_neo4j_compact(schema, driver):
    if not driver:
        print("No Neo4j driver; skipping schema ingest.")
        return


    tables = schema.get("tables", schema)
    table_names_lower = {name.lower(): name for name in tables}
    print("Ingesting schema to Neo4j: tables count =", len(tables))

    with driver.session() as sess:

        sess.run("MATCH (n) DETACH DELETE n")


        sess.run("CREATE CONSTRAINT IF NOT EXISTS FOR (t:Table) REQUIRE t.name IS UNIQUE")


        for tname, tinfo in tables.items():
            cols_data = [
                {
                    "name": col.get("name"),
                    "data_type": col.get("data_type"),
                    "description": col.get("description")
                }
                for col in tinfo.get("columns", [])
            ]
            sess.run("""
                MERGE (t:Table {name:$tname})
                SET t.columns = $columns_json
            """, tname=tname, columns_json=json.dumps(cols_data, ensure_ascii=False))


        seen_links = set()
        for tname, tinfo in tables.items():
            for col in tinfo.get("columns", []):
                cname = (col.get("name") or "").lower()
                desc = (col.get("description") or "").lower()

                target_table = None


                if "foreign key" in desc:
                    for cand in table_names_lower:
                        if cand in desc:
                            target_table = table_names_lower[cand]
                            break


                if not target_table and cname.endswith("_id") and "primary key" not in desc:
                    target_guess = cname[:-3]
                    candidates = {target_guess, target_guess + "s", target_guess.rstrip("s")}

                    for cand in candidates:
                        if cand in table_names_lower:
                            target_table = table_names_lower[cand]
                            break


                    if not target_table and target_guess == tname.lower():
                        target_table = tname

                if not target_table:
                    continue


                key = (tname, target_table, cname)
                if key in seen_links:
                    continue
                seen_links.add(key)


                sess.run("""
                    MATCH (src:Table {name:$src}), (dst:Table {name:$dst})
                    MERGE (src)-[:FK {column:$col}]->(dst)
                """, src=tname, dst=target_table, col=cname)

    print("Schema ingest completed.")



ingest_schema_to_neo4j_compact(schema, driver)


Ingesting schema to Neo4j: tables count = 10
Schema ingest completed.


In [None]:
# This function prints a quick overview of the database schema from Neo4j.
# It lists all tables (Table nodes) in alphabetical order, then shows all foreign key (:FK) relationships
# between tables in the format "TableA --[column_name]-> TableB". If no relationships are found, it states that explicitly.
def show_schema_in_console(driver):
    with driver.session() as sess:

        print(" Tables in schema:")
        results = sess.run("MATCH (t:Table) RETURN t.name AS name ORDER BY name")
        tables = [record["name"] for record in results]
        for t in tables:
            print(" -", t)

        print("\n Foreign Key Relationships:")
        results = sess.run("""
            MATCH (a:Table)-[r:FK]->(b:Table)
            RETURN a.name AS src, r.column AS col, b.name AS dst
            ORDER BY src, dst, col
        """)

        if results.peek():
            for record in results:
                src, col, dst = record["src"], record["col"], record["dst"]
                if src == dst:

                    print(f"  {src} --[{col}]-> {dst} (self-reference)")
                else:
                    print(f"  {src} --[{col}]-> {dst}")
        else:
            print(" (No FK relationships found)")


show_schema_in_console(driver)


📋 Tables in schema:
 - addresses
 - categories
 - order_items
 - orders
 - payments
 - products
 - reviews
 - shipments
 - suppliers
 - users

🔗 Foreign Key Relationships:
  addresses --[user_id]-> users
  order_items --[order_id]-> orders
  order_items --[product_id]-> products
  orders --[payment_id]-> payments
  orders --[shipment_id]-> shipments
  orders --[user_id]-> users
  payments --[order_id]-> orders
  products --[category_id]-> categories
  products --[supplier_id]-> suppliers
  reviews --[product_id]-> products
  reviews --[user_id]-> users
  shipments --[address_id]-> addresses
  shipments --[order_id]-> orders
  users --[address_id]-> addresses


In [None]:
# This section processes the JSON schema to prepare it for semantic search.
# 1. Extracts all table names and their column details into a text list (`texts`) and a metadata list (`meta`).
# 2. Generates embeddings for each text entry using the 'all-MiniLM-L6-v2' SentenceTransformer model.
# 3. Builds a FAISS L2 index for efficient similarity search across schema elements.
# 4. Saves both the FAISS index and the associated metadata to disk for later retrieval.

DATA_PATH = "./"
FAISS_INDEX_PATH = os.path.join(DATA_PATH, "schema_faiss_index.faiss")
META_PATH = os.path.join(DATA_PATH, "schema_meta.pkl")

texts, meta = [], []
for table_name, table_info in schema.get("tables", {}).items():
    texts.append(f"Table: {table_name}")
    meta.append({"type": "table", "name": table_name})

    for col in table_info.get("columns", []):
        col_name = col.get("name", "")
        col_desc = col.get("description", "")
        full_text = f"Column: {col_name} - {col_desc} (Table: {table_name})"
        texts.append(full_text)
        meta.append({
            "type": "column",
            "table": table_name,
            "name": col_name,
            "description": col_desc
        })

print(f"Extracted {len(texts)} schema entries for embeddings.")

model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
d = embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(embeddings)
print("Built FAISS index. n_items:", index.ntotal)

faiss.write_index(index, FAISS_INDEX_PATH)
with open(META_PATH, "wb") as f:
    pickle.dump({"meta": meta, "texts": texts}, f)
print("Saved FAISS index & metadata.")


Extracted 111 schema entries for embeddings.


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Built FAISS index. n_items: 111
Saved FAISS index & metadata.


In [None]:
# This section handles loading and searching the schema's FAISS index.
# 1. `load_schema_faiss()` loads the FAISS index and metadata (`meta` and `texts`) from disk if available.
# 2. After loading, it confirms availability of the index and metadata.
# 3. `search_schema(query, top_k)` encodes the query into embeddings, searches the FAISS index for the top-k most similar schema entries,
#    and returns results containing the similarity score, matched text, and metadata.

def load_schema_faiss():
    idx = faiss.read_index(FAISS_INDEX_PATH) if os.path.exists(FAISS_INDEX_PATH) else None
    mt = None; tx = None
    if os.path.exists(META_PATH):
        with open(META_PATH, "rb") as f:
            cont = pickle.load(f)
            if isinstance(cont, dict) and "meta" in cont and "texts" in cont:
                mt, tx = cont["meta"], cont["texts"]
    return idx, mt, tx

index, meta, texts = load_schema_faiss()
if index is not None: print("  Loaded Schema FAISS index")
if meta is not None: print(" Loaded schema metadata length:", len(meta))

def search_schema(query, top_k=5):
    if index is None or meta is None or texts is None:
        print(" Schema index or metadata not loaded.")
        return []
    q_emb = model.encode([query], convert_to_numpy=True)
    D, I = index.search(q_emb, top_k)
    results = []
    for dist, idx in zip(D[0], I[0]):
        if idx < 0: continue
        results.append({"score": float(dist), "text": texts[idx], "meta": meta[idx], "meta_idx": int(idx)})
    return results


# This function assembles a textual summary of retrieved schema data.
# It takes:
#   - faiss_results: list of FAISS search matches (score + text + metadata)
#   - neo4j_nodes: optional list of Neo4j graph node dictionaries
#   - max_chunks: limit for number of FAISS results to include
# Output:
#   - Returns a combined string with:
#       1) Top FAISS matches (score + description)
#       2) Optional JSON previews of Neo4j nodes (truncated to 400 chars each)


def assemble_schema_context(faiss_results, neo4j_nodes=None, max_chunks=5):
    parts = []
    parts.append("=== Retrieved schema entries ===")
    for i, r in enumerate(faiss_results[:max_chunks]):
        parts.append(f"[Result {i+1}] score={r['score']:.4f} → {r['text']}")
    parts.append("\n=== Graph nodes summary ===")
    if neo4j_nodes:
        for n in neo4j_nodes:
            parts.append(json.dumps(n, ensure_ascii=False)[:400])
    return "\n".join(parts)


  Loaded Schema FAISS index
 Loaded schema metadata length: 111


In [None]:
# This function creates a complete snapshot of the database schema from Neo4j.
# It retrieves:
#   1) All table names (:Table nodes) and their columns (parsed from JSON string)
#   2) All foreign key relationships (:FK edges) between tables
# Output:
#   - A dictionary with:
#       "tables": list of table names
#       "table_columns": mapping table_name -> list of column names
#       "fks": list of FK relationships (start_table, fk_column, end_table)

def get_schema_snapshot(driver) -> Dict[str, Any]:
    snap = {"tables": [], "fks": [], "table_columns": {}}
    with driver.session() as sess:
        res = sess.run("MATCH (t:Table) RETURN t.name AS name, t.columns AS columns ORDER BY name")
        for r in res:
            tname = r["name"]
            snap["tables"].append(tname)
            cols_list = []
            cols_str = r.get("columns")
            if isinstance(cols_str, str) and cols_str.strip():
                try:
                    cols = json.loads(cols_str)
                    cols_list = [c.get("name") for c in cols if isinstance(c, dict)]
                except:
                    pass
            snap["table_columns"][tname] = cols_list

        res_fk = sess.run("""
            MATCH (t1:Table)-[r:FK]->(t2:Table)
            RETURN t1.name AS start_table, r.column AS fk_column, t2.name AS end_table
            ORDER BY start_table, end_table, fk_column
        """)
        snap["fks"] = res_fk.data()
    return snap

# This function converts a schema snapshot (from get_schema_snapshot)
# into a readable text format showing:
#   1) All tables with up to 12 of their columns (and count of extra columns if any)
#   2) All foreign key relationships in the format: table --[column]-> table
# Returns a multi-line string ready for display or inclusion in prompts.

def render_schema_text(snapshot: Dict[str, Any]) -> str:
    lines = ["TABLES:"]
    for t in snapshot["tables"]:
        cols = snapshot["table_columns"].get(t) or []
        prev = ", ".join(cols[:12])
        extra = "" if len(cols) <= 12 else f" (+{len(cols)-12} more)"
        lines.append(f" - {t}: {prev}{extra}")
    lines.append("\nFOREIGN KEYS (Table --[column]-> Table):")
    for fk in snapshot["fks"]:
        lines.append(f" - {fk['start_table']} --[{fk.get('fk_column','?')}]-> {fk['end_table']}")
    return "\n".join(lines)


In [None]:
import re

FORBIDDEN = re.compile(
    r"\b(CREATE|MERGE|DELETE|SET|DROP|CALL\s+dbms|CALL\s+db\.|apoc\.[\w.]*write|LOAD\s+CSV)\b",
    re.I
)

def extract_cypher(text: str) -> str:



    text = re.sub(r"(?i)^here is the cypher query[^\n]*:\n*", "", text.strip())


    m = re.search(r"```(?:cypher)?\s*([\s\S]+?)```", text, flags=re.I)
    if m:
        text = m.group(1).strip()


    text = re.sub(r"(?i)^i apologize[^\n]*\n*", "", text)


    if not re.match(r"^(MATCH|RETURN|WITH|CALL)", text.strip(), re.I):
        raise ValueError(f"Invalid Cypher start: {text[:50]}")


    if FORBIDDEN.search(text):
        raise ValueError(f"Forbidden Cypher command detected: {text}")

    return text.strip()



In [None]:
def generate_cypher_from_groq(question: str, schema_text: str, memory_text: str, model: str = "llama3-8b-8192") -> str:
    import os
    import requests

    api_key = os.getenv("GROQ_API_KEY")
    if not api_key:
        raise RuntimeError("GROQ_API_KEY not set in environment")

    url = "https://api.groq.com/openai/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }


    system_prompt = f"""
You are an AI assistant generating valid Cypher queries for Neo4j.
The database stores the schema as a graph:
- Each table is a node with label :Table and property "name".
- Each table node also has a property "columns" (JSON list of columns with descriptions).
- Relationships between tables are stored as :FK edges with property "column".

STRICT RULES:
- You MUST NOT put {{}} property filters inside a MATCH relationship pattern. This is FORBIDDEN.
- If you need to filter by relationship property, you MUST use a WHERE clause AFTER the MATCH.
- You MUST always use the label :Table for all nodes.
- You MUST filter specific table names using WHERE t.name = "<table_name>".
- Never use labels like :users, :orders, :products. They do not exist.
- Always access properties as t.name or t.columns (never bare `name`).
- Output must start with MATCH, RETURN, WITH, or CALL.
- Return ONLY the Cypher query. Do not include explanations or comments.
- When matching two :Table nodes (even if they are the same table),
  you MUST use different variables (t1, t2).
  Example of self-join:
  MATCH (t1:Table)-[r:FK]->(t2:Table)
  WHERE t1.name = "categories" AND t2.name = "categories" AND r.column = "parent_category_id"
  RETURN t1.name, r.column, t2.name

IMPORTANT:
- Your ENTIRE output MUST be ONLY one valid Cypher query.
- Do NOT include apologies, explanations, markdown, or text outside the Cypher.
- If unsure, still output a best-effort Cypher query (never natural language).
INTENTS:

1. **Relations (foreign key relationships)**
   Pattern:
   MATCH (t1:Table)-[r:FK]->(t2:Table)
   [Optional WHERE r.column = "<column>"]
   RETURN
     t1.name AS `Table 1`,
     r.column AS `FK Name`,
     t2.name AS `Table 2`

2. **Columns of a table**
   Pattern:
   MATCH (t:Table)
   WHERE t.name = "<table_name>"
   RETURN t.columns AS Columns

3. **Primary key of a table**
   Same query as columns:
   MATCH (t:Table)
   WHERE t.name = "<table_name>"
   RETURN t.columns AS Columns
   (the analysis step will extract the column whose description contains "Primary Key")

4. **Foreign key columns of a table**
   Same query as columns:
   MATCH (t:Table)
   WHERE t.name = "<table_name>"
   RETURN t.columns AS Columns
   (the analysis step will extract the columns whose description contains "Foreign Key")

5. **List all tables**
   MATCH (t:Table)
   RETURN t.name AS Table

6. **Find tables with a specific column name**
   MATCH (t:Table)
   WHERE any(col IN apoc.convert.fromJsonList(t.columns) WHERE col.name = "<column_name>")
   RETURN t.name AS Table

Few-shot examples:

Example 1:
Question: "Show all foreign key relationships."
Cypher:
MATCH (t1:Table)-[r:FK]->(t2:Table)
RETURN
  t1.name AS `Table 1`,
  r.column AS `FK Name`,
  t2.name AS `Table 2`

Example 2:
Question: "Show orders linked to shipments."
Cypher:
MATCH (t1:Table)-[r:FK]->(t2:Table)
WHERE r.column = "shipment_id"
RETURN
  t1.name AS `Table 1`,
  r.column AS `FK Name`,
  t2.name AS `Table 2`

Example 3:
Question: "What are the columns in the users table?"
Cypher:
MATCH (t:Table)
WHERE t.name = "users"
RETURN t.columns AS Columns

Example 4:
Question: "What is the primary key of the categories table?"
Cypher:
MATCH (t:Table)
WHERE t.name = "categories"
RETURN t.columns AS Columns

Example 5:
Question: "Which columns in the orders table are foreign keys?"
Cypher:
MATCH (t:Table)
WHERE t.name = "orders"
RETURN t.columns AS Columns

Example 6:
Question: "List all tables in the database."
Cypher:
MATCH (t:Table)
RETURN t.name AS Table

Example 7:
Question: "Which tables contain a column named created_at?"
Cypher:
MATCH (t:Table)
WHERE any(col IN apoc.convert.fromJsonList(t.columns) WHERE col.name = "created_at")
RETURN t.name AS Table.

- If the user refers to "previous table", "previous column", or any ambiguous reference,
  you MUST resolve it using the last conversation memory (memory_buffer).
  Replace the ambiguous phrases with the actual table or column name mentioned previously.

Conversation Memory:
{memory_text}

Filtered Schema:
{schema_text}

User Question:
{question}

Return:
"""



    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": "Follow the instructions from the system message exactly when generating the Cypher query."}
        ],
        "max_tokens": 600,
        "temperature": 0.0
    }

    resp = requests.post(url, headers=headers, json=payload, timeout=120)
    resp.raise_for_status()
    text = resp.json()["choices"][0]["message"]["content"].strip()

    return extract_cypher(text)


In [None]:
# This function checks if a generated Cypher query is safe for the schema-only database.
# It:
# 1) Flattens the query into a single line for easier pattern matching.
# 2) Blocks any query containing forbidden write/management commands (e.g., CREATE, DELETE, MERGE).
# 3) Ensures the query explicitly targets :Table nodes (schema graph only).
# Returns (True, "") if safe, otherwise (False, reason).

def is_safe_schema_query(q: str) -> Tuple[bool, str]:
    q_flat = " ".join(q.strip().split())
    if FORBIDDEN.search(q_flat):
        return False, "Query contains forbidden write/management clauses."
    if ":Table" not in q_flat:
        return False, "Query must target :Table nodes (schema graph)."
    return True, ""

In [None]:

def llm(prompt: str) -> str:
    import os, requests
    api_key = os.getenv("GROQ_API_KEY")
    if not api_key:
        raise RuntimeError("GROQ_API_KEY not set in environment")

    url = "https://api.groq.com/openai/v1/chat/completions"
    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}

    payload = {
        "model": "llama3-8b-8192",
        "messages": [
            {"role": "system", "content": "You are a professional assistant generating clear answers from database results."},
            {"role": "user", "content": prompt}
        ],
        "max_tokens": 600,
        "temperature": 0.0
    }

    resp = requests.post(url, headers=headers, json=payload, timeout=120)
    resp.raise_for_status()
    result = resp.json()
    print("DEBUG LLM RESPONSE:", result)
    return result["choices"][0]["message"]["content"].strip()


    resp = requests.post(url, headers=headers, json=payload, timeout=120)
    resp.raise_for_status()
    return resp.json()["choices"][0]["message"]["content"].strip()


def generate_answer_from_cypher_result(question: str, cypher_result: list) -> str:
    if not cypher_result:
        return "There is no data matching the question."

    data_text = "\n".join([", ".join(f"{k}: {v}" for k, v in row.items()) for row in cypher_result])

    prompt = f"""
You have the following user question:
"{question}"

You also have the following data obtained from a Neo4j database after executing a Cypher query:
{data_text}

Your task:
1. Detect the language of the user's question.
2. Compose a clear, natural, and easily understandable answer in the same language as the question.
3. Focus on conveying the key information directly without showing tables or Cypher queries.
4. If the data shows relationships or links between elements, describe them smoothly and clearly.
5. Use professional and fluent language with a simple and comprehensible style.

Write the final answer so that it is ready to be presented directly to the user.
"""

    return llm(prompt)


In [None]:
# =========================
# 10) LangGraph: State + Nodes
# =========================

class NewState(TypedDict):
    """State structure for LangGraph execution."""
    query: str
    reasoning_plan: Dict[str, Any]
    cypher_query: str
    final_answer:str
    cypher_result: List[Dict[str, Any]]
    analysis: Dict[str, Any]
    cypher_error: str
    errors: List[str]
    memory_buffer: List[Dict[str, str]]
    memory_summary: str


# -------- Reasoning Node --------
def node_reasoning(state: NewState) -> Dict[str, Any]:
    """
    Determines the intent of the user query.

    Possible intents:
        - 'relations'
        - 'columns'
        - 'tables_or_misc'

    This intent is for guidance only, not directly used in the Cypher.
    """
    q = state.get("query", "").lower()

    if "علاقات" in q or "relationships" in q or "fk" in q:
        intent = "relations"
    elif "columns" in q or "أعمدة" in q:
        intent = "columns"
    else:
        intent = "tables_or_misc"

    return {"reasoning_plan": {"intent": intent}}



# -------- Cypher Generation Node --------
def node_generate_cypher(state: NewState) -> Dict[str, Any]:
    """
    Generates a Cypher query from the user's question using schema snapshot + (optional) memory.
    Validates that the generated Cypher is safe.
    """
    snap = get_schema_snapshot(driver)
    schema_text = render_schema_text(snap)


    memory_text = ""
    if state.get("memory_summary"):
        memory_text += f"\n Summary of previous conversations:\n{state['memory_summary']}\n"
    if state.get("memory_buffer"):
        memory_text += "\n Last 3 conversations:\n"
        for m in state["memory_buffer"]:
            memory_text += f"Q: {m['question']}\nA: {m['answer']}\n"

    try:

        if memory_text.strip():
            cypher = generate_cypher_from_groq(
                state["query"],
                schema_text,
                memory_text
             )
        else:
            cypher = generate_cypher_from_groq(
                state["query"],
                schema_text,
                memory_text
            )
    except Exception as e:
        return {
            "cypher_query": "",
            "cypher_error": f"Groq error: {e}",
            "errors": [str(e)]
        }

    ok, reason = is_safe_schema_query(cypher)
    if not ok:
        return {
            "cypher_query": cypher,
            "cypher_error": f"Unsafe query: {reason}",
            "errors": [reason]
        }

    return {"cypher_query": cypher}


# -------- Cypher Execution Node --------
def node_run_cypher(state: NewState) -> Dict[str, Any]:
    """
    Executes the generated Cypher query on Neo4j and returns the results.
    """
    q = state.get("cypher_query", "").strip()
    if not q:
        return {"cypher_result": [], "cypher_error": "No valid Cypher generated."}
    try:
        with driver.session() as sess:
            rows = sess.run(q).data()
        return {"cypher_result": rows}
    except Exception as e:
        return {"cypher_result": [], "cypher_error": str(e), "errors": [str(e)]}


def analyze_and_answer_ar(cypher_result: list) -> str:

    if not cypher_result:
        return "There is no data matching the question."


    lines = []
    for row in cypher_result:
        parts = []
        for k, v in row.items():
            parts.append(f"{k}: {v}")
        lines.append(", ".join(parts))

    return "\n".join(lines)

# -------- Analysis Node --------
def node_analyze_results(state: NewState) -> Dict[str, Any]:
    """
    Processes Cypher results and generates a formatted analysis in Arabic.
    """
    if state.get("cypher_error"):
        return {"analysis": {"summary": f"Error during execution: {state['cypher_error']}", "preview": []}}

    out = analyze_and_answer_ar(state.get("cypher_result", []))
    return {"analysis": {"summary": out, "preview": state.get("cypher_result", [])[:5]}}


def node_generate_final_answer(state: NewState) -> Dict[str, Any]:
    question = state.get("query", "")
    cypher_result = state.get("cypher_result", [])
    try:

        answer = generate_answer_from_cypher_result(question, cypher_result)

    except Exception as e:
        answer = f"حدث خطأ أثناء توليد الإجابة: {e}"

    return  {"final_answer":answer}



def node_memory(state: NewState) -> Dict[str, Any]:
    """
    Manages short-term memory:
    - Keeps last 3 interactions in memory_buffer
    - Summarizes older ones into memory_summary
    """
    buffer = state.get("memory_buffer", [])
    summary = state.get("memory_summary", "")


    if state.get("query") and state.get("final_answer"):
        buffer.append({
            "question": state["query"],
            "answer": state["final_answer"]
        })

    print(len(buffer))
    if len(buffer) > 3:
        old = buffer[:-3]
        buffer = buffer[-3:]


        old_text = "\n".join([f"Q: {m['question']} → A: {m['answer']}" for m in old])
        summary_prompt = f"""
        You are a conversation summarizer.

        Your task:
        - Summarize the following conversations briefly and clearly.
        - The summary should capture the main topics and answers without unnecessary details.
        - Support both Arabic and English (use the same language style of the input).
        - Do not rewrite the full Q/A, only the essence.
        - Keep it short and coherent (3–4 sentences max).

        Conversations to summarize:
        {old_text}
        """

        try:
            new_summary = llm(summary_prompt)
        except:
            new_summary = old_text[:300]

        summary = (summary + "\n" + new_summary).strip()

    return {"memory_buffer": buffer, "memory_summary": summary}



def node_output(state: NewState) -> Dict[str, Any]:
    """
    Final output node — can be extended to format or return final response.
    Currently returns an empty dict as placeholder.
    """
    return {}




In [None]:
wf2 = StateGraph(NewState)

wf2.add_node("memory", node_memory)
wf2.add_node("reasoning", node_reasoning)
wf2.add_node("generate_cypher", node_generate_cypher)
wf2.add_node("run_cypher", node_run_cypher)
wf2.add_node("analyze_results", node_analyze_results)
wf2.add_node("generate_final_answer", node_generate_final_answer)
wf2.add_node("output", node_output)

wf2.add_edge(START, "reasoning")
wf2.add_edge("reasoning", "generate_cypher")
wf2.add_edge("generate_cypher", "run_cypher")
wf2.add_edge("run_cypher", "analyze_results")
wf2.add_edge("analyze_results", "generate_final_answer")
wf2.add_edge("generate_final_answer", "memory")
wf2.add_edge("memory", "output")
wf2.add_edge("output", END)

runner2 = wf2.compile()



agent_state = {
    "final_answer": "",
    "query": "",
    "reasoning_plan": {},
    "cypher_query": "",
    "cypher_result": [],
    "analysis": {},
    "cypher_error": "",
    "errors": [],
    "memory_buffer": [],
    "memory_summary": ""
}

def ask_schema_agent(question: str):
    global agent_state

    agent_state["query"] = question
    agent_state["final_answer"] = ""


    out: NewState = runner2.invoke(agent_state)


    agent_state.update(out)

    print("─"*70)
    print("  السؤال:", question)
    print("─"*70)
    print("\n الإجابة النهائية:\n", out.get("final_answer"))
    print("─"*70)
    print("\n الملخص:\n", (out.get("analysis") or {}).get("summary",""))
    print("─"*70)
    print("\n الخطة:", out.get("reasoning_plan"))
    print("─"*70)
    print("\n الكويري:\n", out.get("cypher_query",""))
    print("─"*70)
    print("\n عينة نتائج:", out.get("cypher_result", [])[:5])
    print("─"*70)
    print("\n  آخر 3 محادثات:", out.get("memory_buffer", []))
    print("\n  ملخص المحادثات الأقدم:", out.get("memory_summary", ""))

    return out


In [None]:
ask_schema_agent("هاتلي أسماء الأعمدة في آخر جدول اتكلمنا عنه")

DEBUG LLM RESPONSE: {'id': 'chatcmpl-ca4a285a-f046-404e-a322-fddea62ddae9', 'object': 'chat.completion', 'created': 1755544512, 'model': 'llama3-8b-8192', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': "Based on the user's question, I detect that the language is Arabic.\n\nHere's the answer:\n\nالاسماء الأعمدة في آخر جدول اتكلمنا عنه هي:\n\n* address_id: هو رقم فريد ل каждый عنوان (رقم رئيسي)\n* user_id: هو رقم فريد يرتبط بالجدول المستخدمين\n* street: هو عنوان الشارع\n* city: هو اسم المدينة\n* state: هو اسم الولاية أو المنطقة\n* zip_code: هو رمز البريد أو الرمز البريدي\n* country: هو اسم البلد\n* created_at: هو توقيت عندما تم إنشاء العنوان\n* updated_at: هو توقيت عندما تم تحديث العنوان\n* address_type: هو نوع العنوان (مثل الشحن أو الفاتورة)\n\nTranslation:\n\nThe column names in the last table we discussed are:\n\n* address_id: is a unique identifier for each address (primary key)\n* user_id: is a foreign key linking to the users table\n* street: is the street add

{'query': 'هاتلي أسماء الأعمدة في آخر جدول اتكلمنا عنه',
 'reasoning_plan': {'intent': 'columns'},
 'cypher_query': 'MATCH (t:Table)\nWHERE t.name = "addresses"\nRETURN t.columns AS Columns',
 'final_answer': "Based on the user's question, I detect that the language is Arabic.\n\nHere's the answer:\n\nالاسماء الأعمدة في آخر جدول اتكلمنا عنه هي:\n\n* address_id: هو رقم فريد ل каждый عنوان (رقم رئيسي)\n* user_id: هو رقم فريد يرتبط بالجدول المستخدمين\n* street: هو عنوان الشارع\n* city: هو اسم المدينة\n* state: هو اسم الولاية أو المنطقة\n* zip_code: هو رمز البريد أو الرمز البريدي\n* country: هو اسم البلد\n* created_at: هو توقيت عندما تم إنشاء العنوان\n* updated_at: هو توقيت عندما تم تحديث العنوان\n* address_type: هو نوع العنوان (مثل الشحن أو الفاتورة)\n\nTranslation:\n\nThe column names in the last table we discussed are:\n\n* address_id: is a unique identifier for each address (primary key)\n* user_id: is a foreign key linking to the users table\n* street: is the street address\n* city: i