In [1]:
# Import
import os
import sys

# Change to the project root directory
os.chdir('/home/alaa/repos/seez-assignment')
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

try:
    from src.utils.tools import read_dialogue
    print("Import successful!")
except:
    print("Failed")

import os

from src.api_key import OPENAI_API_KEY

os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

Import successful!


In [2]:
from typing import Annotated, Dict
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages

class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    completed_nodes: Annotated[list[str], lambda x, y: list(set(x + y))]
    latest_router_decision: str
    user_context: Dict = {}
    session_data: Dict = {}

In [3]:
from typing import List, Callable, Dict, Any

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig

model_name = "gpt-4o"

class AssistantNode:
    def __init__(
            self, 
            name: str, system_prompt: str, 
            tools: List[Callable] = [], 
            completion_tool: str = "", 
            llm_chain: Runnable = None,
            model_name=model_name
            ) -> None:
        """
        completion_tool: Is the tool binded to the node and once it is triggered, this
        node is assigned as completed, and this node name 'name' is appended to the State
        completed_node variable.
        """
        self.name = name
        self.system_prompt = system_prompt
        self.tools = tools
        self.completion_tool = completion_tool
        self.llm = llm_chain or ChatOpenAI(model=model_name)

    def update_system_prompt(self, prompt_config: dict) -> str:
        # self.logger.debug(f"Updating system prompt for {self.name}...")
        print(f"Updating system prompt for {self.name}...")
        try:
            if prompt_config and isinstance(prompt_config, dict):
                formatted_prompt = self.system_prompt.format_map(prompt_config)
                # self.logger.info(f"Using prompt config for {self.name}")
                print(f"Using prompt config for {self.name}")
            else:
                formatted_prompt = self.system_prompt
                # self.logger.info(f"Using default prompt for {self.name},no prompt configs!")
                print(f"Using default prompt for {self.name},no prompt configs!")
        except KeyError as e:
            formatted_prompt = self.system_prompt
            print(f"Error formatting prompt for {self.name}: {e}")
            # self.logger.error(f"Error formatting prompt for {self.name}: {e}")
        return formatted_prompt

    def __call__(self, state: State, config: RunnableConfig) -> Dict[str, Any]:

        print(f"Current Node: {self.name.title()}\n")

        system_prompt = self.update_system_prompt(
            {**state["company_profile"], **state["user_context"], **state["session_data"]}
        )
        assistant_prompt = ChatPromptTemplate.from_messages([("system", system_prompt),("placeholder", "{messages}")])
        if self.tools:
            llm_chain = assistant_prompt | self.llm.bind_tools(self.tools)
        else:
            llm_chain = assistant_prompt | self.llm

        while True:
            response = llm_chain.invoke(state)
            # print(response)
            if not response.tool_calls and (
                not response.content
                or isinstance(response.content, list)
                and not response.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Your last response was empty. Please provide a correct response.")]
                state = {**state, "messages": messages}
            else:
                break

        return {"messages": response}

In [4]:
import random

def fetch_user_information(
    state, 
    config
):
    print(); print("Current Node: Fetch User Info")
    configuration = config.get("configurable", {})
    user_id = configuration.get("user_id", None)

    user_data = {}
    
    if not user_id:
        return {
            "user_context": {
                "user_data": user_data, 
                "user_exists": False,
            }
        }
    
    try:
        from src.utils.tools import read_dialogue, read_user_data, read_jsonl, read_json, get_conversation_by_id

        path = 'data/Movie'
        final_data_path = '{}/final_data.jsonl'.format(path)
        Conversation_path = '{}/Conversation.txt'.format(path)
        item_map_path = '{}/item_map.json'.format(path)
        
        final_data = read_jsonl(final_data_path)
        item_map = read_json(item_map_path)
        Conversation = read_dialogue(Conversation_path)

        user_information = read_user_data(final_data_path, user_id)
        history_interaction = user_information['history_interaction']
        user_might_likes = user_information['user_might_like']
    
        user_data = {
                "history_interaction": [item_map[history_interaction[k]] for k in range(len(history_interaction))],
                "user_might_like": [item_map[user_might_likes[k]] for k in range(len(user_might_likes))],
                "Conversation": {}
            }
        Conversation_info = user_information['Conversation']
        for j in range(len(Conversation_info)):
            per_conversation_info = Conversation_info[j]['conversation_{}'.format(j + 1)]
            user_likes_id = per_conversation_info['user_likes']
            user_dislikes_id = per_conversation_info['user_dislikes']
            rec_item_id = per_conversation_info['rec_item']
            conversation_id = per_conversation_info['conversation_id']
            dialogue = get_conversation_by_id(Conversation, conversation_id)
            user_data['Conversation']["conversation_{}".format(j + 1)] = {
                "user_likes": [item_map[user_likes_id[k]] for k in range(len(user_likes_id))],
                "user_dislikes": [item_map[user_dislikes_id[k]] for k in range(len(user_dislikes_id))],
                "rec_item": [item_map[rec_item_id[k]] for k in range(len(rec_item_id))],
                "conversation_id": conversation_id,
                "dialogue": dialogue
                }
    
        return {
            "user_context": {
                "user_data": user_data, 
                "user_exists": True,
            }
        }

    except Exception as e:
        print(e)
        return {
            "user_context": {
                "user_data": user_data, 
                "user_exists": False,
            }
        }

In [5]:
import uuid

state = State()
config = config = {
        "configurable": {
            "thread_id": str(uuid.uuid4()),
            "user_id": "A30Q8X8B1S3GGT"
        }
    }

state_ = fetch_user_information(state, config)


Current Node: Fetch User Info


In [None]:
from typing import Dict, List, Any, Optional, Union, Tuple
from langchain_core.runnables import RunnableConfig
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage, AIMessage, ToolMessage
from langchain_core.messages.tool import ToolCall
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

class RouterResponse(BaseModel):
    agent_name: str = Field(description="Name of the specialized agent best suited to handle the user's current request.")
    request: str = Field(description="Concise summary of user's request and relevant context that helps the selected agent understand and address the need effectively (optional).")

class Router:
    def __init__(
        self,
        nodes: Union[List[str], Dict[str, str]], 
        system_prompt: str,
        llm: Optional[ChatOpenAI] = None,
        model: str = "",
        default_node: Optional[str] = None,
        max_preview: int = 100,
        max_history: int = 10,
        k_latest_messages: int = 6,
        max_assistant_preview: int = 20,
        default_model: str = "gpt-4o"
    ) -> None:
        """
        Initialize Router node.
        
        Args:
            nodes: List of node names or Dict of node names to descriptions
            llm: Language model instance (optional)
            model: Model name (optional)
            system_prompt: System prompt for the router (optional)
            default_node: Default node to route to (optional)
        """
        self.model = model if "gpt-4o" in model else default_model
        self.llm = llm or ChatOpenAI(model=self.model).with_structured_output(RouterResponse)
        self.nodes = nodes
        self.default_node = default_node or self._get_default_node()
        self.system_prompt = system_prompt
        self.max_preview = max_preview
        self.max_history = max_history
        self.k_latest_messages = k_latest_messages
        self.max_assistant_preview = max_assistant_preview

    def _get_default_node(self) -> str:
        """Get default node from nodes configuration."""
        return (
            self.nodes[0] if isinstance(self.nodes, list) 
            else list(self.nodes.keys())[0]
        )

    def _format_message(self, message: AnyMessage) -> Optional[str]:
        """Format a single message for history."""
        if isinstance(message, HumanMessage):
            return f"User: {message.content}"
        elif isinstance(message, AIMessage) and message.content:
            preview = message.content[:self.max_preview] + "...(preview)"
            return f"Assistant: {preview}"
        return None

    def _prepare_message_history(self, messages: List[AnyMessage]) -> str:
        """Prepare message history for router context."""
        formatted_messages = [
            msg for msg in (self._format_message(m) for m in messages)
            if msg is not None
        ]
        return "\n".join(formatted_messages)

    def _get_nodes_description(self, state: State) -> Tuple[List[str], str]:
        """Get available nodes and their description."""
        completed_nodes = state.get("completed_nodes", [])

        if isinstance(self.nodes, list):
            available_nodes = [
                node for node in self.nodes 
                if node not in completed_nodes
            ]
            nodes_description = ", ".join(available_nodes)
        else:
            available_nodes = {
                node: desc 
                for node, desc in self.nodes.items() 
                if node not in completed_nodes
            }
            nodes_description = "\n".join(
                f"- {node}: {desc}" 
                for node, desc in available_nodes.items()
            )
            available_nodes = list(available_nodes.keys())

        return available_nodes, nodes_description

    def _create_router_messages(self, 
                              state: State, 
                              nodes_description: str) -> List[Union[SystemMessage, HumanMessage]]:
        """Create messages for router LLM."""
        system_prompt = (
            f"{self.system_prompt}\n\n"
            f"Available agents:\n{nodes_description}\n\n"
            "Only return agent name"
        )
        latest_messages = self._prepare_message_history(state.get("messages", []))
        user_msg = f"Chat History:\n{latest_messages}\n\nAgents Name:"
        
        return [
            SystemMessage(content=system_prompt),
            HumanMessage(content=user_msg)
        ]

    def _add_router_tool(self, router_response: RouterResponse) -> Dict[str, Any]:
        """Add route tool to state."""
        routing_message = (
            f"The assistant is now the {router_response.agent_name}. "
            "Reflect on the above conversation between the host assistant and the user. "
            "Do not mention who you are - just act as the proxy for the assistant."
            )
        messages = [
            AIMessage(content="", tool_calls=[ToolCall(name="Router", args={"agent_name": router_response.agent_name, "request": router_response.request}, id="")]),
            ToolMessage(content=routing_message, name="Router", tool_call_id=""),
            ]
        return messages

    def __call__(self, state: State, config: RunnableConfig) -> Dict[str, Any]:
        """Route to appropriate node based on conversation state."""
        print("\nCurrent Node: Router")
        
        available_nodes, nodes_description = self._get_nodes_description(state)
        
        if not available_nodes:
            print(f"Router selected default node: {self.default_node}")
            return {"latest_router_decision": self.default_node}

        messages = self._create_router_messages(state, nodes_description)
        
        while True:
            try:
                response = self.llm.invoke(messages)
                node_name = response.agent_name

                if node_name and node_name in available_nodes:
                    break
                    
                messages.append(
                    HumanMessage(
                        content="Not a valid node! Please try again. "
                        f"Your output should be an agent name from the available agents {available_nodes}."
                    )
                )
            except Exception as e:
                print(f"Error in router: {e}")
                return {"latest_router_decision": self.default_node}

        print(f"Router selected: {node_name}")
        
        messages = self._add_router_tool(response)
        return {"messages": messages,  "latest_router_decision": node_name}