In [None]:
import sys

sys.path.append('..')

In [None]:
from src.tool.execute_chyper import ExecuteCypherTool
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent

model = ChatOpenAI(
    model="qwen3-8B", 
    base_url="http://127.0.0.1:1234/v1", 
    api_key=""
)

tools = [ExecuteCypherTool()]

agent_executor = create_agent(model, tools)

In [None]:
chat_history = [
    SystemMessage(
    """
        You are an expert Medical Fraud Detection AI Agent powered by a Knowledge Graph.
        Your goal is to validate insurance claims against medical rules.

        You have 1 tools to help you analyze the claim is Fraudlent or not.
        1. execute_cypher: Executes Cypher queries against the graph database

        Output Format for Bulk Validation: 
        List of claims with their fraud status.
        Fraud = [Claim_id1, Claim_id2, ...]
        Normal = [Claim_id3, Claim_id4, ...]
    """,
    ),
]

In [None]:
import re

def clean_llm_response(content: str) -> str:
    """Clean LLM response by removing thinking tags and unwanted content."""
    
    # 1. Remove <think>...</think> tags and their content
    content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
    
    # 2. Remove markdown code blocks
    # Note: It's safer to use replace or regex than hard slicing indices
    if content.startswith("```json"):
        content = content.replace("```json", "", 1)
    if content.startswith("```"):
        content = content.replace("```", "", 1)
    if content.endswith("```"):
        content = content[:-3]
        
    # 3. Remove any other XML-like tags
    content = re.sub(r'<[^>]+>', '', content)
    
    return content.strip()

In [None]:
from langchain_core.callbacks import BaseCallbackHandler

# --- 2. Define the Custom Printer Class ---
class ToolExecutionPrinter(BaseCallbackHandler):
    def on_tool_start(self, serialized, input_str, **kwargs):
        """Run when a tool starts running."""
        tool_name = serialized.get("name")
        print(f"\n[LOG] üõ†Ô∏è  Agent is entering tool: {tool_name}")
        print(f"[LOG]    Input args: {input_str}\n")

    def on_tool_end(self, output, **kwargs):
        """Run when a tool ends running."""
        print(f"[LOG] ‚úÖ Tool execution finished.\n")

In [None]:
# Input: Hospital ID
hospital_id = "H001"

In [None]:
# Cypher query to get all unvalidated claims by hospital ID
get_unvalidated_claims_by_hospital = f"""
MATCH (h:Hospital {{id: '{hospital_id}'}})<-[:SUBMITTED_AT]-(c:Claim)
WHERE c.status IS NULL OR c.status = 'NaN' OR c.status = ''
RETURN c.id as claim_id
"""

# Golden queries from original file
golden_chyper_for_get_claim_data = """
MATCH (c:Claim {id: <claim_id>})
OPTIONAL MATCH (c)-[:HAS_PATIENT]->(patient:Patient)
OPTIONAL MATCH (c)-[:HAS_PRIMARY_PROCEDURE]->(pp:Procedure)
OPTIONAL MATCH (c)-[:HAS_SECONDARY_PROCEDURE]->(sp:Procedure)
OPTIONAL MATCH (c)-[:HAS_CLINICAL_NOTE]->(note:ClinicalNote)
OPTIONAL MATCH (c)-[:SUBMITTED_AT]->(hospital:Hospital)
OPTIONAL MATCH (c)-[:SUBMITTED_BY]->(doctor:Doctor)
OPTIONAL MATCH (c)-[:CODED_AS]->(diagnosis:Diagnosis)
RETURN c.id, c.total_cost, c.status,
       patient.name as patient_name,
       hospital.name as hospital_name,
       doctor.name as doctor_name,
       diagnosis.name as diagnosis_name,
       collect(DISTINCT pp.name) as primary_procedures,
       collect(DISTINCT sp.name) as secondary_procedures,
       note.primary_diagnosis_text,
       note.secondary_diagnosis_text
"""

