In [1]:
import boto3
from strands import Agent, tool
from strands.models import BedrockModel
from strands_tools import current_time
from pydantic import BaseModel
from dotenv import load_dotenv
import os, re

In [2]:
from neo4j import GraphDatabase
from typing import List, Dict, Any, Literal, Optional

In [3]:
load_dotenv()

True

#### Creating Database Client

In [8]:
class Neo4jClient:
    def __init__(self, uri: str, auth : tuple, database : str):
        self.driver = GraphDatabase.driver(
            uri= uri, auth=auth)
        self.database = database

    def close(self):
        self.driver.close()

    def run_cypher(self, query: str, params: Dict[str, Any] | None = None) -> List[Dict]:
        with self.driver.session(database=self.database) as session:
            print(query) # adding for llm visibility
            result = session.run(query, params or {})
            return [record.data() for record in result]

In [5]:
### Adding global vars for connectivity
DB_URI = os.getenv('DB_URI')
DB_PASSWORD= os.getenv('DB_PASSWORD')
TARGET_DB = os.getenv('TARGET_DB')
DB_USER=os.getenv('DB_USER')
AUTH = (DB_URI, TARGET_DB)

#### Creating connection to AWS services: Bedrock

In [6]:
session = boto3.Session(aws_access_key_id=os.getenv("aws_access_key_id"),
                        aws_secret_access_key=os.getenv("aws_secret_access_key"),
                        region_name=os.getenv('region_name'))
bedrock_client = session.client('bedrock-runtime')

#### Creating Cypher Execution tool

