In [None]:
from csagent.supervisor.graph import supervisor_graph
from csagent.router_agent.graph import router_graph
from csagent.react_agent.graph import react_agent_graph
from utils import run_langsmith_eval
from csagent.configuration import Configuration
from pydantic import BaseModel, Field
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.chat_models import init_chat_model
from langdetect import detect, DetectorFactory
from typing import Literal
import time
import logging
import sys
import json
from dotenv import load_dotenv

logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    logger.addHandler(handler)

load_dotenv()

In [None]:
language = "id"  # "en" #

configuration = Configuration(
    model="LongCat-Flash-Chat",
    model_medium="GLM-4.6V-Flash",
    model_small="google_genai:gemma-3-12b-it",
    language=language,
)


def create_target_function(graph, graph_name: str):
    """
    Factory function to create target functions for any graph.

    Args:
        graph: The LangGraph graph to invoke
        graph_name: Name of the graph for logging purposes
    """

    def eval_graph(config: Configuration):
        def invoke_graph(inputs: dict) -> dict:
            """
            Evaluate the graph with given inputs.
            """
            final_result = {}
            tools_called = []
            try:
                # result = graph.invoke(
                #     inputs,
                #     context=config,
                # )
                # return result

                for namespace, chunk in graph.stream(
                    inputs,
                    subgraphs=True,
                    stream_mode="debug",
                    context=config,
                ):
                    # print(namespace, chunk)
                    if not namespace:
                        payload = chunk.get("payload") or {}
                        result = payload.get("result") or {}
                        messages = result.get("messages")

                        if messages:
                            # Record final response
                            final_result["messages"] = messages
                            # print("response: ", messages)

                    if chunk.get("type") == "task":
                        payload = chunk.get("payload", {})
                        if payload.get("name") == "tools":
                            # Record tool calls
                            tool_call = payload.get("input", {}).get("tool_call")
                            if tool_call:
                                tools_called.append(
                                    {
                                        tool_call.get(
                                            "name", "tool_name"
                                        ): tool_call.get("args", {})
                                    }
                                )
                                # print("tool_called: ", {tool_call["name"]: tool_call["args"]})

                final_result["tools_called"] = tools_called
                # print("final_result", final_result)
                time.sleep(15)
                return final_result

            except Exception:
                logger.exception(f"Error in {graph_name}")
                return {}

        return invoke_graph

    return eval_graph


# Usage:
target_function = create_target_function(supervisor_graph, "supervisor_graph")
target_function_router = create_target_function(router_graph, "router_graph")
target_function_react = create_target_function(react_agent_graph, "react_graph")

In [None]:
# response = target_function(config=Configuration())(
# response = target_function_router(config=Configuration())(
# response = target_function_react(config=Configuration())(
#     # {"messages": [{"role": "user", "content": "Tell me about orchid"}]}
#     # {"messages": [{"role": "user", "content": "Where can I find a product in penang?"}]}
#     {
#         "messages": [
#             {
#                 "role": "user",
#                 "content": "What is Orchid? Where can I find a product in penang?",
#             }
#         ]
#     }
# )

In [None]:
class EvaluationModel(BaseModel):
    """Structured output of the evaluation."""

    rationale: str = Field(
        description="Rationale that explains the alignment between the AI output and the ground truth."
    )
    score: Literal[0, 1] = Field(
        description="0 means the AI output is completely wrong and not related to the ground truth, 1 means the AI output is completely correct and aligned with the ground truth."
    )


def llm_judge(inputs: str, ai_output: str, ground_truth: str, model: str):
    """Judge the AI output based on the ground truth."""

    instructions = """
        You are given a human question and a pair consisting of a ground truth and an AI-generated output. Your task is to evaluate how well the AI output aligns with the ground truth in the context of the human question.
        1. Provide a brief reasoning (1-2 sentences) explaining the degree of alignment between the AI output and the ground truth.
        2. Assign a binary score:
            - 1 if the AI output aligns with the ground truth.
            - 0 if the AI output does not align with the ground truth.
        Keep your reasoning concise, objective, and focused only on the alignment. Do not add extra commentary, suggestions, or subjective opinions.

        Human Question: {inputs}
        AI output: {ai_output}
        Ground truth: {ground_truth}

        Format Instruction:
        {format_instructions}
    """

    parser = PydanticOutputParser(pydantic_object=EvaluationModel)
    chat_prompt = ChatPromptTemplate.from_messages(
        [
            ("human", instructions),
        ]
    ).partial(
        inputs=inputs,
        ai_output=ai_output,
        ground_truth=ground_truth,
        format_instructions=parser.get_format_instructions(),
    )
    try:
        llm = init_chat_model(model, temperature=0)

        response = llm.invoke(chat_prompt.invoke({}))
        # Extract text content from AIMessage before parsing
        response = parser.parse(response.content)
        return response
    except Exception:
        logger.exception("Failed to judge AI output.")
        return EvaluationModel(rationale="Error. Failed to judge AI output", score=0)

In [None]:
def llm_alignment_evaluator(
    inputs: dict, outputs: dict, reference_outputs: dict
) -> list:
    """LLM-as-judge alignment evaluator."""

    try:
        input_content = inputs["messages"][-1]["content"]
        output_content = outputs["messages"][-1].content
        reference_content = reference_outputs["content"]
    except (KeyError, IndexError) as e:
        logger.exception("Invalid data structure in evaluator.")
        return [
            {"key": "alignment_score", "score": 0},
            {
                "key": "alignment_reasoning",
                "value": f"Error: Invalid data structure - {e}",
            },
        ]

    response_alignment = llm_judge(
        input_content,
        output_content,
        reference_content,
        configuration.model_small,
    )
    return [
        {
            "key": "alignment_score",
            "score": response_alignment.score,
        },
        {"key": "alignment_reasoning", "value": response_alignment.rationale},
    ]