golden_chyper_to_get_price_procedure_diagnose_based_on_claim_id = """
MATCH (c:Claim {id: <claim_id>})
OPTIONAL MATCH (c)-[:HAS_PRIMARY_PROCEDURE]->(pp:Procedure)
OPTIONAL MATCH (c)-[:HAS_SECONDARY_PROCEDURE]->(sp:Procedure)
OPTIONAL MATCH (c)-[:CODED_AS]->(diagnosis:Diagnosis)
WITH c, diagnosis, 
     collect(DISTINCT pp) as primary_procs,
     collect(DISTINCT sp) as secondary_procs
RETURN 
    c.id,
    c.total_cost,
    diagnosis.name,
    diagnosis.avg_cost as diagnosis_cost,
    reduce(total = 0, proc IN primary_procs | total + COALESCE(proc.avg_cost, 0)) as primary_procs_total,
    reduce(total = 0, proc IN secondary_procs | total + COALESCE(proc.avg_cost, 0)) as secondary_procs_total
"""

golden_query_to_get_diagnose_and_procedure_relation = """
MATCH (d:Diagnosis {code: <diagnosis_id>})-[r]->(p:Procedure)
WITH d, count(r) AS procedure_count,
     collect({procedure: p.name, relationship: type(r), cost: p.avg_cost}) AS procedures
RETURN d.code AS ICD10_Code,
       d.name AS Diagnosis_Name,
       d.severity AS Severity,
       d.avg_cost AS Diagnosis_Cost,
       procedure_count AS Number_of_Procedures,
       procedures AS Associated_Procedures;
"""

golden_query_get_specialisties_doctor = """
MATCH (d:Doctor {id: <doctor_id>})
RETURN d.id AS doctor_id,
       d.name AS doctor_name,
       d.specialization AS specialization;
"""

golden_query_get_specialties_and_facilities_hospital = """
MATCH (h:Hospital {id: <hospital_id>})
OPTIONAL MATCH (h)-[:HAS_SPECIALTY]->(s:Specialty)
OPTIONAL MATCH (h)-[:HAS_FACILITY]->(f:Facility)
RETURN h.id AS hospital_id,
       h.name AS hospital_name,
       h.class AS hospital_class,
       h.location AS location,
       collect(DISTINCT s.name) AS specialties,
       collect(DISTINCT f.name) AS facilities;
"""

In [None]:
# Step 1: Get all unvalidated claims for the hospital
print(f"Getting unvalidated claims for hospital: {hospital_id}")

# First get the list of claim IDs
claims_query_message = [
    HumanMessage(f"Execute this Cypher query to get all unvalidated claims for hospital {hospital_id}: {get_unvalidated_claims_by_hospital}")
]

claims_response = agent_executor.invoke({
    "messages": chat_history + claims_query_message,
},
config={
    "callbacks": [ToolExecutionPrinter()]
})

claims_result = clean_llm_response(claims_response['messages'][-1].content)
print(f"Claims Query Result: {claims_result}")

In [None]:
# Parse claim IDs from the response (this might need adjustment based on actual response format)
import json

# Extract claim IDs - you might need to adjust this parsing based on the actual response format
claim_ids = []
try:
    # Try to parse as JSON first
    if claims_result.startswith('[') or claims_result.startswith('{'):
        parsed_result = json.loads(claims_result)
        if isinstance(parsed_result, list):
            claim_ids = [item.get('claim_id', item) if isinstance(item, dict) else item for item in parsed_result]
        elif isinstance(parsed_result, dict) and 'claim_id' in parsed_result:
            claim_ids = [parsed_result['claim_id']]
    else:
        # If not JSON, try to extract claim IDs using regex
        import re
        claim_ids = re.findall(r'C\d+', claims_result)
except:
    # Fallback: extract anything that looks like a claim ID
    import re
    claim_ids = re.findall(r'C\d+', claims_result)

print(f"Found {len(claim_ids)} unvalidated claims: {claim_ids}")

In [None]:
# Step 2: Process each claim and collect results
normal_claims = []
fraud_claims = []

