In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["TAVILY_API_KEY"] = "<>"

import uuid
import json
from datetime import datetime

from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompts import ChatPromptTemplate

from database import Database
from policy import Policy
from online_search import PersianTavilySearchTool
from flight import FlightManager
from llm_translation import translate_to_persian


class ToolMessage(HumanMessage):
    """Ollama does not support `tool` role and `ToolMessage`"""

    def pretty_repr(self, html: bool = False) -> str:
        title = get_msg_title_repr("Tool" + " Message", bold=html)
        # TODO: handle non-string content.
        if self.name is not None:
            title += f"\nName: {self.name}"
        return f"{title}\n\n{self.content}"


def get_tool_description(tool):
    tool_params = [
        f"{name}: {info['type']} ({info['description']})"
        for name, info in tool.args.items()
    ]
    tool_params_string = ', '.join(tool_params)
    return (
        f"tool_name -> {tool.name}\n"
        f"tool_params -> {tool_params_string}\n"
        f"tool_description ->\n{tool.description}"
    )


llm = ChatOllama(model='llama3', num_ctx=8192, num_thread=8, temperature=0.0)

database = Database(data_dir="storage/database")

policy = Policy(data_dir="storage/policy", embedding=OllamaEmbeddings(model='llama3'))
policy_tools = policy.get_tools()

flight_manager = FlightManager(database)
flight_tools = flight_manager.get_tools()

tools = [
    PersianTavilySearchTool(max_results=3, llm=llm),
] + list(policy_tools.values()) + list(flight_tools.values())

tool_descs = '\n\n'.join([get_tool_description(tool) for tool in tools])

In [3]:
prompt_template = \
"""
You are a helpful Persian customer support assistant for Iran Airlines.
Use the provided tools to search for flights, company policies, and other information to assist the user's queries. 
When searching, be persistent. Expand your query bounds if the first search returns no results. 
If a search comes up empty, expand your search before giving up.

You have access to the following tools to get more information if needed:

{tool_descs}

You also have access to the history of privious messages.

Generate the response in the following json format:
{{
    "THOUGHT": "<you should always think about what to do>",
    "ACTION": "<the action to take, must be one tool_name from above tools>",
    "ACTION_PARAMS": "<the input parameters to the ACTION, it must be in json format complying with the tool_params>"
    "FINAL_ANSWER": "<a text containing the final answer to the original input question>",
}}
If you don't know the answer, you can take an action using one of the provided tools.
But if you do, don't take and action and leave the action-related attributes empty.
The values `ACTION` and `FINAL_ANSWER` can never ever be filled at the same time.

Always make sure that your output is a json complying with above format.
Do NOT add anything before or after the json response.

Current user:\n<User>\n{user_info}\n</User>
Current time: {time}.
"""

In [4]:
from typing import Dict, Annotated, Any, Optional, Sequence
from typing_extensions import TypedDict

import warnings

from langchain_core.messages import AIMessage, AnyMessage, ToolCall
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.utils import RunnableCallable

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_node import tools_condition, str_output
from langgraph.graph.message import AnyMessage, Messages, add_messages
from langchain_core.runnables import Runnable, RunnableConfig


class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            result = self.runnable.invoke(state['messages'], config)
            try:
                content_json = json.loads(result.content)
                break
            except ValueError:
                warnings.warn('BAD FORMAT: ' + result.content)
                state['messages'] += [result, HumanMessage("Respond with a json output!")]

        action = content_json.get('ACTION', '').replace(' ', '')
        action_params = content_json.get('ACTION_PARAMS') or {}
        if type(action_params) is str:
            action_params = json.loads(action_params)
        final_answer = content_json.get('FINAL_ANSWER')

        if action:
            tool_call = ToolCall(name=action, args=action_params, id=str(uuid.uuid4()))
            result.tool_calls.append(tool_call)
            return {'messages': result}

        if not final_answer:
            persian_final_answer = ""
        else:
            persian_final_answer = translate_to_persian(final_answer, self.runnable)

        final_result = AIMessage(persian_final_answer)

        return {'messages': [result, final_result]}


