In [2]:
import getpass
import os
from dotenv import load_dotenv

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("ANTHROPIC_API_KEY")
_set_env("TAVILY_API_KEY")
_set_env("GROQ_API_KEY")

load_dotenv()

True

In [3]:
import os
import shutil
import sqlite3

local_file = "insurance.db"
# The backup lets us restart for each tutorial section
backup_file = "insurance.backup.db"

if os.path.exists(local_file):
# Backup - we will use this to "reset" our DB in each section
    shutil.copy(local_file, backup_file)

db = local_file

### Tool

#### Lookup Company Policy

In [1]:
import re
import numpy as np
import os
import openai
from langchain_core.tools import tool
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

from langchain_community.document_loaders import UnstructuredMarkdownLoader

# Load data from MD file
# with open('Implement_Customer_Service_in_Insurance.md', 'r') as f:
#     faq_text = f.read()

# faq_docs = [{"page_content": txt} for txt in re.split(r"(?=\n##)", faq_text)]

docs = "./insurance_faq.md"
embeder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
FAISS_INDEX_PATH = "insurance_faiss_index"

# Load or create FAISS index
def load_or_create_faiss_index():
    embeddings = embeder
    if os.path.exists(FAISS_INDEX_PATH):
        Vdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization = True)
        print("Loaded vectors from FAISS index.")
    else:
        # Vectorise the sales response csv data
        loader = UnstructuredMarkdownLoader(
            docs,
            mode="single",
            strategy="fast",)
        documents = loader.load()
        Vdb = FAISS.from_documents(documents, embeddings)
        Vdb.save_local(FAISS_INDEX_PATH)
        print("Created and saved vectors to FAISS index.") 
    return Vdb

# Call the function to load or create the FAISS index
Vdb = load_or_create_faiss_index()

def retrieve_info(query):
    similar_response = Vdb.similarity_search(query, k=3)
    page_contents_array = "\n\n".join([doc.page_content for doc in similar_response])
    # print(page_contents_array)

    return page_contents_array

@tool
def lookup_policy(query: str) -> str:
    """Consult the company policies to check whether certain options are permitted.
    Use this before making any insurance changes performing other 'write' events."""
    policy_info = retrieve_info(query)
    return policy_info


  from tqdm.autonotebook import tqdm, trange


Loaded vectors from FAISS index.


In [4]:
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate

llm = ChatGroq(model_name = "llama-3.1-8b-instant", temperature=1)

template = """
You are a world class question and answering agent and an intuitive researcher on 
insurance companies policy.
I will share a clients query with you and you will give the best answer based on the 
company policy and best practices.

Below is a message I received from the client:
{query}

Here is a list of best practies of how we normally respond to prospect in similar scenarios:
{company_policy}

Please write the best response for this client.
"""

prompt = PromptTemplate(
    input_variables=["query", "company_policy"],
    template=template
)

chain = prompt | llm

In [5]:
query = "I want to update my information"
policy_info = lookup_policy(query)

input_data = {
    "query": query,
    "company_policy": policy_info
}

# Generate the response using the LLM chain
response = chain.invoke(input= input_data)
print(response.content)

  policy_info = lookup_policy(query)


Dear valued client,

Thank you for reaching out to update your information. We're happy to assist you with any changes to your policy. To update your personal details, such as your address or contact information, you can log in to your account on our website. Simply navigate to the "Account Settings" section and make the required changes.

If you need to update more sensitive information like your name or ID, please contact our customer support team, and we will guide you through the process. We may need to verify the changes with the necessary documentation, so please be prepared to provide the required documents.

If you're unsure about what changes can be made or need assistance, our customer support team is always here to help. You can reach us by phone, email, or through our live chat feature on our website. We're available Monday through Friday from 8:00 AM to 6:00 PM.

To proceed with updating your information, please provide us with the specific changes you'd like to make, and 

#### Get Policy Info

In [28]:
import sqlite3
from typing import Optional, Dict
from langchain_core.runnables import RunnableConfig

# Connect to the SQLite database
def connect_db(db_file: str = "insurance.db") -> sqlite3.Connection:
    """Establishes a connection to the SQLite database."""
    try:
        conn = sqlite3.connect(db_file)
        return conn
    except sqlite3.Error as e:
        print(f"Error connecting to database: {e}")
        return None

