In [1]:
import sys

sys.path.append('..')

In [2]:
from src.tool.execute_chyper import ExecuteCypherTool
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent

In [3]:
model = ChatOpenAI(
    model="qwen3-8B", 
    base_url="http://127.0.0.1:1234/v1", 
    api_key=""
)

tools = [ExecuteCypherTool()]

agent_executor = create_agent(model, tools)

In [4]:
chat_history = [
    SystemMessage(
    """
        You are an expert Medical Fraud Detection AI Agent powered by a Knowledge Graph.
        Your goal is to validate insurance claim form data against medical rules.

        You have 1 tool to help you analyze if the claim form data is Fraudulent or not.
        1. execute_cypher: Executes Cypher queries against the graph database

        Note: You are validating NEW form data that has not been processed before.
        There is no existing status to check - you must perform full validation.

        Output Format: 
        Form Data Summary: <summary of input data>
        Validation Result: <FRAUD/NORMAL>
        Confidence Score: <0-100%>
        Detail Analysis: <detailed analysis of validation>
        Explanation: <detailed explanation of the validation>
    """,
    ),
]

In [5]:
import re

def clean_llm_response(content: str) -> str:
    """Clean LLM response by removing thinking tags and unwanted content."""
    
    # 1. Remove <think>...</think> tags and their content
    content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
    
    # 2. Remove markdown code blocks
    # Note: It's safer to use replace or regex than hard slicing indices
    if content.startswith("```json"):
        content = content.replace("```json", "", 1)
    if content.startswith("```"):
        content = content.replace("```", "", 1)
    if content.endswith("```"):
        content = content[:-3]
        
    # 3. Remove any other XML-like tags
    content = re.sub(r'<[^>]+>', '', content)
    
    return content.strip()

In [6]:
from langchain_core.callbacks import BaseCallbackHandler

# --- 2. Define the Custom Printer Class ---
class ToolExecutionPrinter(BaseCallbackHandler):
    def on_tool_start(self, serialized, input_str, **kwargs):
        """Run when a tool starts running."""
        tool_name = serialized.get("name")
        print(f"\n[LOG] üõ†Ô∏è  Agent is entering tool: {tool_name}")
        print(f"[LOG]    Input args: {input_str}\n")

    def on_tool_end(self, output, **kwargs):
        """Run when a tool ends running."""
        print(f"[LOG] ‚úÖ Tool execution finished.\n")

In [7]:
# Form input data (replace with actual form data)
form_input = {
    "hospital_id": "H001",
    "doctor_id": "D001", 
    "diagnosa_id": "I63",
    "total_cost": 15000000,
    "primary_procedure": "CT Scan",
    "secondary_procedure": "MRI",
    "diagnosis_text": "Cerebral infarction"
}

In [8]:
# Cypher queries for validation (adapted for form data)

golden_query_to_get_diagnose_and_procedure_relation = """
MATCH (d:Diagnosis {code: '<diagnosis_id>'})-[r]->(p:Procedure)
WITH d, count(r) AS procedure_count,
     collect({procedure: p.name, relationship: type(r), cost: p.avg_cost}) AS procedures
RETURN d.code AS ICD10_Code,
       d.name AS Diagnosis_Name,
       d.severity AS Severity,
       d.avg_cost AS Diagnosis_Cost,
       procedure_count AS Number_of_Procedures,
       procedures AS Associated_Procedures;
"""

golden_query_get_specialisties_doctor = """
MATCH (d:Doctor {id: '<doctor_id>'})
RETURN d.id AS doctor_id,
       d.name AS doctor_name,
       d.specialization AS specialization;
"""

golden_query_get_specialties_and_facilities_hospital = """
MATCH (h:Hospital {id: '<hospital_id>'})
OPTIONAL MATCH (h)-[:HAS_SPECIALTY]->(s:Specialty)
OPTIONAL MATCH (h)-[:HAS_FACILITY]->(f:Facility)
RETURN h.id AS hospital_id,
       h.name AS hospital_name,
       h.class AS hospital_class,
       h.location AS location,
       collect(DISTINCT s.name) AS specialties,
       collect(DISTINCT f.name) AS facilities;
"""

golden_query_get_procedure_costs = """
MATCH (p:Procedure)
WHERE p.name IN ['<primary_procedure>', '<secondary_procedure>']
RETURN p.name AS procedure_name,
       p.avg_cost AS avg_cost
"""

golden_query_get_diagnosis_cost = """
MATCH (d:Diagnosis {code: '<diagnosis_id>'})
RETURN d.code AS diagnosis_code,
       d.name AS diagnosis_name,
       d.avg_cost AS avg_cost
"""

In [9]:
# Format form input for display
form_summary = f"""
Hospital ID: {form_input['hospital_id']}
Doctor ID: {form_input['doctor_id']}
Diagnosis ID: {form_input['diagnosa_id']}
Total Cost: {form_input['total_cost']:,}
Primary Procedure: {form_input['primary_procedure']}
Secondary Procedure: {form_input['secondary_procedure']}
Diagnosis Text: {form_input['diagnosis_text']}
"""