In [None]:
def language_evaluator(inputs: dict, outputs: dict, reference_outputs: dict) -> dict:
    """Check whether AI output and reference outputs are in the same language."""

    DetectorFactory.seed = 0  # Makes results reproducible
    sentence1 = detect(outputs["messages"][-1].content)
    sentence2 = detect(reference_outputs["content"])

    is_same_language = sentence1 == sentence2

    return {
        "key": "is_same_language",
        "score": is_same_language,
    }

In [None]:
def check_tools_called(expected_tools: list, actual_tools: list) -> dict:
    """
    Compare expected tools with actual tools called.

    Args:
        expected_tools: List of expected tool calls with parameters
        actual_tools: List of actual tool calls made

    Returns:
        dict: Results with match status, missing tools, extra tools, and parameter mismatches
    """

    # Edge case: Handle None or empty inputs
    if expected_tools is None:
        expected_tools = []
    if actual_tools is None:
        actual_tools = []

    # Normalize to lists if single dict provided
    if isinstance(expected_tools, dict):
        expected_tools = [expected_tools]
    if isinstance(actual_tools, dict):
        actual_tools = [actual_tools]

    results = {
        "is_match": True,
        "missing_tools": [],
        "extra_tools": [],
        "parameter_mismatches": [],
        "details": [],
    }

    # Create normalized versions for comparison
    def normalize_tool(tool):
        """Extract tool name and parameters"""
        if not tool:
            return None
        tool_name = list(tool.keys())[0]
        params = tool[tool_name]
        return (tool_name, params if isinstance(params, dict) else {})

    expected_normalized = [normalize_tool(t) for t in expected_tools if t]
    actual_normalized = [normalize_tool(t) for t in actual_tools if t]

    # Track which actual tools have been matched
    matched_actual = [False] * len(actual_normalized)

    # Check each expected tool
    for exp_tool_name, exp_params in expected_normalized:
        found_match = False

        for i, (act_tool_name, act_params) in enumerate(actual_normalized):
            if matched_actual[i]:
                continue

            # Check if tool names match
            if exp_tool_name == act_tool_name:
                # Check if parameters match
                param_match = True
                mismatches = []

                # Check all expected parameters
                for key, expected_value in exp_params.items():
                    actual_value = act_params.get(key)

                    # Handle case-insensitive comparison for strings
                    if isinstance(expected_value, str) and isinstance(
                        actual_value, str
                    ):
                        if expected_value.lower() != actual_value.lower():
                            param_match = False
                            mismatches.append(
                                {
                                    "tool": exp_tool_name,
                                    "parameter": key,
                                    "expected": expected_value,
                                    "actual": actual_value,
                                }
                            )
                    elif expected_value != actual_value:
                        param_match = False
                        mismatches.append(
                            {
                                "tool": exp_tool_name,
                                "parameter": key,
                                "expected": expected_value,
                                "actual": actual_value,
                            }
                        )

                # Check for unexpected parameters
                for key in act_params:
                    if key not in exp_params:
                        mismatches.append(
                            {
                                "tool": exp_tool_name,
                                "parameter": key,
                                "expected": None,
                                "actual": act_params[key],
                                "type": "unexpected_parameter",
                            }
                        )

                if param_match and not mismatches:
                    found_match = True
                    matched_actual[i] = True
                    results["details"].append(
                        {"tool": exp_tool_name, "status": "matched"}
                    )
                    break
                elif mismatches:
                    results["parameter_mismatches"].extend(mismatches)

        if not found_match:
            results["missing_tools"].append({exp_tool_name: exp_params})
            results["is_match"] = False

    # Check for extra tools that weren't expected
    for i, (act_tool_name, act_params) in enumerate(actual_normalized):
        if not matched_actual[i]:
            results["extra_tools"].append({act_tool_name: act_params})
            results["is_match"] = False

    # Final match determination
    if (
        results["missing_tools"]
        or results["extra_tools"]
        or results["parameter_mismatches"]
    ):
        results["is_match"] = False

    return results


def is_correct_tool_called(
    inputs: dict, outputs: dict, reference_outputs: dict
) -> list:
    """Check if the number of tools called is correct."""

    try:
        # input_content = inputs["messages"][-1]["content"]
        output_tools_called = outputs["tools_called"]
        reference_tools_called = reference_outputs["expected_tools_called"]
        result = check_tools_called(reference_tools_called, output_tools_called)
    except (KeyError, IndexError):
        logger.exception("Invalid data structure in tool_call_number_evaluator.")
        return {"key": "tool_count", "score": 0}

    return [
        {
            "key": "is_correct_tool_called",
            "score": result["is_match"],
        },
        {
            "key": "missing_tools",
            "value": json.dumps(result["missing_tools"]),
        },
        {
            "key": "extra_tools",
            "value": json.dumps(result["extra_tools"]),
        },
        {
            "key": "parameter_mismatches",
            "value": json.dumps(result["parameter_mismatches"]),
        },
    ]

In [None]:
# Select which graph to evaluate
EVAL_TARGET = "react"  # or "router"  # or "supervisor"
target_mapping = {
    "react": target_function_react,
    "router": target_function_router,
    "supervisor": target_function,
}
target = target_mapping[EVAL_TARGET]


run_langsmith_eval(
    target(config=configuration),
    # "CS Agent Evaluation",
    "CS Agent Test Evaluation",
    [llm_alignment_evaluator, language_evaluator, is_correct_tool_called],
    configuration.model,
    split_name=language,
)