In [None]:
import os
from typing import Optional
from dotenv import load_dotenv
from pydantic import BaseModel, Field

load_dotenv()


True

In [None]:
from langchain_core.tools import tool
from langchain_neo4j import Neo4jGraph
from langchain.output_parsers import OutputFixingParser
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser


In [5]:
graph = Neo4jGraph(refresh_schema=False, url="bolt://localhost:7687", username="neo4j", password="fraud_detection")


In [None]:
r = graph.query("CALL db.labels()");
print(r);


[{'label': 'Transaction'}, {'label': 'Account'}, {'label': 'Customer'}, {'label': 'Country'}]


In [19]:
r = graph.query("CALL db.relationshipTypes()");
print(r);


[{'relationshipType': 'OWNED_BY'}, {'relationshipType': 'SENT_FROM'}, {'relationshipType': 'RECEIVED_BY'}, {'relationshipType': 'LOCATED_IN'}]


In [20]:
r = graph.query("CALL db.schema.visualization()");
print(r);


[{'nodes': [{'name': 'Account', 'indexes': [], 'constraints': []}, {'name': 'Transaction', 'indexes': [], 'constraints': []}, {'name': 'Customer', 'indexes': [], 'constraints': []}, {'name': 'Country', 'indexes': [], 'constraints': []}], 'relationships': [({'name': 'Account', 'indexes': [], 'constraints': []}, 'OWNED_BY', {'name': 'Customer', 'indexes': [], 'constraints': []}), ({'name': 'Account', 'indexes': [], 'constraints': []}, 'LOCATED_IN', {'name': 'Country', 'indexes': [], 'constraints': []}), ({'name': 'Transaction', 'indexes': [], 'constraints': []}, 'SENT_FROM', {'name': 'Account', 'indexes': [], 'constraints': []}), ({'name': 'Transaction', 'indexes': [], 'constraints': []}, 'RECEIVED_BY', {'name': 'Account', 'indexes': [], 'constraints': []})]}]


In [3]:
llm = ChatGoogleGenerativeAI(
    api_key=os.getenv("GOOGLE_API_KEY"),
    model="gemini-1.5-flash",
)


In [10]:
@tool
async def cql_executor(cql: str)->str:
    """Takes CQL query as a parameter and Returns CQL execution results"""
    try:
        cql_result = graph.query(f"""{cql}""")
        return f"Query Results: {cql_result}"
    except Exception as e:
        return f"❌ Failed execution. Reason {str(e)}"


In [None]:
tools_list = [cql_executor]


In [31]:
tools = {}
for tl in tools_list:
    tools[tl.name] = {
        "description": tl.description,
        "parameters": tl.args,
        "fn": tl
    }
print(tools)


{'cql_executor': {'description': 'Takes CQL query as a parameter and Returns CQL execution results', 'parameters': {'cql': {'title': 'Cql', 'type': 'string'}}, 'fn': StructuredTool(name='cql_executor', description='Takes CQL query as a parameter and Returns CQL execution results', args_schema=<class 'langchain_core.utils.pydantic.cql_executor'>, coroutine=<function cql_executor at 0x0000029FCD6B2160>)}}


In [32]:
def escape_braces(obj):
    string_repr = str(obj)
    escaped = string_repr.replace('{', '{{').replace('}', '}}')
    return escaped


In [None]:
tools_definition = escape_braces(tools)
# tools_definition


In [35]:
class Response(BaseModel):
    step: str = Field(description="Represents the current step of a flow")
    content: str = Field(description="Represents the content according to the step. It can be empty as well based on step.", default="")
    tool_name: Optional[str] = Field(description="Represents the tool name that will be used."),
    tool_args: Optional[str] = Field(description="Represents the arguments that will be passed to the tool according to its parameters."),

parser = PydanticOutputParser(pydantic_object=Response)


