In [1]:
import boto3
from strands import Agent, tool
from strands.models import BedrockModel
from strands_tools import current_time
from pydantic import BaseModel
from dotenv import load_dotenv
import os

In [2]:
from neo4j import GraphDatabase
from typing import List, Dict, Any, Literal, Optional

In [3]:
load_dotenv()

True

#### Creating Database Client

In [4]:
class Neo4jClient:
    def __init__(self, uri: str, auth : tuple, database : str):
        self.driver = GraphDatabase.driver(
            uri= uri, auth=auth)
        self.database = database

    def close(self):
        self.driver.close()

    def run_cypher(self, query: str, params: Dict[str, Any] | None = None) -> List[Dict]:
        with self.driver.session() as session:
            result = session.run(query, params or {})
            return [record.data() for record in result]

In [5]:
### Adding global vars for connectivity
DB_URI = os.getenv('DB_URI')
DB_PASSWORD= os.getenv('DB_PASSWORD')
TARGET_DB = os.getenv('TARGET_DB')
DB_USER=os.getenv('DB_USER')
AUTH = (DB_URI, TARGET_DB)

#### Creating connection to AWS services: Bedrock

In [6]:
session = boto3.Session(aws_access_key_id=os.getenv("aws_access_key_id"),
                        aws_secret_access_key=os.getenv("aws_secret_access_key"),
                        region_name=os.getenv('region_name'))
bedrock_client = session.client('bedrock-runtime')

#### Creating Cypher Execution tool