for claim_id in claim_ids:
    print(f"\nProcessing claim: {claim_id}")
    
    # Create validation message for this claim
    validation_message = [
        HumanMessage(f"This is Claim ID: {claim_id}"),
        HumanMessage(f"""
        Please follow these steps to analyze the claim. Be objective and allow for reasonable operational variances.

        1. **Data Retrieval**: 
           Execute {golden_chyper_for_get_claim_data} to get all relevant data. 
           - If the status is already "NORMAL" or "FRAUD", return it directly with 100% confidence. (No further validation needed.)
           - If status is null/NaN, proceed to validation steps below.

        2. **Validation Logic (Apply these rules strictly in order)**:

           a. **Procedure Consistency**: 
              Check if the procedure is clinically appropriate for the diagnosis.
              - *Logic*: Use the relation from {golden_query_to_get_diagnose_and_procedure_relation}.
              - *Guidance*: If the procedure is a standard diagnostic tool for the diagnosis (e.g., MRI for Stroke), it is a MATCH.

           b. **Cost Analysis (The 20% Rule)**: 
              Compare the Claim's Total Cost vs. Ground Truth (Sum of Diagnosis Avg Cost + Procedure Avg Cost).
              - *Logic*: Calculate the deviation: `(Claim_Cost - Ground_Truth) / Ground_Truth`.
              - *Guidance*: 
                 - If deviation is **< 20%**: Consider this **NORMAL** operational variance (e.g., room upgrades, extra meds). Do NOT flag as fraud based on cost alone.
                 - If deviation is **> 20%**: Flag as **FRAUD** (Cost significantly inflated).
              - Ground Truth Query: {golden_chyper_to_get_price_procedure_diagnose_based_on_claim_id} (or calculate manually using averages).

           c. **Doctor Qualification (GP Exception)**: 
              Check if the doctor is qualified.
              - *Logic*: Use {golden_query_get_specialisties_doctor}.
              - *Guidance*: 
                 - **GPs (General Practitioners)** are VALID for initial diagnoses, consultations, and ordering standard scans (like MRI/CT), even for complex conditions like Stroke. 
                 - Flag as **FRAUD** only if there is a **hard contradiction** (e.g., a Pediatrician performing Major Surgery, or an Ophthalmologist treating Heart Attack).

           d. **Hospital Capability**: 
              Check if the hospital has relevant facilities.
              - *Logic*: Use {golden_query_get_specialties_and_facilities_hospital}.
              - *Guidance*: Look for broad keyword matches. For example, if Diagnosis is "Stroke", facilities like "ICU", "Neurology", or "Internal Medicine" are sufficient evidence of capability.

        3. **Final Verdict**:
           Based on the above, determine FRAUD or NORMAL.
           - Return ONLY: Claim ID: {claim_id} | Status: <FRAUD/NORMAL>
        """)
    ]
    
    # Process this claim
    try:
        response = agent_executor.invoke({
            "messages": chat_history + validation_message,
        },
        config={
            "callbacks": [ToolExecutionPrinter()]
        })
        
        raw_content = response['messages'][-1].content
        final_output = clean_llm_response(raw_content)
        
        # Extract status from response
        if "FRAUD" in final_output.upper():
            fraud_claims.append(claim_id)
            print(f"  ‚Üí {claim_id}: FRAUD")
        else:
            normal_claims.append(claim_id)
            print(f"  ‚Üí {claim_id}: NORMAL")
            
    except Exception as e:
        print(f"  ‚Üí Error processing {claim_id}: {str(e)}")
        # In case of error, add to normal claims as default
        normal_claims.append(claim_id)

In [None]:
# Step 3: Output Results
print("\n" + "="*50)
print(f"BULK VALIDATION RESULTS FOR HOSPITAL: {hospital_id}")
print("="*50)

print(f"\nNormal Claims ({len(normal_claims)}):")
print(normal_claims)

print(f"\nFraud Claims ({len(fraud_claims)}):")
print(fraud_claims)

print(f"\nTotal Claims Processed: {len(normal_claims) + len(fraud_claims)}")