# Language Agent Tree Search
Language Agent Tree Search (LATS) by Zhou et al. is a general LLM agent search algorithm that combines reflection/evaluation and search (specifically monte-carlo tree search) to achieve better overall task performance compared to similar techniques like ReACT, Reflexion or Tree of Thoughts

In [9]:
from dotenv import load_dotenv
load_dotenv()

import re
import os
import json
import math
import base64
import asyncio
import datetime
import platform
import requests
import operator
import playwright
import numpy as np
import pandas as pd
import datetime as dt

from enum import Enum
from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Annotated
from typing import TypedDict
from operator import itemgetter
from collections import defaultdict

from IPython import display
from IPython.display import HTML
from IPython.display import Image

from langsmith import traceable

from langgraph.graph import END
from langgraph.graph import StateGraph
from langgraph.graph import MessageGraph
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_executor import ToolInvocation

from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

from langchain_core.tools import StructuredTool
from langchain_core.messages import BaseMessage
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.chat import ChatMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.function import FunctionMessage
from langchain_core.prompts.image import ImagePromptTemplate

from langchain_core.pydantic_v1 import Field
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableParallel
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser

from langchain_core.runnables.graph import CurveStyle
from langchain_core.runnables.graph import NodeColors
from langchain_core.runnables.graph import MermaidDrawMethod

from langchain import hub
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain.prompts import HumanMessagePromptTemplate
from langchain.prompts import SystemMessagePromptTemplate
from langchain.agents import create_openai_functions_agent
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.output_parsers.openai_tools import JsonOutputToolsParser

from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper

from langchain_fireworks.chat_models import ChatFireworks

It has 4 main steps:
1. **Select**: pick the best next actions based on the aggregate rewards from step (2). Either respond (if a solution is found or the max search depth is reached) or continue searching.
2. **Expand and simulate**: select the "best" 5 potential actions and execute them in parallel.
3. **Reflect + Evaluate**: observe the outcomes of these actions and score the decisions based on reflection (and possibly external feedback)
4. **Backpropagate**: update the scores of the root trajectories based on the outcomes.

In [7]:
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_PROJECT'] = 'LATS'

LATS is based on (greedy) Monte Carlo tree search. For each search step, it picks the node with the highest "upper confidence bound", which is a metric that balances exploitation (highest average reward) and exploration (lowest visits). Starting from that node, it generates N (5 in this case) new candidate actions to take, and adds them to the tree. It stops searching either when it has generated a valid solution OR when it has reached the maximum number of rollouts (search tree depth)

Our LangGraph state will be composed of two items:
1. The root of the search tree
2. The user input

In [8]:
class Reflection(BaseModel):
    reflections: str = Field(description='The critique and reflections on the sufficiency, superfluency, and general quality of the response')
    score: int = Field(description='Score from 0 - 10 on the quality of the candidate response', gte=0, lte=10)
    found_solution: bool = Field(description='Whether the response has fully solved the question or task')
    
    def as_message(self):
        return HumanMessage(content=f'Reasoning: {self.reflections}\nScore: {self.score}')
    
    @property
    def normalized_score(self) -> float:
        return self.score / 10.0

In [None]:
class Node:
    def __init__(self, messages: List[BaseMessage], reflection: Reflection, parent: Optional[Node]=None):
        self.value = 0
        self.visits = 0
        self.children = []
        self.parent = parent
        self.messages = messages
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    def __repr__(self) -> str:
        return (
            f'<Node value={self.value}, visits={self.visits},'
            f' solution={self.messages} reflection={self.reflection}/>'
        )
    
    @property
    def is_solved(self):
        '''If any solutions exist, we can end the search'''
        return self._is_solved
    
    @property
    def is_terminal(self):
        return not self.children
    
    @property
    def best_child(self):
        '''Select the child with the highest UCT to search next'''
        if not self.children:
            return None
        return max(self._get_all_children(), key=lambda child: child.upper_confidence_bound())
    
    @property
    def best_child_score(self):
        '''Return the child with the highest value'''
        if not self.children:
            return None
        return max(self.children, key=lambda child: int(child.is_solved) * child.value)
    
    @property
    def height(self) -> int:
        '''Check for how far we've rolled out the tree'''
        if self.children:
            return 1 + max([child.height for child in self.children])
        return 1
    
    def upper_confidence_bound(self, exploration_weight=1.0):
        '''Return the UCT score. This helps balance exploration vs exploitation of a branch'''
        if self.parent is None:
            raise ValueError('Cannot obtain UCT from root node')
        if self.visits == 0:
            return self.value
        
        # encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits

        # encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term
    
    def backpropagate(self, reward: float):
        '''Update the score of this node and its parents'''
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            node = node.parent

    def get_messages(self, include_reflections: bool=True):
        if include_reflections:
            return self.messages + [self.reflection.as_message()]
        return self.messages
    
    def get_trajectory(self, include_reflections: bool=True) -> List[BaseMessage]:
        '''Get messages representing this search branch'''
        messages = []
        node = self
        while node:
            messages.extend(node.get_messages(include_reflections=include_reflections)[::-1])
            node = node.parent

        # reverse the final back-tracked trajectory to return in the correct order
        return messages[::-1] # root solution, reflection, child
    
