# A* Chat

As your application grows bigger, you may want to create more than one agent, each with a different set of tools to handle different part of the problem. Orchestrating a conversation between these agents can be challenging, especially when it is difficult to determind which agent should be called at each timestep.

Seeing multi-agent chat as a path finding problem, where each message in a message history is analogous to a node, and the message history as the path, A* algorithm can be used to find the optimal path to the end of the conversation.

A* algorithm is a path finding algorithm that is widely used. It is a variant of Dijkstra's algorithm, which is used to find the shortest path between two nodes in a graph. A* algorithm is an extension of Dijkstra's algorithm, which adds a heuristic function to guide the search towards the goal. The heuristic function is an estimation of the distance between the current node and the goal. The algorithm will always choose the node with the lowest cost, which is the sum of the distance from the start node to the current node and the heuristic function.

Using A* Chat, instead of having to manually program agent behaviours, you can simply define the heuristic function that estimates how *close* the message history to the goal, and the algorithm will automatically orchestrate the conversation between the agents to reach the goal.

## Example 0: A* Chat to write test problems and solutions according to Bloom’s Taxonomy of Educational Objectives
We will use A* Chat to orchestrate a conversation between a two agents:
- test_giver: writes a set of questions and marking scheme about a subject
- bloom: apply [Bloom’s Taxonomy of Educational Objectives](https://en.wikipedia.org/wiki/Bloom%27s_taxonomy#:~:text=Bloom's%20taxonomy%20is%20a%20set,cognitive%2C%20affective%20and%20psychomotor%20domains.) to evaluate the quality of the test problems

In [7]:
from agentx.agent import Agent
from agentx.schema import GenerationConfig, Message, Content
from agentx.groupchat import astar_chat, reconstruct_path
from agentx.tool import Tool
from pydantic import BaseModel, Field
from typing import Dict, List, Literal
from functools import partial
from dotenv import load_dotenv
from rich import print as rich_print
import os

load_dotenv()

generation_config = GenerationConfig(
    api_type='azure',
    api_key=os.environ.get('AZURE_OPENAI_KEY'),
    base_url=os.environ.get('AZURE_OPENAI_ENDPOINT'),
    azure_deployment='gpt-35',
)

# this agent will write the question / answer pair
test_writer = Agent(
    name='test_writer',
    generation_config=generation_config,
    system_prompt='''According to the user request and reviewer feedback, write a test question / answer pair.
You must respond in the following format:
Question: <question>
Answer: <answer>
''',
)

class TestQuestion(BaseModel):
    question:str
    answer:str
    explanation:str
    bloom_objective:Literal['knowledge', 'comprehension', 'application', 'analysis', 'synthesis', 'evaluation']

# Define agents and tools for reviewing the test question
bloom = Agent(
    name='bloom_expert',
    generation_config=generation_config,
    system_prompt="You are an education expert and highly knowledgeable about Bloom's taxonomy of education objective.",
)

class BloomScore(BaseModel):
    score:float = Field(0, ge=0, le=10),
    improvement_suggestion:str
    objective:Literal['knowledge', 'comprehension', 'application', 'analysis', 'synthesis', 'evaluation']

class BloomScorer(Tool):
    def __init__(
        self,
        agent:Agent,
        bloom_objective:Literal['knowledge', 'comprehension', 'application', 'analysis', 'synthesis', 'evaluation'],
        **kwargs
    ):
        super().__init__(input_model=TestQuestion, **kwargs)
        self.agent = agent
        self.bloom_objective = bloom_objective

    def run(self, **kwargs) -> str:
        raise NotImplementedError()

    async def a_run(self, **kwargs) -> str:
        test_question = self.input_model(**kwargs)
        response = await self.agent.a_generate_response(
            messages=[
                Message(
                    role='user',
                    content=Content(
                        text='''Please give a score of 10 to represent the test question's quality at the {bloom_objective} level of Bloom's Taxonomy.
The test question:
{test_question}

You must reply an JSON object.'''.format(
    bloom_objective=self.bloom_objective,
    test_question=test_question.model_dump()
),
                    ),
                )
            ],
            output_model=BloomScore
        )
        return response[-1].content.text

bloom_scoring_tools = [
    BloomScorer(
        bloom,
        bloom_objective,
        name='{bloom_objective}_scorer'.format(bloom_objective=bloom_objective),
        description='''Give a score of 10 to represent that the test question is at the {bloom_objective} level of Bloom's Taxonomy and the question's quality.'''.format(bloom_objective=bloom_objective),
    ) for bloom_objective in ['knowledge', 'comprehension', 'application', 'analysis', 'synthesis', 'evaluation']
]

# this agent will review the question / answer pair
reviewer = Agent(
    name='reviewer',
    generation_config=generation_config,
    system_prompt='''Use the tool you have been provided to review the question. Critically access if the test question is at the right level and quality of Bloom's Taxonomy. Then, give feedback on how to improve the question.''',
    tools=bloom_scoring_tools
)

# At each timestep, A* minimize heuristic + cost
# heuristic: an estimation of the distance between the current state and the goal state
# cost: the distance between the start state and the current state

# Heuristic is the sum of the difference between the current Bloom score and the target Bloom score
class BloomReport(BaseModel):
    knowledge:float = Field(0, ge=0, le=10)
    comprehension:float = Field(0, ge=0, le=10)
    application:float = Field(0, ge=0, le=10)
    analysis:float = Field(0, ge=0, le=10)
    synthesis:float = Field(0, ge=0, le=10)
    evaluation:float = Field(0, ge=0, le=10)

extractor = Agent(
    name='extractor',
    generation_config=generation_config,
    system_prompt='''Extract the latest Bloom score from the messages history. You must reply an JSON object.''',
)

def heuristic(
    messages:List[Message], 
    target:Dict[Literal['knowledge', 'comprehension', 'application', 'analysis', 'synthesis', 'evaluation'], float]
) -> float:
    if 'test_writer' not in [message.name for message in messages]:
        # no test question has been written
        return 10
    if 'reviewer' not in [message.name for message in messages]:
        # no review has been made
        return 10
    if len(messages) < 2 or messages[-1].name != 'reviewer' or not messages[-2].name.endswith('_scorer'):
        # the last message is not from the test writer
        return None
    
    bloom_report = extractor.generate_response(
        messages=[
            Message(
                role='user',
                content=Content(
                    text='Based on this chat history: {history}'.format(
                        history=[message.model_dump_json(
                            exclude_unset=True,
                            exclude_none=True
                        ) for message in messages]
                    ),
                ),
            )
        ] + [
            Message(
                role='user',
                content=Content(
                    text='Extract the latest Bloom scores. You must reply an JSON object.'),
            ),
        ],
        output_model=BloomReport,
    )[-1].content.text
    bloom_report = BloomReport.model_validate_json(bloom_report).model_dump()
    
    # print out for easier debugging and illustration
    rich_print([message for message in messages if message.name=='test_writer'][-1].content.text)
    rich_print(bloom_report)

    difference = sum(
        [
            abs(bloom_report[objective] - target[objective]) for objective in ['knowledge', 'comprehension', 'application', 'analysis', 'synthesis', 'evaluation']
        ]
    )
    
    return difference / sum(target.values()) * 10

# Cost is the number of LLM calls
def cost(messages:List[Message], next_message:List[Message]) -> float:
    cost = len(next_message)
    return cost

In [19]:

# notice how the target is iteratively reached

target = {'knowledge': 10, 'comprehension': 10, 'application': 0, 'analysis': 0, 'synthesis': 0, 'evaluation': 0}
init_message = Message(
    role = 'user',
    content = Content(
        text = '''Write a multiple choice test question / answer pair about Mendelian Genetics.
The student should be tested on these areas {target}'''.format(
            target=target
        )
    )
)

reconstructed_path, came_from, cost_so_far, heuristic_map, hash_map = await astar_chat(
    agents = [test_writer, reviewer],
    heuristic = partial(heuristic, target=target),
    cost = cost,
    messages = [init_message],
    threshold = 2,
    n_replies = 1,
    max_iteration = 8,
)

  0%|          | 0/8 [00:00<?, ?it/s]

 12%|█▎        | 1/8 [00:06<00:44,  6.34s/it]

 50%|█████     | 4/8 [00:13<00:10,  2.64s/it]

 62%|██████▎   | 5/8 [00:19<00:11,  3.68s/it]

 88%|████████▊ | 7/8 [00:28<00:03,  3.92s/it]

100%|██████████| 8/8 [00:35<00:00,  4.47s/it]


In [17]:
reconstructed_path

[Message(role='user', content=Content(text="Write a multiple choice test question / answer pair about Mendelian Genetics.\nThe student should be tested on these areas {'knowledge': 10, 'comprehension': 5, 'application': 0, 'analysis': 0, 'synthesis': 0, 'evaluation': 0}", files=None, urls=None, tool_calls=None, tool_response=None), name=None)]

In [18]:
heuristic_map

{(6469935075520154183,): 10,
 (-4580926200682588662,): 10,
 (-1518112528675522610,
  4714398374246581160,
  7163238287259037138,
  -7837850031291768279): 10,
 (-6065979392944444587,
  1658435270493097711,
  1385579723353972364,
  -6118312179635416503): 10.0,
 (-1788646112146464354,): 10,
 (5264585814736223471,): 10,
 (5703292310741949036,): 10.0,
 (-6038740678385575172,): 10.0,
 (6570973958616454994,): 10,
 (-5234326009116715241,): 10}