# Retrieve policy information based on policy number or customer ID
# @tool
# def get_policy_info(config: RunnableConfig) -> Dict:
def view_policy_info(customer_id: Optional[int] = None) -> Dict:
    """Fetches the policy details from the database.
    
    Returns:
    A list of dictionaries where each dictionary contains the policy details of the user.
    
    """
    
    # configuration = config.get("configurable", {})
    # policy_number = configuration.get("policy_number", None)
    # customer_id = configuration.get("customer_id", None)

    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}

    query = "SELECT * FROM customer_policy WHERE "
    params = ()

    if customer_id:
        query += "customer_id = ?"
        params = (customer_id,)
    else:
        return {"error": "No policy number or customer ID provided"}

    try:
        cursor = conn.cursor()
        cursor.execute(query, params)
        rows = cursor.fetchall()

        if rows:
            column_names = [column[0] for column in cursor.description]
            policy_info = [dict(zip(column_names, row)) for row in rows]
            # policy_info = {
            #     "policy_id": policy[0],
            #     "policy_number": policy[1],
            #     "customer_id": policy[2],
            #     "policy_type": policy[3],
            #     "coverage_details": policy[4],
            #     "start_date": policy[5],
            #     "end_date": policy[6],
            #     "status": policy[7],
            # }
            return policy_info
        else:
            return {"error": "Policy not found"}
    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

# Example of how to call the function
policy_data = view_policy_info(customer_id="C001")
print(policy_data)


[{'customer_policy_id': 1, 'customer_id': 'C001', 'policy_id': 'P1001', 'start_date': '2023-01-01', 'end_date': '2024-12-26', 'monthly_premium': 100.0, 'policy_status': 'active', 'policy_type': 'Health'}]


#### Claims Management Tool

In [26]:
import sqlite3
from typing import Optional, Dict, List

# Create a new claim
# @tool
# def create_claim(config: RunnableConfig) -> Dict:
def create_claim(policy_number: str, customer_id: str, claim_date: str, claim_amount: float, description: str, claim_type: str) -> Dict:
    """Creates a new claim in the database."""
    
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}
    
    try:
        cursor = conn.cursor()
        query = """
            INSERT INTO claims (policy_id, customer_id, claim_date, claim_amount, claim_status, description, claim_type)
            VALUES (?, ?, ?, ?, 'Pending', ?, ?)
        """
        cursor.execute(query, (policy_number, customer_id, claim_date, claim_amount, description, claim_type))
        conn.commit()

        return {"success": "Claim created successfully", "claim_id": cursor.lastrowid}
    
    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

# Update claim status
# Use RunnableConfig
def update_claim_status(claim_id: int, new_status: str) -> Dict:
    """Updates the status of a claim."""
    
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}
    
    try:
        cursor = conn.cursor()
        query = "UPDATE claims SET claim_status = ? WHERE claim_id = ?"
        cursor.execute(query, (new_status, claim_id))
        conn.commit()

        if cursor.rowcount == 0:
            return {"error": "Claim not found or status not updated"}
        
        return {"success": "Claim status updated successfully"}
    
    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

# Retrieve claim details by claim_id or policy_number
def get_claim_info(claim_id: Optional[int] = None, customer_id: Optional[str] = None) -> List[Dict]:
    """Fetches claim details from the database."""
    
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}
    
    query = "SELECT * FROM claims WHERE "
    params = ()

    if claim_id:
        query += "claim_id = ?"
        params = (claim_id,)
    elif customer_id:
        query += "customer_id = ?"
        params = (customer_id,)
    else:
        return {"error": "No claim_id or customer_id provided"}

    try:
        cursor = conn.cursor()
        cursor.execute(query, params)
        claims = cursor.fetchall()

        if claims:
            
            column_names = [column[0] for column in cursor.description]
            claim_list = [dict(zip(column_names, claim)) for  claim in claims]
            # claim_list = []
            # for claim in claims:
            #     claim_info = {
            #         "claim_id": claim[0],
            #         "policy_number": claim[1],
            #         "customer_id": claim[2],
            #         "claim_date": claim[3],
            #         "claim_amount": claim[4],
            #         "claim_status": claim[5],
            #         "description": claim[6],
            #     }
            #     claim_list.append(claim_info)
            return claim_list
        else:
            return {"error": "Claim not found"}
    
    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()


In [48]:
# Example usage
# Creating a new claim
# new_claim = create_claim(policy_number="P1002", customer_id= "C002", claim_date="2023-09-16", claim_amount=5000.0, description="Car accident")
# print(new_claim)