class ToolNode(RunnableCallable):

    def __init__(
        self,
        tools: Sequence[BaseTool],
        *,
        name: str = 'tools',
        tags: Optional[list[str]] = None,
    ) -> None:
        super().__init__(self._func, None, name=name, tags=tags, trace=False)
        self.tools_by_name = {tool.name: tool for tool in tools}

    def _func(
        self, input: dict[str, Any], config: RunnableConfig
    ) -> Any:
        message: AnyMessage = input['messages'][-1]
        if not isinstance(message, AIMessage):
            raise ValueError("Last message is not an AIMessage")

        def run_one(call: ToolCall):
            output = self.tools_by_name[call["name"]].invoke(call["args"], config)
            tool_prompt = (
                "Here is the tool results:\n\n" +
                str_output(output)
            )
            return ToolMessage(
                content=tool_prompt, name=call["name"], tool_call_id=call["id"]
            )

        return {"messages": run_one(message.tool_calls[0])}


builder = StateGraph(State)

builder.add_node('assistant', Assistant(llm))
builder.add_node('action', ToolNode(tools))
builder.set_entry_point('assistant')
builder.add_conditional_edges(
    'assistant',
    tools_condition,
    {'action': 'action', END: END},
)
builder.add_edge('action', 'assistant')

memory = SqliteSaver.from_conn_string(':memory:')
graph = builder.compile(checkpointer=memory)

In [7]:
def _print_event(event: dict, _printed: set, ignore_first_system_message: bool = True):
    current_state = event.get('dialog_state')
    if current_state:
        print(f"Currently in: ", current_state[-1])
    messages = event.get('messages')
    if messages:
        if ignore_first_system_message:
            if isinstance(messages[0], SystemMessage):
                messages = messages[1:]
        for message in messages:
            if message.id not in _printed:
                msg_repr = message.pretty_repr(html=True)
                print(msg_repr)
                _printed.add(message.id)


tutorial_questions = [
    # "ساعت الان چنده؟",
    # "قوانین تهیه بلیط هواپیما چیه؟",
    # "سلام. اطلاعات های مربوط به پروازم رو میخواستم دریافت کنم.",
    # "اطلاعات پروازهای فردا رو میخواستم",
    # "لیست پروازهای به مقصد BSL رو میخواستم",
    # "پروازم رو به هفته آینده انتقال بده",
    # "هزینه اضافه بار برای مسافرت خارجی چقدر است؟",
    "امکانش هست تیکتم رو به شماره ۷۲۴۰۰۰۵۴۳۲۹۰۶۵۶۹ کنسل کنم؟",  # number isuue
    # "قیمت بلیط برای کودکان چقدره؟",  # confusion user/tool
    # "میخواستم تیکتم رو کنسل کنم",  # ???
    # "چه مدت قبل از پرواز هواپیما میشه بلیط تهیه کرد؟",  # hallucinate
    # "حداکثر برای چند نفر میتوان بلیط تهیه کرد؟",  # update/cancel ticket
]

config = {
    'configurable': {
        'passenger_id': '3442 587242',
        'thread_id': str(uuid.uuid4()),
    }
}

database.reset_and_prepare()


_printed = set()
for question in tutorial_questions:
    system_message = SystemMessage(prompt_template.format(
        tool_descs=tool_descs, time=datetime.now(),
        user_info=f"passenger_id: {config['configurable']['passenger_id']}"
    ))
    user_message = HumanMessage(question)

    events = graph.stream(
        {'messages': [system_message, user_message]}, config, stream_mode='values'
    )
    for event in events:
        _print_event(event, _printed)


امکانش هست تیکتم رو به شماره ۷۲۴۰۰۰۵۴۳۲۹۰۶۵۶۹ کنسل کنم؟

{ "THOUGHT": "Let's check if the ticket can be cancelled and what are the possible actions.", 
"ACTION": "cancel_ticket_tool", 
"ACTION_PARAMS": {"ticket_no": "72400005432906569"}, 
"FINAL_ANSWER": "" }
Tool Calls:
  cancel_ticket_tool (a8c73f45-d081-4ec4-82ee-f5464fe0ec44)
 Call ID: a8c73f45-d081-4ec4-82ee-f5464fe0ec44
  Args:
    ticket_no: 72400005432906569
Name: cancel_ticket_tool

Here is the tool results:

No existing ticket found for the given `ticket_no`.

{ "THOUGHT": "The ticket does not exist, so it cannot be cancelled.", 
"ACTION": "", 
"FINAL_ANSWER": "There is no ticket with the number 72400005432906569 to cancel." }

هیچ بلیت با شماره ۷۲۴۰۰۰۰۵۳۲۹۰۶۵۶۹ برای لغو وجود ندارد.
