<a href="https://colab.research.google.com/github/ar7emiy/GoogleAICampProjects/blob/main/GoogleAI_TPA_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install json_repair langchain_google_genai langgraph

In [None]:
from typing import Annotated, Literal, Optional, List, Dict, Any
from typing_extensions import TypedDict
import datetime
import json
import uuid
import sqlite3
from json_repair import repair_json

from langgraph.graph.message import add_messages


class ClaimState(TypedDict):
    """State representing a workers' compensation claim process.

    This serves as the central data structure passed between all agent nodes
    in the workflow, tracking the complete claim state.
    """

    # Core identifiers and status
    claim_id: str
    status: str  # Current status in the workflow (new, review, decision, etc.)
    created_at: str
    updated_at: str

    # The chat conversation history
    messages: Annotated[list, add_messages]

    # Structured claim information
    claimant_info: Dict[str, Any]  # Personal details and injury info
    employer_info: Dict[str, Any]  # Company and supervisor details
    medical_info: Dict[str, Any]   # Diagnosis, treatment, provider info

    # Process tracking
    verification_status: Dict[str, Any]  # Results of verification process
    decision_status: Dict[str, Any]      # Approval/denial details
    payment_info: Dict[str, Any]         # Payment history and amounts
    recovery_status: Dict[str, Any]      # Recovery progress tracking

    # Conversation records organized by stakeholder
    conversations: Dict[str, List[Dict[str, Any]]]  # Keyed by stakeholder type

    # Flag for workflow completion
    finished: bool