# # Updating a claim status
# status_update = update_claim_status(claim_id=1, new_status="Approved")
# print(status_update)

# Retrieving claim info by policy number
claim_info = get_claim_info(customer_id="C002")
print(claim_info)

[{'claim_id': 1, 'policy_id': 'P1002', 'customer_id': 'C002', 'claim_date': '2023-09-16', 'claim_amount': 5000.0, 'claim_status': 'Pending', 'description': 'Car accident'}]


#### Policy Renewal and Cancellation

In [7]:
from datetime import datetime, timedelta


# Policy renewal
def renew_policy(customer_id: str, policy_id: str, additional_months: int = 1) -> dict:
    """Renews a policy by extending its end date."""
    
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}

    try:    
        cursor = conn.cursor()

        # Fetch the current policy details

        query = "SELECT start_date, end_date, policy_status FROM customer_policy WHERE customer_id = ? AND policy_id = ? "
        cursor.execute(query, (customer_id, policy_id,))
        policy = cursor.fetchone()

        if policy:
            start_date, end_date, policy_status = policy

            # Parse the end_date to a datetime object
            end_date = datetime.strptime(end_date, '%Y-%m-%d') if end_date else datetime.now()

            if policy_status == 'active' or policy_status == 'expired':
                # Extend the end date by the specified number of months
                new_end_date_obj = end_date + timedelta(days=30 * additional_months)

                # Update the policy details
                update_query = "UPDATE customer_policy SET end_date = ?, policy_status = 'active' WHERE customer_id = ? AND policy_id = ?"
                new_end_date = new_end_date_obj.strftime("%Y-%m-%d")

                cursor.execute(update_query, (new_end_date, customer_id, policy_id,))

                conn.commit()
                print(f"Policy {policy_id} for customer {customer_id} has been renewed until {new_end_date}.")
            else:
                return{"error": "Cannot renew a canceled policy."}
        else:
            return {"error": "Policy not found for the customer."}

    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

# Policy cancellation
def cancel_policy(customer_id: str, policy_id: str):
    """Cancels a policy by updating its status to 'Cancelled'."""


    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}

    try:    
        cursor = conn.cursor()

        query = "SELECT policy_status FROM customer_policy WHERE customer_id = ? AND policy_id = ?"
        # Check if the policy exists and is active
        cursor.execute(query, (customer_id, policy_id))
        policy_status = cursor.fetchone()

        if policy_status and policy_status[0] == 'active':
            # Cancel the policy by updating the status and end_date
            current_date = datetime.now().strftime('%Y-%m-%d')
            update_query = "UPDATE customer_policy SET policy_status = 'canceled', end_date = ? WHERE customer_id = ? AND policy_id = ?"
            cursor.execute(update_query, (current_date, customer_id, policy_id))

            conn.commit()
            return{"error": f"Policy {policy_id} for customer {customer_id} has been canceled."}
        else:
            return{"error": "Policy either doesn't exist or is not active."}

    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

In [13]:
# Renewing a policy for customer 1, policy 101, extending it by 6 months
renew_policy(customer_id="C001", policy_id="P1001", additional_months=6)

# # Canceling a policy for customer 1, policy 101
# cancel_policy(customer_id=1, policy_id=101)

Policy P1001 for customer C001 has been renewed until 2024-12-26.


#### Customer Profile Management

In [16]:

def view_customer_profile(customer_id: str) -> dict:
    """Fetch and return a customer's profile details."""
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}
    
    query = "SELECT * FROM customers WHERE "
    params = ()

    if customer_id:
        query += "customer_id = ?"
        params = (customer_id,)
    else:
        return {"error": "No customer ID provided"}

    try:
        cursor = conn.cursor() 
        cursor.execute(query, params)
        customer = cursor.fetchone() # Fetch only one result
        
        if customer:
            column_names = [column[0] for column in cursor.description]
            customer_info = dict(zip(column_names, customer)) # Return single customer info as dict
            return customer_info
        else:
            return {"error": "Customer not found."}

    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