In [10]:
message = [
    HumanMessage(f"This is a medical claim form with the following data: {form_summary}"),
    HumanMessage(f"""
    Please validate this claim form data following these steps. Be objective and allow for reasonable operational variances.

    **Form Data Details:**
    - Hospital ID: {form_input['hospital_id']}
    - Doctor ID: {form_input['doctor_id']}
    - Diagnosis ID: {form_input['diagnosa_id']}
    - Total Cost: {form_input['total_cost']:,}
    - Primary Procedure: {form_input['primary_procedure']}
    - Secondary Procedure: {form_input['secondary_procedure']}
    - Diagnosis Text: {form_input['diagnosis_text']}

    **Validation Steps (Apply these rules strictly in order)**:

    1. **Procedure Consistency**: 
       Check if the procedures are clinically appropriate for the diagnosis.
       - *Logic*: Use the relation from {golden_query_to_get_diagnose_and_procedure_relation} (replace '<diagnosis_id>' with '{form_input['diagnosa_id']}').
       - *Guidance*: If the procedures are standard diagnostic tools for the diagnosis (e.g., CT Scan/MRI for Stroke), it is a MATCH.

    2. **Cost Analysis (The 20% Rule)**: 
       Compare the Form's Total Cost vs. Ground Truth (Sum of Diagnosis Avg Cost + Procedure Avg Costs).
       - *Logic*: Calculate the deviation: `(Form_Cost - Ground_Truth) / Ground_Truth`.
       - *Guidance*: 
          - If deviation is **< 20%**: Consider this **NORMAL** operational variance (e.g., room upgrades, extra meds). Do NOT flag as fraud based on cost alone.
          - If deviation is **> 20%**: Flag as **FRAUD** (Cost significantly inflated).
       - Use these queries to get ground truth:
         * Diagnosis cost: {golden_query_get_diagnosis_cost} (replace '<diagnosis_id>' with '{form_input['diagnosa_id']}')
         * Procedure costs: {golden_query_get_procedure_costs} (replace '<primary_procedure>' with '{form_input['primary_procedure']}' and '<secondary_procedure>' with '{form_input['secondary_procedure']}')

    3. **Doctor Qualification (GP Exception)**: 
       Check if the doctor is qualified.
       - *Logic*: Use {golden_query_get_specialisties_doctor} (replace '<doctor_id>' with '{form_input['doctor_id']}').
       - *Guidance*: 
          - **GPs (General Practitioners)** are VALID for initial diagnoses, consultations, and ordering standard scans (like MRI/CT), even for complex conditions like Stroke. 
          - Flag as **FRAUD** only if there is a **hard contradiction** (e.g., a Pediatrician performing Major Surgery, or an Ophthalmologist treating Heart Attack).

    4. **Hospital Capability**: 
       Check if the hospital has relevant facilities.
       - *Logic*: Use {golden_query_get_specialties_and_facilities_hospital} (replace '<hospital_id>' with '{form_input['hospital_id']}').
       - *Guidance*: Look for broad keyword matches. For example, if Diagnosis is "Stroke", facilities like "ICU", "Neurology", or "Internal Medicine" are sufficient evidence of capability.

    5. **Final Verdict**:
       Based on the above, determine FRAUD or NORMAL.
       - Provide a confidence score (0-100%).
       - Provide the Form Data Summary.
       - **Explanation**: You MUST explicitly state the cost deviation percentage in your explanation (e.g., "Cost is 4.5% higher, which is within the acceptable 20% variance").
    """)
]

In [11]:
final_messages = chat_history + message

response = agent_executor.invoke({
    "messages": final_messages,
},
config={
        "callbacks": [ToolExecutionPrinter()]
    }
)

raw_content = response['messages'][-1].content

# 3. Apply the Cleaning Function
final_output = clean_llm_response(raw_content)

# 4. Print the Cleaned Response
print(f"Response: {final_output}")


[LOG] üõ†Ô∏è  Agent is entering tool: execute_cypher
[LOG]    Input args: {'cypher_query': "MATCH (d:Diagnosis {code: 'I63'})-[r]->(p:Procedure) WITH d, count(r) AS procedure_count, collect({procedure: p.name, relationship: type(r), cost: p.avg_cost}) AS procedures RETURN d.code AS ICD10_Code, d.name AS Diagnosis_Name, d.severity AS Severity, d.avg_cost AS Diagnosis_Cost, procedure_count AS Number_of_Procedures, procedures AS Associated_Procedures;"}

[EXECUTE_CYPHER] Executing query: MATCH (d:Diagnosis {code: 'I63'})-[r]->(p:Procedure) WITH d, count(r) AS procedure_count, collect({procedure: p.name, relationship: type(r), cost: p.avg_cost}) AS procedures RETURN d.code AS ICD10_Code, d.name AS Diagnosis_Name, d.severity AS Severity, d.avg_cost AS Diagnosis_Cost, procedure_count AS Number_of_Procedures, procedures AS Associated_Procedures;

[LOG] üõ†Ô∏è  Agent is entering tool: execute_cypher
[LOG]    Input args: {'cypher_query': "MATCH (d:Diagnosis {code: 'I63'}) RETURN d.code AS di