In [7]:
def neo4j_cypher_tool(
    cypher: str,
    params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
    """
    Execute a Cypher query against Neo4j and return results.
    """
    client = Neo4jClient(
        uri=DB_URI,
        auth=(DB_USER, DB_PASSWORD),
        database=TARGET_DB
    )

    try:
        print('sucessfully connected to database')
        records = client.run_cypher(cypher, params)
        return {
            "row_count": len(records),
            "records": records
        }
    finally:
        client.close()


#### Defining Tool for Strands

In [8]:
FORBIDDEN_KEYWORDS = {
    "CREATE", "MERGE", "DELETE", "SET", "DROP",
    "REMOVE", "CALL", "LOAD CSV", "APOC"
}

def assert_read_only(cypher: str):
    upper = cypher.upper()
    for kw in FORBIDDEN_KEYWORDS:
        if kw in upper:
            raise ValueError(f"Forbidden Cypher keyword: {kw}")


In [9]:
RELATIONSHIP_CANONICAL = {
    "WROTE": "WORK_AUTHORED_BY",
    "AUTHORED": "WORK_AUTHORED_BY",
    "AUTHORED_BY": "WORK_AUTHORED_BY"
}

def normalize_relationships(cypher: str) -> str:
    for bad, good in RELATIONSHIP_CANONICAL.items():
        cypher = cypher.replace(f":{bad}", f":{good}")
    return cypher


In [10]:
VALID_LABELS = {"Author", "Work"}

def validate_labels(cypher: str):
    import re
    labels = re.findall(r":([A-Za-z_][A-Za-z0-9_]*)", cypher)
    for label in labels:
        if label not in VALID_LABELS and label not in RELATIONSHIP_CANONICAL.values():
            raise ValueError(f"Unknown label or relationship: {label}")


In [11]:
def prepare_cypher(cypher: str) -> str:
    assert_read_only(cypher)
    cypher = normalize_relationships(cypher)
    validate_labels(cypher)
    return cypher

In [12]:
@tool(
    name="neo4j_query_tool",
    description="Execute a READ-ONLY Cypher query against Neo4j"
)
def neo4j_query_tool(cypher: str, **kwargs) -> dict:
    try:
        safe_cypher = prepare_cypher(cypher)
    except Exception as e:
        return {
            "error": "cypher_validation_error",
            "message": str(e),
            "original_cypher": cypher
        }

    client = Neo4jClient(
        uri=DB_URI,
        auth=(DB_USER, DB_PASSWORD),
        database=TARGET_DB
    )

    try:
        records = client.run_cypher(safe_cypher)
        return {
            "row_count": len(records),
            "records": records
        }
    finally:
        client.close()


#### Creating Valid Contracts for Strands Model to Follow for cypher queries

In [13]:
AuthorLabel = Literal["Author"]
RelationshipType = Literal["WORK_AUTHORED_BY"]
WorkLabel = Literal["Work"]
Direction = Literal["OUT"]

In [14]:
VALID_NODE_LABELS = {"Author", "Work"}
VALID_RELATIONSHIPS = {
    "WORK_AUTHORED_BY": {
        "from": "Author",
        "to": "Work",
        "direction": "OUT"
    }
}


In [15]:
class MatchPattern(BaseModel):
    start_label: AuthorLabel
    relationship: RelationshipType
    end_label: WorkLabel
    direction: Direction = "OUT"


In [16]:
class Aggregation(BaseModel):
    function: Literal["count"]
    variable: Literal["w"]
    alias: str


In [17]:
class OrderBy(BaseModel):
    field: str
    direction: Literal["ASC", "DESC"] = "DESC"

In [18]:
class Filter(BaseModel):
    field: str
    op: Literal["=", ">", "<"]
    value: str | int


In [19]:
# Adding full query plan
class CypherQueryPlan(BaseModel):
    match: MatchPattern
    aggregations: List[Aggregation] = []
    return_fields: List[str]
    order_by: Optional[OrderBy] = None
    limit: Optional[int] = None


#### Rendering the Cypher

In [20]:
def render_cypher(plan: CypherQueryPlan) -> str:
    m = plan.match

    rel = f"-[:{m.relationship}]->"

    match_clause = (
        f"MATCH (a:{m.start_label})"
        f"{rel}"
        f"(w:{m.end_label})"
    )

    returns = plan.return_fields.copy()

    for agg in plan.aggregations:
        returns.append(
            f"{agg.function.upper()}({agg.variable}) AS {agg.alias}"
        )

    return_clause = "RETURN " + ", ".join(returns)

    order_clause = ""
    if plan.order_by:
        order_clause = (
            f" ORDER BY {plan.order_by.field} "
            f"{plan.order_by.direction}"
        )

    limit_clause = f" LIMIT {plan.limit}" if plan.limit else ""

    return f"{match_clause} {return_clause}{order_clause}{limit_clause}"


### Creating the Strands Agents to Bedrock

In [32]:
# Creating Bedrock Model Instance
bedrock_model = BedrockModel(
    model_id='amazon.nova-lite-v1:0',
    temperature=0.0
)

In [37]:
# Creating an agent using BedrockModel instance
first_agent = Agent(model=bedrock_model,
                    tools=[neo4j_query_tool],
                    system_prompt="""
                    You are an assistant that can query a Neo4j database.

Rules:
- Call the tool `neo4j_query_tool` if data is required.
- Output ONLY the tool call.
- Do NOT include any other text, thoughts, explanations, or commentary
  in the same message as a tool call.
- After the tool result is returned, you may explain the answer.

Safety:
- Only generate READ-ONLY Cypher.
- Never use CREATE, MERGE, DELETE, SET, DROP, CALL, or APOC.

                    """,
)

In [42]:
question = """
    Return 5 Authors who are cited more than 5 times.
 """

In [43]:
response = first_agent(question)


<thinking>The previous query did not return any authors who are cited more than 5 times. I need to execute a new query to return 5 authors who are cited more than 5 times. This query will filter authors by the 'cited_by_count' property.</thinking> 
Tool #3: neo4j_query_tool
Here are 5 authors who are cited more than 5 times:

1. **Larsson Karl-Henrik**
   - **Cited by Count**: 13
   - **Works Count**: 3,201,122

2. **Geoscience Australia**
   - **Cited by Count**: 101
   - **Works Count**: 3,169,852

3. **Sarah Leichty**
   - **Cited by Count**: 33
   - **Works Count**: 102,977

4. **Joshua Adkins**
   - **Cited by Count**: 16,511
   - **Works Count**: 99,448

5. **George M Garrity**
   - **Cited by Count**: 48,178
   - **Works Count**: 96,201