def update_customer_profile(customer_id: str, name: str = None, email: str = None, phone: str = None, address: str = None) -> dict:
    """Update a customer's profile details."""
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}

    try:
        cursor = conn.cursor()
        # Fetch current details
        current_profile = view_customer_profile(customer_id)
        if "error" in current_profile:
            return current_profile
        
        # Update only the provided fields, keep others the same
        update_query = """
            UPDATE customers 
            SET name = ?, email = ?, phone = ?, address = ? 
            WHERE customer_id = ?
        """
        updated_name = name if name else current_profile["name"]
        updated_email = email if email else current_profile["email"]
        updated_phone = phone if phone else current_profile["phone"]
        updated_address = address if address else current_profile["address"]

        cursor.execute(update_query, (updated_name, updated_email, updated_phone, updated_address, customer_id))
        conn.commit()
        
        return {"success": f"Customer {customer_id}'s profile has been updated."}

    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

def delete_customer_profile(customer_id: str) -> dict:
    """Delete a customer's profile from the database."""
    conn = connect_db()
    if conn is None:
        return {"error": "Failed to connect to the database"}
    
    try:
        cursor = conn.cursor()

        # Check if customer exists
        existing_customer = view_customer_profile(customer_id)
        if "error" in existing_customer:
            return existing_customer
        
        # Proceed with deletion
        delete_query = "DELETE FROM customers WHERE customer_id = ?"
        cursor.execute(delete_query, (customer_id,))
        conn.commit()
        
        return {"success": f"Customer {customer_id}'s profile has been deleted."}
    
    except sqlite3.Error as e:
        return {"error": f"Database error: {e}"}
    finally:
        conn.close()

In [18]:
# Example usage
# View a customer profile
# view_customer_profile(customer_id = "C001")

# # Update a customer profile
print(update_customer_profile(customer_id="C001", email="newemail@example.com", phone="555-123-4567"))

# # Delete a customer profile
# print(delete_customer_profile("C001"))


{'success': "Customer C001's profile has been updated."}


#### Utilities

In [19]:
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda

# from langchain_core.messages.modifier import RemoveMessage

from langgraph.prebuilt import ToolNode


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def _print_event(event: dict, _printed: set, max_length=1500):
    current_state = event.get("dialog_state")
    if current_state:
        print("Currently in: ", current_state[-1])
    message = event.get("messages")
    if message:
        if isinstance(message, list):
            message = message[-1]
        if message.id not in _printed:
            msg_repr = message.pretty_repr(html=True)
            if len(msg_repr) > max_length:
                msg_repr = msg_repr[:max_length] + " ... (truncated)"
            print(msg_repr)
            _printed.add(message.id)

### Specialised Workflows

In [25]:
from typing import Annotated, Literal, Optional
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages

# Define the state management for the insurance dialog
def update_dialog_stack(left: list[str], right: Optional[str]) -> list[str]:
    """Push or pop the state in the customer service automation."""
    if right is None:
        return left
    if right == "pop":
        return left[:-1]
    return left + [right]


class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]  # Holds the conversation history
    user_info: str  # Holds the customer's information
    dialog_state: Annotated[
        list[
            Literal[
                "assistant",
                "file_claim",
                "update_policy",
                "update_customer_info",
            ]
        ],
        update_dialog_stack,
    ]  # The state of the current interaction



In [24]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableConfig
# from langchain_groq import ChatGroq

# llm = ChatAnthropic(model="claude-3-sonnet-20240229", temperature=1)

# llm = ChatGroq(model_name = "llama-3.1-8b-instant", temperature=1)

class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            result = self.runnable.invoke(state)

            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}


class CompleteOrEscalate(BaseModel):
    """A tool to mark the current task as completed and/or to escalate control of the dialog to the main assistant,
    who can re-route the dialog based on the user's needs."""

    cancel: bool = True
    reason: str

    class Config:
        schema_extra = {
            "example": {
                "cancel": True,
                "reason": "User changed their mind about the current task.",
            },
            "example 2": {
                "cancel": True,
                "reason": "I have fully completed the task.",
            },
            "example 3": {
                "cancel": False,
                "reason": "I need to search the user's emails or calendar for more information.",
            },
        }