def initialize_database():
    """Initialize the SQLite database with required tables.

    Creates the necessary tables for storing claim information,
    conversations, and other related data.
    """
    conn = sqlite3.connect("claims.db")
    cursor = conn.cursor()

    # Create claims table
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS claims (
        claim_id TEXT PRIMARY KEY,
        status TEXT,
        created_at TEXT,
        updated_at TEXT,
        claimant_info TEXT,  -- JSON string
        employer_info TEXT,  -- JSON string
        medical_info TEXT,   -- JSON string
        verification_status TEXT,  -- JSON string
        decision_status TEXT,      -- JSON string
        payment_info TEXT,         -- JSON string
        recovery_status TEXT       -- JSON string
    )
    """)

    # Create conversations table
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS conversations (
        conversation_id TEXT PRIMARY KEY,
        claim_id TEXT,
        stakeholder_type TEXT,  -- claimant, employer, medical
        timestamp TEXT,
        conversation_text TEXT,
        extracted_data TEXT,  -- JSON string of structured data
        FOREIGN KEY (claim_id) REFERENCES claims (claim_id)
    )
    """)

    # Create payments table for detailed payment tracking
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS payments (
        payment_id TEXT PRIMARY KEY,
        claim_id TEXT,
        amount REAL,
        payment_type TEXT,
        payment_date TEXT,
        approved_by TEXT,  -- adjuster who authorized
        processed_by TEXT,  -- payment processor
        FOREIGN KEY (claim_id) REFERENCES claims (claim_id)
    )
    """)

    conn.commit()
    conn.close()
    print("Database initialized successfully")


def create_initial_state(claim_details: Optional[Dict[str, Any]] = None) -> ClaimState:
    """Create initial state for a new claim.

    Args:
        claim_details: Optional dictionary with initial claim information

    Returns:
        A ClaimState object with default values and any provided claim details
    """
    claim_id = 1 #str(uuid.uuid4())
    timestamp = datetime.datetime.now().isoformat()

    # Create default initial state
    initial_state = ClaimState(
        claim_id=claim_id,
        status="new",
        created_at=timestamp,
        updated_at=timestamp,
        messages=[],
        claimant_info={},
        employer_info={},
        medical_info={},
        verification_status={},
        decision_status={},
        payment_info={"payments": []},
        recovery_status={},
        conversations={
            "claimant": [],
            "employer": [],
            "medical": []
        },
        finished=False
    )

    # Update with any provided claim details
    if claim_details:
        for key, value in claim_details.items():
            if key in initial_state:
                initial_state[key] = value

    # Save the initial state to database
    save_claim_state(initial_state)

    return initial_state


def save_claim_state(state: ClaimState):
    """Save the current claim state to the database.

    Args:
        state: The current ClaimState to save
    """
    conn = sqlite3.connect("claims.db")
    cursor = conn.cursor()

    # Prepare data for storage - use json_repair for safety
    claim_data = {
        "claim_id": state["claim_id"],
        "status": state["status"],
        "updated_at": datetime.datetime.now().isoformat(),
        "claimant_info": repair_json(json.dumps(state["claimant_info"])),
        "employer_info": repair_json(json.dumps(state["employer_info"])),
        "medical_info": repair_json(json.dumps(state["medical_info"])),
        "verification_status": repair_json(json.dumps(state["verification_status"])),
        "decision_status": repair_json(json.dumps(state["decision_status"])),
        "payment_info": repair_json(json.dumps(state["payment_info"])),
        "recovery_status": repair_json(json.dumps(state["recovery_status"]))
    }

    # Check if claim exists
    cursor.execute("SELECT claim_id FROM claims WHERE claim_id = ?", (state["claim_id"],))
    if cursor.fetchone():
        # Update existing claim
        query = """
        UPDATE claims SET
            status = ?, updated_at = ?, claimant_info = ?, employer_info = ?,
            medical_info = ?, verification_status = ?, decision_status = ?,
            payment_info = ?, recovery_status = ?
        WHERE claim_id = ?
        """
        cursor.execute(query, (
            claim_data["status"], claim_data["updated_at"],
            claim_data["claimant_info"], claim_data["employer_info"],
            claim_data["medical_info"], claim_data["verification_status"],
            claim_data["decision_status"], claim_data["payment_info"],
            claim_data["recovery_status"], claim_data["claim_id"]
        ))
    else:
        # Insert new claim
        query = """
        INSERT INTO claims (
            claim_id, status, created_at, updated_at, claimant_info, employer_info,
            medical_info, verification_status, decision_status, payment_info, recovery_status
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """
        cursor.execute(query, (
            claim_data["claim_id"], claim_data["status"],
            state["created_at"], claim_data["updated_at"],
            claim_data["claimant_info"], claim_data["employer_info"],
            claim_data["medical_info"], claim_data["verification_status"],
            claim_data["decision_status"], claim_data["payment_info"],
            claim_data["recovery_status"]
        ))

    conn.commit()
    conn.close()


def save_conversation(claim_id: str, stakeholder_type: str, conversation_text: str, extracted_data: Dict[str, Any]):
    """Save a conversation record to the database.

    Args:
        claim_id: The ID of the associated claim
        stakeholder_type: Type of stakeholder (claimant, employer, medical)
        conversation_text: The text of the conversation
        extracted_data: Structured data extracted from the conversation
    """
    conn = sqlite3.connect("claims.db")
    cursor = conn.cursor()

    conversation_id = str(uuid.uuid4())
    timestamp = datetime.datetime.now().isoformat()

    # Use json_repair to ensure valid JSON
    extracted_data_json = repair_json(json.dumps(extracted_data))

    query = """
    INSERT INTO conversations (
        conversation_id, claim_id, stakeholder_type,
        timestamp, conversation_text, extracted_data
    ) VALUES (?, ?, ?, ?, ?, ?)
    """

    cursor.execute(query, (
        conversation_id, claim_id, stakeholder_type,
        timestamp, conversation_text, extracted_data_json
    ))

    conn.commit()
    conn.close()

## Importing Libraries

In [None]:
# !pip install json_repair langchain_google_genai langgraph
import os
from google import genai
from google.genai import types
from google import generativeai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from json_repair import repair_json

In [None]:
# Global parameters
api_key = 'AIzaSyDW0Faqg24tRba6h7tEl4FQsodn0_IjH3s'

In [None]:
def initialize_model():
    """Initialize the Gemini model with API key.

    Returns:
        A configured LangChain ChatGoogleGenerativeAI model
    """
    # Setup API key
    # api_key = os.environ.get("")
    # genai.configure(api_key=api_key)

    # Create the model
    llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=api_key, callbacks=None)
    return llm


# Tool functions for agents to use
@tool
def extract_claimant_data(
    name: str,
    contact_info: str,
    injury_date: str,
    injury_description: str,
    how_it_occurred: str,
    witnesses: list[str] = None,
    medical_attention: bool = False,
    medical_provider: str = None
) -> dict:
    """Extract structured claimant data from conversation.

    Args:
        name: Claimant's full name
        contact_info: Contact information (phone, email)
        injury_date: Date of injury (ISO format)
        injury_description: Description of the injury
        how_it_occurred: How the injury occurred
        witnesses: Names of witnesses (optional)
        medical_attention: Whether medical attention was sought
        medical_provider: Medical provider if applicable

    Returns:
        Structured claimant data
    """
    return {
        "name": name,
        "contact_info": contact_info,
        "injury_date": injury_date,
        "injury_description": injury_description,
        "how_it_occurred": how_it_occurred,
        "witnesses": witnesses or [],
        "medical_attention": medical_attention,
        "medical_provider": medical_provider
    }


@tool
def extract_employer_data(
    company_name: str,
    supervisor_name: str,
    contact_info: str,
    employment_duration: str,
    job_title: str,
    incident_reported: bool,
    report_date: str = None
) -> dict:
    """Extract structured employer data from conversation.

    Args:
        company_name: Employer's company name
        supervisor_name: Supervisor's name
        contact_info: Contact information
        employment_duration: How long claimant has been employed
        job_title: Claimant's job title
        incident_reported: Whether incident was reported to employer
        report_date: Date incident was reported (ISO format)

    Returns:
        Structured employer data
    """
    return {
        "company_name": company_name,
        "supervisor_name": supervisor_name,
        "contact_info": contact_info,
        "employment_duration": employment_duration,
        "job_title": job_title,
        "incident_reported": incident_reported,
        "report_date": report_date
    }


@tool
def extract_medical_data(
    provider_name: str,
    facility: str,
    diagnosis: str,
    treatment_plan: str,
    work_restrictions: str,
    estimated_recovery: str,
    follow_up_date: str = None
) -> dict:
    """Extract structured medical data from conversation.

    Args:
        provider_name: Medical provider's name
        facility: Medical facility name
        diagnosis: Medical diagnosis
        treatment_plan: Recommended treatment plan
        work_restrictions: Work restrictions if any
        estimated_recovery: Estimated recovery time
        follow_up_date: Follow-up appointment date (ISO format)

    Returns:
        Structured medical data
    """
    return {
        "provider_name": provider_name,
        "facility": facility,
        "diagnosis": diagnosis,
        "treatment_plan": treatment_plan,
        "work_restrictions": work_restrictions,
        "estimated_recovery": estimated_recovery,
        "follow_up_date": follow_up_date,
        "specialty_matches_injury": True  # Simplified check for simulation
    }


@tool
def simulate_conversation(role: str, context: dict) -> dict:
    """Simulate a conversation with a stakeholder.

    Args:
        role: The stakeholder role (claimant, employer, medical)
        context: Relevant information about the claim
        goal: The goal of the conversation

    Returns:
        A dictionary with the conversation text and extracted data
    """
    # This is where the LLM would be called to simulate both sides of the conversation
    # For a complete implementation, this would use the model to simulate the conversation

    # For simulation purposes in this code, we'll return mock data
    if role == "claimant":
        conversation_text = f"Adjuster: Hello, I'm calling about your workers compensation claim. Can you tell me about your injury?\n\nClaimant: Yes, I injured my back while lifting a heavy box at work on June 1st.\n\nAdjuster: I'm sorry to hear that. Can you describe how it happened?\n\nClaimant: I was moving inventory in the warehouse when I felt a sharp pain in my lower back."
        extracted_data = {
            "name": "John Smith",
            "contact_info": "john.smith@example.com",
            "injury_date": "2025-06-01T00:00:00Z",
            "injury_description": "Lower back strain",
            "how_it_occurred": "Lifting heavy box in warehouse",
            "witnesses": ["Jane Doe"],
            "medical_attention": True,
            "medical_provider": "City General Hospital"
        }
    elif role == "employer":
        conversation_text = f"Adjuster: Hello, I'm calling to verify employment for John Smith regarding a workers compensation claim. Can you confirm his employment?\n\nEmployer: Yes, John has been with us for 3 years. He works in our warehouse department.\n\nAdjuster: Was the incident reported to you?\n\nEmployer: Yes, John reported it to his supervisor immediately after it happened."
        extracted_data = {
            "company_name": "ABC Logistics",
            "supervisor_name": "Mike Johnson",
            "contact_info": "mike.johnson@abclogistics.com",
            "employment_duration": "3 years",
            "job_title": "Warehouse Associate",
            "incident_reported": True,
            "report_date": "2025-06-01T14:30:00Z"
        }
    elif role == "medical":
        conversation_text = f"Adjuster: Hello, I'm calling regarding John Smith. He's filed a workers compensation claim and listed you as his provider. Can you confirm his diagnosis?\n\nDoctor: Yes, John has an acute lumbar strain. Based on his description, it's consistent with a lifting injury.\n\nAdjuster: What's the treatment plan and expected recovery time?\n\nDoctor: We've prescribed rest, physical therapy twice weekly, and anti-inflammatory medication. Recovery should take 4-6 weeks."
        extracted_data = {
            "provider_name": "Dr. Sarah Williams",
            "facility": "City General Hospital",
            "diagnosis": "Acute lumbar strain",
            "treatment_plan": "Rest, physical therapy twice weekly, anti-inflammatory medication",
            "work_restrictions": "No lifting over 10 pounds for 2 weeks",
            "estimated_recovery": "4-6 weeks",
            "follow_up_date": "2025-06-15T00:00:00Z"
        }
    else:
        conversation_text = "No conversation simulated"
        extracted_data = {}

    # In a real implementation, we would use json_repair here
    # extracted_data = json.loads(repair_json(json.dumps(extracted_data)))

    # Save conversation to database
    if context.get("claim_id"):
        save_conversation(context["claim_id"], role, conversation_text, extracted_data)

    return {
        "conversation_text": conversation_text,
        "extracted_data": extracted_data
    }


@tool
def verify_employment(
    employment_verified: bool,
    verification_method: str,
    employment_duration: str,
    eligible_for_benefits: bool,
    notes: str = ""
) -> dict:
    """Verify employment status and eligibility.

    Args:
        employment_verified: Whether employment is verified
        verification_method: Method used for verification
        employment_duration: Duration of employment
        eligible_for_benefits: Whether employee is eligible for benefits
        notes: Any additional verification notes

    Returns:
        Employment verification status
    """
    return {
        "employment_verified": employment_verified,
        "verification_method": verification_method,
        "employment_duration": employment_duration,
        "eligible_for_benefits": eligible_for_benefits,
        "notes": notes,
        "verification_timestamp": datetime.datetime.now().isoformat()
    }


@tool
def assess_medical_treatment(
    diagnosis_confirmed: bool,
    provider_qualified: bool,
    treatment_appropriate: bool,
    notes: str = ""
) -> dict:
    """Assess if medical treatment is appropriate and from qualified provider.

    Args:
        diagnosis_confirmed: Whether diagnosis is confirmed
        provider_qualified: Whether provider is qualified for this injury
        treatment_appropriate: Whether treatment is appropriate
        notes: Additional assessment notes

    Returns:
        Medical assessment results
    """
    return {
        "diagnosis_confirmed": diagnosis_confirmed,
        "provider_qualified": provider_qualified,
        "treatment_appropriate": treatment_appropriate,
        "notes": notes,
        "assessment_timestamp": datetime.datetime.now().isoformat()
    }


@tool
def make_claim_decision(
    approved: bool,
    reasoning: str,
    benefit_type: str = None,
    denial_reason: str = None
) -> dict:
    """Make decision on claim approval or denial.

    Args:
        approved: Whether claim is approved
        reasoning: Detailed reasoning for decision
        benefit_type: Type of benefits approved (if applicable)
        denial_reason: Reason for denial (if applicable)

    Returns:
        Decision details
    """
    return {
        "approved": approved,
        "decision_date": datetime.datetime.now().isoformat(),
        "reasoning": reasoning,
        "benefit_type": benefit_type,
        "denial_reason": denial_reason
    }


@tool
def process_payment(
    payment_amount: float,
    payment_type: str,
    payment_period: str = None,
    notes: str = ""
) -> dict:
    """Process a payment for an approved claim.

    Args:
        payment_amount: Payment amount
        payment_type: Type of payment (initial, recurring, final)
        payment_period: Period covered by payment
        notes: Additional payment notes

    Returns:
        Payment details
    """
    payment = {
        "payment_id": str(uuid.uuid4()),
        "payment_amount": payment_amount,
        "payment_type": payment_type,
        "payment_date": datetime.datetime.now().isoformat(),
        "payment_period": payment_period,
        "notes": notes
    }

    # In a real implementation, this would update the payments database

    return payment


@tool
def assess_recovery(
    recovery_status: str,
    work_status: str,
    fully_recovered: bool,
    continuing_treatment: bool,
    next_medical_review: str = None,
    notes: str = ""
) -> dict:
    """Assess recovery progress.

    Args:
        recovery_status: Current recovery status (improving, stable, etc.)
        work_status: Current work status (off work, light duty, full duty)
        fully_recovered: Whether claimant has fully recovered
        continuing_treatment: Whether claimant is still receiving treatment
        next_medical_review: Date of next medical review
        notes: Additional notes on recovery

    Returns:
        Recovery assessment
    """
    return {
        "recovery_status": recovery_status,
        "work_status": work_status,
        "fully_recovered": fully_recovered,
        "continuing_treatment": continuing_treatment,
        "next_medical_review": next_medical_review,
        "notes": notes,
        "assessment_date": datetime.datetime.now().isoformat()
    }


# Define system prompts
INTAKE_SYSTEM_PROMPT = """You are the Intake Agent for a workers' compensation claim system. Your role is to:
1. Gather initial information from claimants about their injury
2. Collect employer information for verification
3. Record initial medical information
4. Structure this information for further processing

Important rules:
- Document all information thoroughly
- Only collect information relevant to the claim
- Be empathetic but focused on gathering facts
- Maintain confidentiality of sensitive information
"""

ADJUSTER_SYSTEM_PROMPT = """You are the Claims Adjuster for a workers' compensation claim system. Your role is to:
1. Verify employment and injury details
2. Review medical information and ensure provider is qualified for the injury type
3. Ensure compliance with regulations
4. Make approval/denial decisions based on evidence
5. Monitor recovery progress and authorize additional treatment

Important rules:
- Verify all medical providers are qualified for the treatment
- Only share medical information with authorized parties
- Document reasoning for all decisions
- Ensure recovery checks are scheduled at appropriate intervals
- Authorize payments but do not process them
"""

PAYMENT_SYSTEM_PROMPT = """You are the Payment Processing Agent for a workers' compensation claim system. Your role is to:
1. Calculate payment amounts based on approved claims
2. Process payments for medical procedures and lost wages
3. Track payment history for each claim
4. Ensure payments are within policy limits

Important rules:
- Only process payments authorized by the Claims Adjuster
- Document all payment details including date, amount, and recipient
- Track running totals against policy limits
- Process multiple payments for the same claim as needed
"""

# Define agent nodes
def intake_agent(state: ClaimState) -> ClaimState:
    """The intake agent gathers initial information about the claim.

    This agent simulates conversations with the claimant, employer,
    and medical provider to gather initial claim information.

    Args:
        state: The current claim state

    Returns:
        Updated claim state with intake information
    """
    # Initialize LLM with tools
    llm = initialize_model()
    print(100* '-')
    print('Intake Agent initialized')
    print(llm)
    tools = [
        simulate_conversation,
        extract_claimant_data,
        extract_employer_data,
        extract_medical_data
    ]


    llm_with_tools = llm.bind_tools(tools)

    # Check if we need to gather claimant information
    if not state["claimant_info"]:

        # Simulate conversation with claimant
        system_message = INTAKE_SYSTEM_PROMPT
        user_message = f"I need to gather initial information for a new workers' compensation claim (ID: {state['claim_id']}). Please simulate a conversation with the claimant to gather injury details."

        # In a real implementation, we would use the model response here
        response = llm_with_tools.invoke([system_message, user_message])
        conversation_text = response.content
        print(response)
        print(conversation_text)
        # For this simulation, we'll directly call the simulate_conversation tool
        claimant_convo = simulate_conversation.invoke({
            "role": "claimant",
            "context": {"claim_id": state["claim_id"]},
            # "initial claim information gathering"
        })
        print(claimant_convo)

        # Update state with claimant information
        state["claimant_info"] = claimant_convo["extracted_data"]
        # Add to conversations record
        if "claimant" not in state["conversations"]:
            state["conversations"]["claimant"] = []

        state["conversations"]["claimant"].append(claimant_convo)

    # Check if we need to gather employer information
    if not state["employer_info"] and state["claimant_info"]:
        # Simulate conversation with employer
        employer_convo = simulate_conversation.invoke({
            "role": "employer",
            "context": {"claim_id": state["claim_id"], "claimant_info": state["claimant_info"]},
            #"verify employment and incident details"
        })
        print(employer_convo)

        # Update state with employer information
        state["employer_info"] = employer_convo["extracted_data"]

        # Add to conversations record
        if "employer" not in state["conversations"]:
            state["conversations"]["employer"] = []

        state["conversations"]["employer"].append(employer_convo)

    # Check if we need to gather medical information
    if not state["medical_info"] and state["claimant_info"]:
        # Simulate conversation with medical provider
        medical_convo = simulate_conversation.invoke({
            "role": "medical",
            "context": {"claim_id": state["claim_id"], "claimant_info": state["claimant_info"]},
            # "gather medical diagnosis and treatment information"
        })
        print(medical_convo)

        # Update state with medical information
        state["medical_info"] = medical_convo["extracted_data"]

        # Add to conversations record
        if "medical" not in state["conversations"]:
            state["conversations"]["medical"] = []

        state["conversations"]["medical"].append(medical_convo)

    # Update status to move to next stage
    if state["claimant_info"] and state["employer_info"] and state["medical_info"]:
        state["status"] = "review"
    # Update timestamp
    state["updated_at"] = datetime.datetime.now().isoformat()
    # Save updated state to database
    save_claim_state(state)
    return state


def claims_adjuster(state: ClaimState) -> ClaimState:
    """The claims adjuster verifies information and makes decisions.

    This agent verifies employment, reviews medical information,
    ensures compliance, and makes approval/denial decisions.

    Args:
        state: The current claim state

    Returns:
        Updated claim state with adjuster decisions
    """
    print(100* '-')
    print('Adjuster Agent initialized')
    # Initialize LLM with tools
    llm = initialize_model()
    tools = [
        verify_employment,
        assess_medical_treatment,
        make_claim_decision,
        simulate_conversation
    ]
    llm_with_tools = llm.bind_tools(tools)

    # Check the current claim status
    if state["status"] == "review":
        # Verify employment
        if not state.get("verification_status") or "employment_verified" not in state["verification_status"]:
            # In a real implementation, we would use the model to make this decision
            verification_result = verify_employment.invoke({
                "employment_verified":True,
                "verification_method":"Employer conversation",
                "employment_duration":state["employer_info"]["employment_duration"],
                "eligible_for_benefits":True,
                "notes":"Employment verified through conversation with supervisor"
            })
            print(verification_result)

            state["verification_status"] = verification_result

        # Assess medical treatment
        if not state.get("medical_assessment") and state["medical_info"]:
            # Check if provider is qualified for injury type (simplified)
            medical_assessment = assess_medical_treatment.invoke({
                "diagnosis_confirmed":True,
                "provider_qualified":True,  # Simplified check
                "treatment_appropriate":True,
                "notes":"Diagnosis and treatment plan are appropriate for the reported injury"
            })
            print(medical_assessment)

            state["medical_assessment"] = medical_assessment

        # Make claim decision if verification and medical assessment are complete
        if state.get("verification_status") and state.get("medical_assessment"):
            if not state.get("decision_status") or "approved" not in state["decision_status"]:
                # Verify conditions for approval
                employment_verified = state["verification_status"].get("employment_verified", False)
                provider_qualified = state["medical_assessment"].get("provider_qualified", False)
                treatment_appropriate = state["medical_assessment"].get("treatment_appropriate", False)

                if employment_verified and provider_qualified and treatment_appropriate:
                    # Approve claim
                    decision = make_claim_decision.invoke({
                        "approved":True,
                        "reasoning":"Employment verified and medical treatment is appropriate",
                        "benefit_type":"Medical and temporary disability"
                    })
                else:
                    # Deny claim
                    denial_reason = ""
                    if not employment_verified:
                        denial_reason = "Employment could not be verified"
                    elif not provider_qualified:
                        denial_reason = "Medical provider not qualified for treatment"
                    elif not treatment_appropriate:
                        denial_reason = "Treatment not appropriate for reported injury"

                    decision = make_claim_decision(
                        approved=False,
                        reasoning="Claim does not meet requirements for approval",
                        denial_reason=denial_reason
                    )
                    print(decision)

                state["decision_status"] = decision

                # Update status based on decision
                if decision["approved"]:
                    state["status"] = "payment"
                else:
                    state["status"] = "denied"

    # Handle recovery check if in recovery status
    elif state["status"] == "recovery_check":
        # In a real implementation, we would simulate a conversation with
        # the claimant and possibly the medical provider

        # For this simulation, we'll create a recovery assessment
        recovery_assessment = assess_recovery.invoke({
            "recovery_status":"improving",
            "work_status":"light duty",
            "fully_recovered":False,  # Not yet fully recovered
            "continuing_treatment":True,
            "next_medical_review":(datetime.datetime.now() + datetime.timedelta(days=14)).isoformat(),
            "notes":"Patient showing improvement but needs continued therapy"
        })
        print(recovery_assessment)

        state["recovery_status"] = recovery_assessment

        # If not fully recovered, go back to review for additional treatment
        if not recovery_assessment["fully_recovered"]:
            state["status"] = "review"
        else:
            # If fully recovered, mark as closed
            state["status"] = "closed"
            state["finished"] = True

    # Update timestamp
    state["updated_at"] = datetime.datetime.now().isoformat()

    # Save updated state to database
    save_claim_state(state)

    return state


def payment_processor(state: ClaimState) -> ClaimState:
    """The payment processor calculates and issues payments.

    This agent processes payments for approved claims based on
    the adjuster's authorization.

    Args:
        state: The current claim state

    Returns:
        Updated claim state with payment information
    """
    # Initialize LLM with tools
    llm = initialize_model()
    tools = [process_payment]
    llm_with_tools = llm.bind_tools(tools)

    # Check if claim is approved and needs payment
    if state["status"] == "payment" and state["decision_status"].get("approved", False):
        # Calculate payment amount (in a real system, this would be more complex)
        # For this simulation, we'll use a simple fixed amount
        payment_amount = 1000.00

        # Process payment
        payment = process_payment.invoke({
            "payment_amount":payment_amount,
            "payment_type":"initial",
            "payment_period":"First two weeks",
            "notes":"Initial payment for approved claim"
        })
        print(payment)

        # Update payment info in state
        if "payments" not in state["payment_info"]:
            state["payment_info"]["payments"] = []

        state["payment_info"]["payments"].append(payment)
        state["payment_info"]["last_payment_date"] = payment["payment_date"]

        # Move to recovery check status
        state["status"] = "recovery_check"

    # Update timestamp
    state["updated_at"] = datetime.datetime.now().isoformat()

    # Save updated state to database
    save_claim_state(state)

    return state


def recovery_check_agent(state: ClaimState) -> ClaimState:
    """The recovery check agent assesses recovery progress.

    This agent contacts the claimant and medical provider to
    assess recovery progress and determine next steps.

    Args:
        state: The current claim state

    Returns:
        Updated claim state with recovery information
    """
    print(100* '-')
    print('Recovery Check Agent initialized')
    # Initialize LLM with tools
    llm = initialize_model()
    tools = [
        simulate_conversation,
        assess_recovery
    ]
    llm_with_tools = llm.bind_tools(tools)

    # Simulate conversation with claimant about recovery
    recovery_convo = simulate_conversation.invoke({
        "role":"claimant",
        "context":{"claim_id": state["claim_id"], "medical_info": state["medical_info"]},
        # "assess recovery progress"
    })
    print(recovery_convo)

    # Update state with recovery information
    # state["recovery_info"] = recovery_convo["extracted_data"]

    # In a real implementation, we would use the conversation to inform the assessment

    # For this simulation, we'll create a recovery assessment directly
    # Randomly determine if fully recovered for simulation purposes
    import random
    fully_recovered = random.choice([True, False])

    recovery_assessment = assess_recovery.invoke({
        "recovery_status":"improving" if not fully_recovered else "recovered",
        "work_status":"light duty" if not fully_recovered else "full duty",
        "fully_recovered":fully_recovered,
        "continuing_treatment":not fully_recovered,
        "next_medical_review":None if fully_recovered else (datetime.datetime.now() + datetime.timedelta(days=14)).isoformat(),
        "notes":"Patient's status assessed through conversation"
    })

    state["recovery_status"] = recovery_assessment

    # Update status based on recovery
    if fully_recovered:
        state["status"] = "closed"
        state["finished"] = True
    else:
        # Not fully recovered, go back to review for additional treatment
        state["status"] = "review"

    # Update timestamp
    state["updated_at"] = datetime.datetime.now().isoformat()

    # Save updated state to database
    save_claim_state(state)

    return state


# Routing functions
def route_after_intake(state: ClaimState) -> Literal["adjuster", END]:
    """Route after intake is complete.

    Args:
        state: The current claim state

    Returns:
        Next node name or END
    """
    # Check if all required information is gathered
    if state["claimant_info"] and state["employer_info"] and state["medical_info"]:
        return "adjuster"
    else:
        # If for some reason intake didn't complete properly, end the process
        return END


def route_after_adjuster(state: ClaimState) -> Literal["payment", END]:
    """Route after adjuster processing.

    Args:
        state: The current claim state

    Returns:
        Next node name or END
    """
    # Check the status after adjuster processing
    if state["status"] == "payment":
        return "payment"
    elif state["status"] == "closed" or state["status"] == "denied":
        return END
    elif state["status"] == "review":
        # If we're coming back from recovery check, go back to adjuster
        return "adjuster"
    else:
        return END


def route_after_payment(state: ClaimState) -> Literal["recovery_check", END]:
    """Route after payment processing.

    Args:
        state: The current claim state

    Returns:
        Next node name or END
    """
    # Check if payment was processed
    if state["status"] == "recovery_check":
        return "recovery_check"
    else:
        return END


def route_after_recovery(state: ClaimState) -> Literal["adjuster", END]:
    """Route after recovery check.

    Args:
        state: The current claim state

    Returns:
        Next node name or END
    """
    # If fully recovered, end the process
    if state["status"] == "closed":
        return END
    # If need additional treatment, go back to adjuster
    elif state["status"] == "review":
        return "adjuster"
    else:
        return END


# Main workflow definition
def create_workflow():
    """Create the LangGraph workflow for claims processing.

    Returns:
        A compiled LangGraph StateGraph
    """
    print(100* '-')
    print('Workflow initialized')
    # Initialize graph
    graph_builder = StateGraph(ClaimState)
    print(graph_builder)

    # Add agent nodes
    graph_builder.add_node("intake", intake_agent)
    graph_builder.add_node("adjuster", claims_adjuster)
    graph_builder.add_node("payment", payment_processor)
    graph_builder.add_node("recovery_check", recovery_check_agent)

    # Define routing
    graph_builder.add_edge(START, "intake")
    graph_builder.add_conditional_edges("intake", route_after_intake)
    graph_builder.add_conditional_edges("adjuster", route_after_adjuster)
    graph_builder.add_conditional_edges("payment", route_after_payment)
    graph_builder.add_conditional_edges("recovery_check", route_after_recovery)

    # Compile the graph
    return graph_builder.compile()


# Execution functions
def process_claim(initial_state):
    """Process a claim through the workflow.

    Args:
        initial_state: The initial claim state

    Returns:
        Final state after processing
    """
    # Create workflow
    workflow = create_workflow()

    # Set recursion limit to allow for loops in the workflow
    config = {"recursion_limit": 100}

    # Process the claim
    result = workflow.invoke(initial_state, config)
    return result


In [None]:
def main():
    """Main application entry point."""
    # Initialize database
    initialize_model()
    initialize_database()

    # Create a new claim
    initial_state = create_initial_state()

    print(f"Processing claim {initial_state['claim_id']}...")

    # Process the claim
    result = process_claim(initial_state)

    # Output result
    print(f"Claim processing complete. Final status: {result['status']}")

    if result["status"] == "closed" and result.get("decision_status", {}).get("approved", False):
        print("Claim was approved and is now closed.")

        # Show payment information
        payments = result.get("payment_info", {}).get("payments", [])
        if payments:
            print("\nPayments:")
            for payment in payments:
                print(f"  {payment['payment_date']}: ${payment['payment_amount']:.2f} ({payment['payment_type']})")

    elif result["status"] == "denied":
        denial_reason = result.get("decision_status", {}).get("denial_reason", "No reason specified")
        print(f"Claim was denied. Reason: {denial_reason}")

    return result


# Run the application if executed directly
if __name__ == "__main__":
    main()

Database initialized successfully
Processing claim 1...
----------------------------------------------------------------------------------------------------
Workflow initialized
<langgraph.graph.state.StateGraph object at 0x7b8dfdae7610>
----------------------------------------------------------------------------------------------------
Intake Agent initialized
model='models/gemini-2.0-flash' google_api_key=SecretStr('**********') client=<google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient object at 0x7b8df056ab90> default_metadata=()
content='' additional_kwargs={'function_call': {'name': 'simulate_conversation', 'arguments': '{"context": "New workers\' compensation claim (ID: 1)", "role": "claimant"}'}} response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.0-flash', 'safety_ratings': []} id='run-fcb30e1f-7e8f-404c-b97f-f1e26daa30ae-0' tool_calls=[{'name': 'simulate_c

KeyError: 'adjuster'

In [None]:
# # Import required packages
# import os
# from typing import Annotated, Literal
# from typing_extensions import TypedDict
# import datetime
# import json
# import uuid
# import sqlite3

# from google import genai
# from google.genai import types

# from langgraph.graph import StateGraph, END, START
# from langgraph.graph.message import add_messages
# from langgraph.prebuilt import ToolNode
# from langchain_core.tools import tool
# from langchain_google_genai import ChatGoogleGenerativeAI

# # Define the state object for our claim
# class ClaimState(TypedDict):
#     """State representing a workers' compensation claim process."""

#     # Tracking information
#     claim_id: str
#     status: str  # Current status in the workflow

#     # The chat conversation history. This preserves the conversation history
#     # between nodes.
#     messages: Annotated[list, add_messages]

#     # Claim details
#     claimant_info: dict  # Personal and injury details
#     employer_info: dict  # Employment verification
#     medical_info: dict   # Medical treatment and diagnosis

#     # Processing information
#     verification_status: dict  # Results of verification
#     decision_status: dict      # Approval/denial
#     payment_info: dict         # Payment details and history
#     recovery_status: dict      # Recovery check results

#     # Flag to indicate completion
#     finished: bool

# # Define our system prompts
# INTAKE_SYSTEM_PROMPT = (
#     "system",
#     """You are the Intake Agent for a workers' compensation claim system. Your role is to:
#     1. Gather initial information from claimants about their injury
#     2. Collect employer information for verification
#     3. Record initial medical information
#     4. Structure this information for further processing

#     Important rules:
#     - Document all information thoroughly
#     - Only collect information relevant to the claim
#     - Be empathetic but focused on gathering facts
#     - Maintain confidentiality of sensitive information
#     """
# )

# ADJUSTER_SYSTEM_PROMPT = (
#     "system",
#     """You are the Claims Adjuster for a workers' compensation claim system. Your role is to:
#     1. Verify employment and injury details
#     2. Review medical information and ensure provider is qualified for the injury type
#     3. Ensure compliance with regulations
#     4. Make approval/denial decisions based on evidence
#     5. Monitor recovery progress and authorize additional treatment

#     Important rules:
#     - Verify all medical providers are qualified for the treatment
#     - Only share medical information with authorized parties
#     - Document reasoning for all decisions
#     - Ensure recovery checks are scheduled at appropriate intervals
#     - Authorize payments but do not process them
#     """
# )

# PAYMENT_SYSTEM_PROMPT = (
#     "system",
#     """You are the Payment Processing Agent for a workers' compensation claim system. Your role is to:
#     1. Calculate payment amounts based on approved claims
#     2. Process payments for medical procedures and lost wages
#     3. Track payment history for each claim
#     4. Ensure payments are within policy limits

#     Important rules:
#     - Only process payments authorized by the Claims Adjuster
#     - Document all payment details including date, amount, and recipient
#     - Track running totals against policy limits
#     - Process multiple payments for the same claim as needed
#     """
# )

# # Initialize model
# def initialize_model():
#     """Initialize the Gemini model with API key."""
#     # Setup API key
#     api_key = os.environ.get("AIzaSyDW0Faqg24tRba6h7tEl4FQsodn0_IjH3s") #edit
#     genai.configure(api_key=api_key)

#     # Create the model
#     llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
#     return llm

# # Database functions
# def initialize_database():
#     """Initialize the SQLite database with required tables."""
#     conn = sqlite3.connect("claims.db")
#     cursor = conn.cursor()

#     # Create tables
#     cursor.execute("""
#     CREATE TABLE IF NOT EXISTS claims (
#         claim_id TEXT PRIMARY KEY,
#         status TEXT,
#         created_at TEXT,
#         updated_at TEXT
#     )
#     """)

#     cursor.execute("""
#     CREATE TABLE IF NOT EXISTS conversations (
#         conversation_id TEXT PRIMARY KEY,
#         claim_id TEXT,
#         stakeholder_type TEXT,
#         timestamp TEXT,
#         conversation_text TEXT,
#         FOREIGN KEY (claim_id) REFERENCES claims (claim_id)
#     )
#     """)

#     # Additional tables as needed

#     conn.commit()
#     conn.close()

# # Function tools for conversation simulation
# @tool
# def simulate_conversation(role: str, context: dict, goal: str) -> dict:
#     """Simulate a conversation with a stakeholder.

#     Args:
#         role: The stakeholder role (claimant, employer, medical)
#         context: Relevant information about the claim
#         goal: The goal of the conversation

#     Returns:
#         A dictionary with the conversation text and extracted data
#     """
#     # This would call the LLM to simulate both sides of the conversation
#     pass

# @tool
# def extract_claimant_data(
#     name: str,
#     contact_info: str,
#     injury_date: str,
#     injury_description: str,
#     how_it_occurred: str,
#     witnesses: list[str] = None,
#     medical_attention: bool = False,
#     medical_provider: str = None
# ) -> dict:
#     """Extract structured claimant data from conversation."""
#     return {
#         "name": name,
#         "contact_info": contact_info,
#         "injury_date": injury_date,
#         "injury_description": injury_description,
#         "how_it_occurred": how_it_occurred,
#         "witnesses": witnesses or [],
#         "medical_attention": medical_attention,
#         "medical_provider": medical_provider
#     }

# # Similar extraction tools for other data types

# # Define agent nodes
# def intake_agent(state: ClaimState) -> ClaimState:
#     """The intake agent gathers initial information about the claim."""
#     # Implementation with function calling to simulate conversations
#     pass

# def claims_adjuster(state: ClaimState) -> ClaimState:
#     """The claims adjuster verifies information and makes decisions."""
#     # Implementation with function calling for verification, review, and decisions
#     pass

# def payment_processor(state: ClaimState) -> ClaimState:
#     """The payment processor calculates and issues payments."""
#     # Implementation with function calling for payment processing
#     pass

# # Routing functions
# def route_after_intake(state: ClaimState) -> Literal["adjuster", END]:
#     """Route after intake is complete."""
#     # Logic to determine next step
#     pass

# def route_after_adjuster(state: ClaimState) -> Literal["payment", "denied", END]:
#     """Route after adjuster processing."""
#     # Logic based on approval/denial
#     pass

# def route_after_payment(state: ClaimState) -> Literal["recovery_check", END]:
#     """Route after payment processing."""
#     # Logic to move to recovery check
#     pass

# def route_after_recovery(state: ClaimState) -> Literal["adjuster", END]:
#     """Route after recovery check."""
#     # Logic to determine if further review needed
#     pass

# # Main workflow definition
# def create_workflow():
#     """Create the LangGraph workflow for claims processing."""
#     # Initialize graph
#     graph_builder = StateGraph(ClaimState)

#     # Add agent nodes
#     graph_builder.add_node("intake", intake_agent)
#     graph_builder.add_node("adjuster", claims_adjuster)
#     graph_builder.add_node("payment", payment_processor)
#     graph_builder.add_node("recovery_check", recovery_check_agent)

#     # Define routing
#     graph_builder.add_edge(START, "intake")
#     graph_builder.add_conditional_edges("intake", route_after_intake)
#     graph_builder.add_conditional_edges("adjuster", route_after_adjuster)
#     graph_builder.add_conditional_edges("payment", route_after_payment)
#     graph_builder.add_conditional_edges("recovery_check", route_after_recovery)

#     # Compile the graph
#     return graph_builder.compile()

# # Execution functions
# def process_claim(initial_state):
#     """Process a claim through the workflow."""
#     workflow = create_workflow()
#     result = workflow.invoke(initial_state)
#     return result

# def create_initial_state(claim_details=None):
#     """Create initial state for a new claim."""
#     claim_id = str(uuid.uuid4())
#     timestamp = datetime.datetime.now().isoformat()

#     initial_state = ClaimState(
#         claim_id=claim_id,
#         status="new",
#         messages=[],
#         claimant_info={},
#         employer_info={},
#         medical_info={},
#         verification_status={},
#         decision_status={},
#         payment_info={},
#         recovery_status={},
#         finished=False
#     )

#     # If claim details provided, update initial state
#     if claim_details:
#         for key, value in claim_details.items():
#             if key in initial_state:
#                 initial_state[key] = value

#     return initial_state