In [None]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """
        You are a Fraud Detection Expert whose work is find the fraud transactions from the day to day transaction.
        
        # Context
        You works like Honest Income Tax officer who wants to maintain the economy of its country.
        You works in a modern online way. You have your graph database that stores the transaction details.
        You are expert in writing cypher query language.
        You use Cypher Query Language(CQL) to interact with your graph database.
        Now your task is to use cql queries in a way that can help you to detect the vulnerable/fraud transactions.
        For e.g. = Detect when money goes from A → B → C → A
        
        # Instructions
        You works in a flow think -> plan -> action -> observe -> result
        You first think about the given data, then you plan how to make clever cql to find vulnerabilities, then you take action according to your plan, after completing your action you observe that, and then finally give the result according to your observation.
        You can do again planning if your observation doesn't show some satisfactory output, then again action after planning, and then again observation. i.e., plan -> action -> observe -> (plan(again, if needed) -> action(again, if needed)) | result.
        If there are no vulnerabilites according to your observation(s), then you can say in result that there are no fraud transaction yet.
        
        # Tools
        You use your tools whenever needed. Here is your tools list:
        {tools}
        
        # Rules
        You take only one step at a time.
        You perform only one step at a time of a flow.
        Your flow must always end at 'result' step.
        You can use your tool in-between any step of a flow.
        
        # Output Format
        Here are your output format instructions:
        {format_instructions}
        
        # Example
        - Input: Start to find fraud transactions from your database.
        - Output: {{ 'step': 'think', 'content': 'Ok, now I have to find fraud detections from me the stored transactions. So, let's first gather some information about the stored data like what entities are there, what are the relationships between them. To get this information, I will use CQL, i.e. 'CALL db.schema.visualization();'. Now I have to execute this, to execute this I will use my tool cql_executor.' }}
        - Output: {{ 'step': 'tool', 'content': '', 'tool_name': 'cql_executor', 'tool_args': 'CALL db.schema.visualization();' }}
        - Output: {{ 'step': 'tool_res', 'content': '[{{\\n  'nodes': [\\n    {{'id': 0, 'labels': ['Account']}},\\n    {{'id': 1, 'labels': ['Transaction']}}\\n  ],\\n  'relationships': [\\n    {{\\n      'id': 0,\\n      'type': 'SENT',\\n      'startNode': 0,\\n      'endNode': 1\\n    }},\\n    {{\\n      'id': 1,\\n      'type': 'RECEIVED',\\n      'startNode': 0,\\n      'endNode': 1\\n    }}\\n  ]\\n}}]' }}
        - Output: {{ 'step': 'plan', 'content': 'Now as I can see, there are two entites in the db Account and the transaction, and there are two relationships between them sent and the recieved one. Ok, so now I know the entites and relationships between them, I will find a way to find fraudlent transactions, so first let's try to find transactions that were transferring money to themselves by using some other accounts. }}
        - Output: {{ 'step': 'action', 'content': '' }}
    """),
    MessagesPlaceholder("history"),
])


In [37]:
new_parser = OutputFixingParser.from_llm(llm=llm, parser=JsonOutputParser())


In [38]:
chain = prompt | llm | new_parser


In [39]:
history = []


In [None]:
for i in range(0,2):
    results = chain.invoke(input={
        "history": history,
        "tools": tools_definition,
        "format_instructions": parser.get_format_instructions()
    })

    print(results)
    print("*"*100);
    
    if results == None:
        break;
    
    if results.get("step") == "result":
        break;
    
    history.append(("ai", escape_braces(results)));
    
    if results.get("step") == "tool":
        tool_name = results.get("tool_name")
        tool_args = results.get("tool_args")
        if tool_name in tools:
            tool_res = await tools.get(tool_name).get("fn").ainvoke(tool_args)
            tool_res = {"step": "tool_res", "content": tool_res}
            print(tool_res);
            print("*"*100);
            history.append(("ai", escape_braces(tool_res)))