In [None]:
# Insurance Claim Assistant
claim_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a specialized assistant for managing insurance claims. "
            "The primary assistant delegates work to you whenever the user needs help with filing or managing an insurance claim. "
            "Guide the user through the process of submitting, checking the status, updating, or canceling a claim. "
            "Be persistent in collecting relevant claim details and provide clear instructions to the customer."
            " Escalate back to the primary assistant if the user requires additional support, such as understanding policy terms."
            " If you need more information or the customer changes their mind, escalate the task back to the main assistant."
            " Remember that a claim isn't successfully submitted until confirmation is received from the claims system."
            "\nCurrent time: {time}."
            '\n\nIf the user needs help, and none of your tools are appropriate for it, then "CompleteOrEscalate" the dialog to the host assistant.'
            " Do not waste the user's time. Do not make up invalid tools or functions."
            "\n\nSome examples for which you should CompleteOrEscalate:\n"
            " - 'How do I change my policy?' \n"
            " - 'What's covered under my policy?' \n"
            " - 'I want to know more about my deductible' \n"
            " - 'Nevermind, I’ll check my policy documents first'\n"
            " - 'Claim successfully filed'",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now())

# Tools
claim_safe_tools = [get_claim_info]
claim_sensitive_tools = [create_claim, update_claim_status] # Cancel/Delete claim
claim_tools = claim_safe_tools + claim_sensitive_tools

claim_runnable = claim_assistant_prompt | llm.bind_tools(
    claim_tools + [CompleteOrEscalate]
)


# Insurance Policy Assistant
policy_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a specialized assistant for managing insurance policies. "
            "The primary assistant delegates tasks to you whenever the user needs help with managing or reviewing their insurance policies. "
            "You are responsible for providing policy details, assisting with policy renewals, cancellations, and answering policy-related queries. "
            "If there are no results or you need additional information, escalate the task back to the main assistant."
            " Remember that policy management actions are not completed until the relevant system confirms them."
            "\nCurrent time: {time}."
            '\n\nIf the user asks for help, and your tools are not applicable, then "CompleteOrEscalate" the dialog back to the host assistant.'
            " Do not waste the user's time. Do not make up invalid tools or functions."
            "\n\nSome examples for which you should CompleteOrEscalate:\n"
            " - 'Can you help me understand coverage for accidents?'\n"
            " - 'I want to know more about policy premiums'\n"
            " - 'I'm thinking of switching insurance companies'\n"
            " - 'How do I file a claim?' \n"
            " - 'Policy successfully renewed or canceled'",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now())

# Tools
policy_safe_tools = [view_policy_info] # view policy details / list available policies
policy_sensitive_tools = [renew_policy, cancel_policy] # update_policy, create policy
policy_tools = policy_safe_tools + policy_sensitive_tools

policy_runnable = policy_assistant_prompt | llm.bind_tools(
    policy_tools + [CompleteOrEscalate]
)


# Insurance Customer Assistant
customer_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a specialized assistant for managing customer profiles and related tasks for an insurance company. "
            "The primary assistant delegates tasks to you whenever the user needs help with updating or reviewing customer profile details. "
            "You are responsible for managing customer information, such as personal details, contact information, and their enrolled policies. "
            "If more information is needed or the customer requests something outside your tools, escalate the task back to the main assistant."
            " Remember that profile updates aren't complete until they have been confirmed by the relevant system."
            "\nCurrent time: {time}."
            '\n\nIf the user asks for help, and none of your tools are applicable, then "CompleteOrEscalate" the dialog to the host assistant.'
            " Do not waste the user's time. Do not make up invalid tools or functions."
            "\n\nSome examples for which you should CompleteOrEscalate:\n"
            " - 'What's the best insurance policy for my family?'\n"
            " - 'I want to book an appointment with an agent.'\n"
            " - 'I need help choosing between Auto and Home insurance.'\n"
            " - 'I'll update my profile details later'\n"
            " - 'How do I cancel my claim?'",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now())

# Tools
customer_safe_tools = [view_customer_profile] # list_customer_policies
customer_sensitive_tools = [update_customer_profile, delete_customer_profile]
customer_tools = customer_safe_tools + customer_sensitive_tools

customer_runnable = customer_assistant_prompt | llm.bind_tools(
    customer_tools + [CompleteOrEscalate]
)





In [None]:
#Primary Assistant

class ToClaimsAssistant(BaseModel):
    """Transfers work to a specialized assistant to handle insurance claims management."""

    customer_id: str = Field(
        description="The ID of the customer filing or inquiring about a claim."
    )
    policy_number: str = Field(
        description="The ID of the policy under which the claim is being filed or managed."
    )
    claim_type: str = Field(
        description="The type of claim (e.g., auto accident, health issue, property damage)."
    )
    claim_date: str = Field(description="The date the incident related to the claim occurred.")
    claim_details: str = Field(
        description="Any additional information or requests regarding the claim."
    )

    class Config:
        schema_extra = {
            "example": {
                "customer_id": "C001",
                "policy_id": "P1002",
                "claim_type": "Auto Accident",
                "claim_date": "2024-09-15",
                "claim_details": "Car accident involving rear-end collision. I need assistance with filing the claim.",
            }
        }