In [7]:
def neo4j_cypher_tool(
    cypher: str,
    params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
    """
    Execute a Cypher query against Neo4j and return results.
    """
    client = Neo4jClient(
        uri=DB_URI,
        auth=(DB_USER, DB_PASSWORD),
        database=TARGET_DB
    )

    try:
        print('sucessfully connected to database')
        records = client.run_cypher(cypher, params)
        return {
            "row_count": len(records),
            "records": records
        }
    finally:
        client.close()


#### Defining Tool for Strands

In [9]:
FORBIDDEN_KEYWORDS = {
    "CREATE", "MERGE", "DELETE", "SET", "DROP",
    "REMOVE", "CALL", "LOAD CSV", "APOC"
}

def assert_read_only(cypher: str):
    upper = cypher.upper()
    for kw in FORBIDDEN_KEYWORDS:
        if kw in upper:
            raise ValueError(f"Forbidden Cypher keyword: {kw}")


In [10]:
SCHEMA = {
    "Author": {
        "id",
        "name",
        "display_name"
    },
    "Work": {
        "id",
        "title",
        "type",
        "publication_date"   # âœ… canonical
    },
    "Topic": {
        "id",
        "display_name",
        "score"
    }
}

PROPERTY_ALIASES = {
    "publication_year": "publication_date",
    "pub_year": "publication_date",
    "year": "publication_date"
}


In [11]:
def validate_properties(cypher: str):
    """
    Fail fast if the query references properties
    not present in the schema.
    """
    matches = re.findall(r"(\w+)\.(\w+)", cypher)

    for var, prop in matches:
        # we can't reliably infer labels from vars,
        # so we check against ALL known properties
        if not any(prop in props for props in SCHEMA.values()):
            raise ValueError(
                f"Unknown property `{prop}` referenced in Cypher."
            )


In [12]:
def normalize_properties(cypher: str) -> str:
    """
    Replace known hallucinated property names with canonical ones.
    """
    for bad, good in PROPERTY_ALIASES.items():
        # match `.publication_year` safely
        cypher = re.sub(
            rf"\.{bad}\b",
            f".{good}",
            cypher
        )
    return cypher

In [13]:
RELATIONSHIP_CANONICAL = {
    "WROTE": "WORK_AUTHORED_BY",
    "AUTHORED": "WORK_AUTHORED_BY",
    "AUTHORED_BY": "WORK_AUTHORED_BY",
    "HAS_TOPIC": "WORK_HAS_TOPIC",
    "TOPIC_IN": "WORK_HAS_TOPIC"
}

def normalize_relationships(cypher: str) -> str:
    for bad, good in RELATIONSHIP_CANONICAL.items():
        cypher = cypher.replace(f":{bad}", f":{good}")
    return cypher


In [14]:
VALID_LABELS = {"Author", "Work", "Topic"}

def validate_labels(cypher: str):
    import re
    labels = re.findall(r":([A-Za-z_][A-Za-z0-9_]*)", cypher)
    for label in labels:
        if label not in VALID_LABELS and label not in RELATIONSHIP_CANONICAL.values():
            raise ValueError(f"Unknown label or relationship: {label}")


In [15]:
def prepare_cypher(cypher: str) -> str:
    assert_read_only(cypher)
    cypher = normalize_relationships(cypher)
    validate_labels(cypher)
    return cypher

In [16]:
@tool(
    name="neo4j_query_tool",
    description="Execute a READ-ONLY Cypher query against Neo4j"
)
def neo4j_query_tool(cypher: str, **kwargs) -> dict:
    try:
        safe_cypher = prepare_cypher(cypher)
        validate_properties(safe_cypher)
    except Exception as e:
        return {
            "error": "cypher_validation_error",
            "message": str(e),
            "original_cypher": cypher
        }

    client = Neo4jClient(
        uri=DB_URI,
        auth=(DB_USER, DB_PASSWORD),
        database=TARGET_DB
    )

    try:
        records = client.run_cypher(safe_cypher)
        return {
            "row_count": len(records),
            "records": records
        }
    finally:
        client.close()


#### Creating Valid Contracts for Strands Model to Follow for cypher queries

In [17]:
AuthorLabel = Literal["Author"]
TopicLabel = Literal["Topic"]
RelationshipType = Literal["WORK_AUTHORED_BY"]
WorkLabel = Literal["Work"]
Direction = Literal["OUT"]

In [18]:
VALID_NODE_LABELS = {"Author", "Work", "Topic"}
VALID_RELATIONSHIPS = {
    "WORK_AUTHORED_BY": {
        "from": "Author",
        "to": "Work",
        "direction": "OUT"
    },
    "WORK_HAS_TOPIC": {
        "from": "Work",
        "to": "Topic",
        "direction": "OUT"
    }
}


In [19]:
class MatchPattern(BaseModel):
    start_label: AuthorLabel
    relationship: RelationshipType
    end_label: WorkLabel
    direction: Direction = "OUT"


In [20]:
class Aggregation(BaseModel):
    function: Literal["count"]
    variable: Literal["w"]
    alias: str


In [21]:
class OrderBy(BaseModel):
    field: str
    direction: Literal["ASC", "DESC"] = "DESC"

In [22]:
class Filter(BaseModel):
    field: str
    op: Literal["=", ">", "<"]
    value: str | int


In [23]:
# Adding full query plan
class CypherQueryPlan(BaseModel):
    match: MatchPattern
    aggregations: List[Aggregation] = []
    return_fields: List[str]
    order_by: Optional[OrderBy] = None
    limit: Optional[int] = None


#### Rendering the Cypher

In [24]:
def render_cypher(plan: CypherQueryPlan) -> str:
    m = plan.match

    rel = f"-[:{m.relationship}]->"

    match_clause = (
        f"MATCH (a:{m.start_label})"
        f"{rel}"
        f"(w:{m.end_label})"
    )

    returns = plan.return_fields.copy()

    for agg in plan.aggregations:
        returns.append(
            f"{agg.function.upper()}({agg.variable}) AS {agg.alias}"
        )

    return_clause = "RETURN " + ", ".join(returns)

    order_clause = ""
    if plan.order_by:
        order_clause = (
            f" ORDER BY {plan.order_by.field} "
            f"{plan.order_by.direction}"
        )

    limit_clause = f" LIMIT {plan.limit}" if plan.limit else ""

    return f"{match_clause} {return_clause}{order_clause}{limit_clause}"


### Creating the Strands Agents to Bedrock

In [25]:
# Creating Bedrock Model Instance
bedrock_model = BedrockModel(
    model_id='anthropic.claude-3-5-sonnet-20240620-v1:0',
    temperature=0.0
)

In [26]:
# Creating an agent using BedrockModel instance
first_agent = Agent(model=bedrock_model,
                    tools=[neo4j_query_tool],
                    system_prompt="""
                    You are an assistant that can query a Neo4j database.

Rules:
- Call the tool `neo4j_query_tool` if data is required.
- Output ONLY the tool call.
- Do NOT include any other text, thoughts, explanations, or commentary
  in the same message as a tool call.
- After the tool result is returned, you may explain the answer.

Safety:
- Only generate READ-ONLY Cypher.
- Never use CREATE, MERGE, DELETE, SET, DROP, CALL, or APOC.

                    """,
)

In [27]:
question = """
    Return 10 Works authored by Josh Adkins. If there are not 10 Works, then return the remaining amount of works authored by David Moher
 """

In [28]:
response = first_agent(question)


To fulfill this request, we need to query the Neo4j database. I'll use the `neo4j_query_tool` to execute a Cypher query that will fetch the required information.
Tool #1: neo4j_query_tool
I apologize for the error in my previous query. It seems that the relationship `:HAS_AUTHOR` is not defined in the database. Let me modify the query to use the correct relationship type. I'll try using `:AUTHORED` instead, which is a common relationship type for connecting authors to their works.
Tool #2: neo4j_query_tool
MATCH (w:Work)<-[:WORK_AUTHORED_BY]-(a:Author)
WHERE a.name = 'Josh Adkins'
WITH w
LIMIT 10
RETURN w.title AS title, 'Josh Adkins' AS author
UNION
MATCH (w:Work)<-[:WORK_AUTHORED_BY]-(a:Author)
WHERE a.name = 'David Moher'
WITH w
LIMIT 10 - COUNT {
  MATCH (w:Work)<-[:WORK_AUTHORED_BY]-(a:Author)
  WHERE a.name = 'Josh Adkins'
  RETURN w
}
RETURN w.title AS title, 'David Moher' AS author
I apologize for the continued errors. It seems that the Neo4j version being used doesn't support 