class ToPolicyAssistant(BaseModel):
    """Transfers work to a specialized assistant to handle policy-related tasks such as renewals, cancellations, or inquiries."""

    customer_id: str = Field(
        description="The ID of the customer making the inquiry or request regarding the policy."
    )
    policy_id: str = Field(
        description="The ID of the policy in question."
    )
    request_type: str = Field(
        description="The type of request the customer is making (e.g., renewal, cancellation, inquiry)."
    )
    request_details: str = Field(
        description="Additional details or specific requests related to the policy (optional).", 
        default=None
    )

    class Config:
        schema_extra = {
            "example": {
                "customer_id": "C001",
                "policy_id": "P1003",
                "request_type": "Renewal",
                "request_details": "I would like to renew my auto insurance policy for another 6 months.",
            }
        }


class ToCustomerAssistant(BaseModel):
    """Transfers work to a specialized assistant to handle customer profile management, updates, or inquiries."""

    customer_id: str = Field(
        description="The ID of the customer making the request or inquiry."
    )
    request_type: str = Field(
        description="The type of request the customer is making (e.g., profile update, contact information change, inquiry)."
    )
    request_details: str = Field(
        description="Additional details or specific requests related to the customer's profile (optional).", 
        default=None
    )

    class Config:
        schema_extra = {
            "example": {
                "customer_id": "C001",
                "request_type": "Profile Update",
                "request_details": "I want to update my phone number and address.",
            }
        }


primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful customer support assistant for an insurance company. "
            "Your primary role is to search for policy information, claims status, and customer profile data to answer customer queries. "
            "If a customer requests to update or cancel their policy, file a claim, or update their profile information, "
            "delegate the task to the appropriate specialized assistant by invoking the corresponding tool. You are not able to make these types of changes yourself. "
            "Only the specialized assistants are given permission to do this for the user. "
            "The user is not aware of the different specialized assistants, so do not mention them; just quietly delegate through function calls. "
            "Provide detailed information to the customer, and always double-check the database before concluding that information is unavailable. "
            "When searching, be persistent. Expand your query bounds if the first search returns no results. "
            "If a search comes up empty, expand your search before giving up."
            "\n\nCurrent user policy and claims information:\n<Policies>\n{user_info}\n</Policies>"
            "\nCurrent time: {time}.",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now())

primary_assistant_tools = [
    search_policies,  # To search for available insurance policies.
    lookup_policy,
]

assistant_runnable = primary_assistant_prompt | llm.bind_tools(
    primary_assistant_tools
    + [
        ToPolicyAssistant,  # Delegate to the policy assistant for renewals or cancellations.
        ToClaimsAssistant,  # Delegate to the claims assistant for filing or updating claims.
        ToCustomerAssistant,  # Delegate to the customer assistant for profile updates.
    ]
)


### Utility

In [None]:
from typing import Callable

from langchain_core.messages import ToolMessage


def create_insurance_entry_node(assistant_name: str, new_dialog_state: str) -> Callable:
    """Creates an entry node for the insurance assistant, handling specialized tasks such as policies, claims, or customer management."""
    def entry_node(state: dict) -> dict:
        tool_call_id = state["messages"][-1].tool_calls[0]["id"]
        return {
            "messages": [
                ToolMessage(
                    content=(
                        f"The assistant is now the {assistant_name}. Review the conversation between the host assistant and the user."
                        f" The user's insurance-related request has not yet been satisfied. As {assistant_name}, use the appropriate tools to assist the user."
                        " Remember, you are handling insurance tasks (policy management, claims processing, or customer profile updates),"
                        " and the request is not complete until after the correct tool has been successfully invoked."
                        " If the user changes their mind or needs help for tasks outside your expertise, call the CompleteOrEscalate function to let the primary assistant take over."
                        " Do not mention who you are - just act as the proxy for the insurance assistant."
                    ),
                    tool_call_id=tool_call_id,
                )
            ],
            "dialog_state": new_dialog_state,
        }

    return entry_node