# Tutorial 24: LATS (Language Agent Tree Search)

This tutorial demonstrates the **LATS (Language Agent Tree Search)** pattern, which applies Monte Carlo Tree Search (MCTS) algorithms to language agents for complex reasoning tasks.

## Overview

Instead of exploring a single solution path like traditional agents, LATS:
- **Explores multiple paths** in parallel using tree search
- **Balances exploration vs exploitation** using UCB (Upper Confidence Bound)
- **Learns from evaluations** via backpropagation up the tree
- **Selects best solution** from all explored paths

## MCTS Algorithm

LATS follows the classic MCTS cycle:

1. **Selection**: Choose best node to expand using UCB
2. **Expansion**: Generate N candidate actions from LLM
3. **Simulation**: Execute tools and evaluate with reflection
4. **Backpropagation**: Update node values up to root

```
                    [Root]
                   v=0.5, n=10
                       │
         ┌─────────────┼─────────────┐
         │             │             │
    [Child 1]     [Child 2]     [Child 3]
    v=0.7, n=5    v=0.3, n=3    v=0.6, n=2
         │                            │
    ┌────┼────┐                  ┌───┴───┐
 [C1.1] [C1.2]                [C3.1] [C3.2]
```

## When to Use LATS

Use LATS when:
- Task requires complex multi-step reasoning
- Multiple solution paths should be explored
- You want to balance trying new approaches vs using what works
- Quality matters more than speed

**Trade-offs**:
- ✅ Explores multiple paths (better solutions)
- ✅ Learns which approaches work (via backpropagation)
- ✅ Can recover from mistakes (tries alternatives)
- ❌ More LLM calls (slower, more expensive)
- ❌ Requires complexity limits for local models

## Setup

In [None]:
from langchain_ollama import ChatOllama
from langchain_core.tools import tool
from langgraph_ollama_local.patterns.lats import (
    Node,
    Reflection,
    create_lats_graph,
    run_lats_task,
    select,
    get_best_solution,
)

# Initialize LLM with temperature for diversity
llm = ChatOllama(
    model="llama3.2:3b",
    temperature=0.7,  # Higher temp for diverse candidates
)

## Part 1: Understanding MCTS Components

### 1.1 Reflection Model

The `Reflection` model provides structured evaluation of each candidate:

In [None]:
# Example reflections
good_reflection = Reflection(
    reflections="Good approach, uses correct tool, but answer incomplete",
    score=7,
    found_solution=False,
)

perfect_reflection = Reflection(
    reflections="Perfect! Correct reasoning, complete answer with evidence",
    score=10,
    found_solution=True,
)

print(f"Good reflection score: {good_reflection.score}/10 ({good_reflection.normalized_score})")
print(f"Perfect reflection score: {perfect_reflection.score}/10 ({perfect_reflection.normalized_score})")
print(f"Solution found: {perfect_reflection.found_solution}")

### 1.2 Node Class

The `Node` class represents a state in the search tree:

In [None]:
from langchain_core.messages import HumanMessage, AIMessage

# Create root node
root = Node(messages=[], reflection=None, parent=None)
print(f"Root depth: {root.depth}")
print(f"Root visits: {root.visits}")
print(f"Root value: {root.value}")
print(f"Is terminal: {root.is_terminal}")

# Create child node with reflection
child_reflection = Reflection(
    reflections="Made progress on the task",
    score=6,
    found_solution=False,
)
child = Node(
    messages=[AIMessage(content="Let me search for information")],
    reflection=child_reflection,
    parent=root,
)
root.children.append(child)

print(f"\nAfter creating child:")
print(f"Child depth: {child.depth}")
print(f"Child visits: {child.visits}")
print(f"Child value: {child.value}")
print(f"Root visits (backpropagated): {root.visits}")
print(f"Root value (backpropagated): {root.value}")

### 1.3 UCB (Upper Confidence Bound)

UCB balances **exploitation** (using what works) vs **exploration** (trying new things):

```
UCB = avg_reward + exploration_weight * sqrt(ln(parent_visits) / node_visits)
```

In [None]:
# Create two children with different scores
reflection_ok = Reflection(reflections="OK approach", score=5, found_solution=False)
reflection_good = Reflection(reflections="Good approach", score=8, found_solution=False)

child1 = Node(messages=[AIMessage(content="approach 1")], reflection=reflection_ok, parent=root)
child2 = Node(messages=[AIMessage(content="approach 2")], reflection=reflection_good, parent=root)

root.children = [child1, child2]

# Visit child2 more times
child2.backpropagate(0.8)
child2.backpropagate(0.7)

print("Child 1 (lower score, less visited):")
print(f"  Value: {child1.value:.3f}, Visits: {child1.visits}")
print(f"  UCB (low exploration): {child1.upper_confidence_bound(0.5):.3f}")
print(f"  UCB (high exploration): {child1.upper_confidence_bound(2.0):.3f}")

print("\nChild 2 (higher score, more visited):")
print(f"  Value: {child2.value:.3f}, Visits: {child2.visits}")
print(f"  UCB (low exploration): {child2.upper_confidence_bound(0.5):.3f}")
print(f"  UCB (high exploration): {child2.upper_confidence_bound(2.0):.3f}")

print("\nWith high exploration weight, less-visited child1 gets bonus!")

### 1.4 Selection Algorithm

The `select` function chooses the best leaf node to expand:

In [None]:
# Create a deeper tree
root2 = Node(messages=[], reflection=None, parent=None)

# First level
r1 = Reflection(reflections="approach A", score=6, found_solution=False)
r2 = Reflection(reflections="approach B", score=7, found_solution=False)
child_a = Node(messages=[AIMessage(content="A")], reflection=r1, parent=root2)
child_b = Node(messages=[AIMessage(content="B")], reflection=r2, parent=root2)
root2.children = [child_a, child_b]

# Second level (only under child_b)
r3 = Reflection(reflections="B.1", score=8, found_solution=False)
child_b1 = Node(messages=[AIMessage(content="B.1")], reflection=r3, parent=child_b)
child_b.children = [child_b1]

# Select best leaf to expand
selected = select(root2)
print(f"Selected node depth: {selected.depth}")
print(f"Selected node messages: {[m.content for m in selected.messages]}")
print(f"\nSelection traversed the tree to find the best leaf!")

## Part 2: Building a LATS Agent

### 2.1 Define Tools

Let's create simple tools for a math problem:

In [None]:
@tool
def calculator(expression: str) -> str:
    """Evaluate a mathematical expression.
    
    Args:
        expression: Math expression to evaluate (e.g., "2+2", "10*5")
    
    Returns:
        Result of the calculation
    """
    try:
        # Safe eval for basic math
        result = eval(expression, {"__builtins__": {}}, {})
        return str(result)
    except Exception as e:
        return f"Error: {e}"

@tool
def get_number_info(number: int) -> str:
    """Get information about a number.
    
    Args:
        number: Number to get info about
    
    Returns:
        Information about the number
    """
    info = []
    info.append(f"Number: {number}")
    info.append(f"Even: {number % 2 == 0}")
    info.append(f"Square: {number ** 2}")
    return "\n".join(info)

tools = [calculator, get_number_info]

### 2.2 Create LATS Graph

Create the graph with complexity limits for local models:

In [None]:
# For 3B model, use conservative limits
graph = create_lats_graph(
    llm=llm,
    tools=tools,
    max_depth=3,           # Limit tree depth
    max_width=2,           # 2 candidates per expansion
    max_iterations=10,     # Max 10 total nodes
    exploration_weight=1.0, # Balanced exploration
)

print("LATS graph created!")
print(f"Max depth: 3")
print(f"Max width: 2 candidates per expansion")
print(f"Max iterations: 10 total nodes")

### 2.3 Run Tree Search

Let's solve a problem that benefits from exploring multiple approaches:

In [None]:
task = """What is the square of the sum of 15 and 25?
Break this down step by step and calculate the final answer."""

print(f"Task: {task}")
print("\nRunning LATS tree search...\n")

result = run_lats_task(graph, task)

print(f"\n{'='*60}")
print("SEARCH COMPLETE")
print(f"{'='*60}")
print(f"Total nodes explored: {result['total_nodes']}")
print(f"Tree height: {result['root'].height}")
print(f"Solution found: {result['best_solution'].reflection.found_solution if result['best_solution'].reflection else False}")

### 2.4 Inspect the Search Tree

In [None]:
def print_tree(node: Node, prefix: str = "", is_last: bool = True):
    """Print tree structure."""
    connector = "└── " if is_last else "├── "
    
    # Node info
    if node.reflection:
        info = f"score={node.reflection.score}/10, visits={node.visits}, value={node.value:.2f}"
        if node.reflection.found_solution:
            info += " ✓ SOLUTION"
    else:
        info = f"ROOT, visits={node.visits}, value={node.value:.2f}"
    
    print(f"{prefix}{connector}{info}")
    
    # Children
    extension = "    " if is_last else "│   "
    for i, child in enumerate(node.children):
        is_last_child = i == len(node.children) - 1
        print_tree(child, prefix + extension, is_last_child)

print("\nSearch Tree Structure:")
print("="*60)
print_tree(result['root'])

### 2.5 Examine Best Solution

In [None]:
best = result['best_solution']

print("Best Solution Trajectory:")
print("="*60)
for i, msg in enumerate(result['best_trajectory'], 1):
    print(f"\nStep {i} ({msg.__class__.__name__}):")
    print(f"{msg.content}")

if best.reflection:
    print(f"\n{'='*60}")
    print("Reflection:")
    print(f"Score: {best.reflection.score}/10")
    print(f"Critique: {best.reflection.reflections}")
    print(f"Solution found: {best.reflection.found_solution}")

## Part 3: Complexity Limits for Local Models

### 3.1 Understanding the Trade-offs

LATS is computationally expensive. Each expansion:
- Generates `max_width` candidates (multiple LLM calls)
- Executes tools for each candidate
- Reflects on each candidate

**Recommended limits by model size:**

In [None]:
complexity_limits = {
    "3B-8B": {
        "max_depth": 3,
        "max_width": 2,
        "max_iterations": 10,
        "timeout": "30s",
        "notes": "Conservative limits for small models",
    },
    "13B-34B": {
        "max_depth": 4,
        "max_width": 3,
        "max_iterations": 20,
        "timeout": "45s",
        "notes": "Moderate complexity for medium models",
    },
    "70B+": {
        "max_depth": 5,
        "max_width": 4,
        "max_iterations": 30,
        "timeout": "60s",
        "notes": "Higher complexity for large models",
    },
}

import json
print(json.dumps(complexity_limits, indent=2))

### 3.2 Creating Graphs for Different Model Sizes

In [None]:
def create_lats_for_model(model_name: str, model_size: str):
    """Create LATS graph with appropriate limits for model size."""
    limits = complexity_limits.get(model_size, complexity_limits["3B-8B"])
    
    llm = ChatOllama(model=model_name, temperature=0.7)
    
    graph = create_lats_graph(
        llm=llm,
        tools=tools,
        max_depth=limits["max_depth"],
        max_width=limits["max_width"],
        max_iterations=limits["max_iterations"],
    )
    
    return graph

# Example configurations
print("Example configurations:")
print("\n1. Small model (llama3.2:3b):")
print("   graph = create_lats_for_model('llama3.2:3b', '3B-8B')")
print("\n2. Medium model (llama3.1:13b):")
print("   graph = create_lats_for_model('llama3.1:13b', '13B-34B')")
print("\n3. Large model (llama3.1:70b):")
print("   graph = create_lats_for_model('llama3.1:70b', '70B+')")

## Part 4: Advanced Patterns

### 4.1 Adjusting Exploration Weight

The exploration weight controls UCB behavior:

In [None]:
# Low exploration (exploit what works)
graph_exploit = create_lats_graph(
    llm=llm,
    tools=tools,
    max_depth=3,
    max_width=2,
    exploration_weight=0.3,  # Low - focus on high-scoring paths
)

# High exploration (try diverse approaches)
graph_explore = create_lats_graph(
    llm=llm,
    tools=tools,
    max_depth=3,
    max_width=2,
    exploration_weight=2.0,  # High - try less-visited paths
)

print("Created two graphs with different exploration strategies:")
print("1. graph_exploit: exploration_weight=0.3 (focus on best paths)")
print("2. graph_explore: exploration_weight=2.0 (try diverse paths)")

### 4.2 Custom Reflection Criteria

You can implement custom reflection logic by modifying the reflection prompt or using multi-criteria scoring.

## Part 5: Best Practices

### 5.1 When to Use LATS

✅ **Good use cases:**
- Complex reasoning requiring multiple steps
- Tasks where exploring alternatives is valuable
- Problems with multiple valid approaches
- When quality matters more than speed

❌ **Avoid LATS when:**
- Simple, straightforward tasks
- Tight latency requirements
- Limited compute resources
- Single obvious solution path

### 5.2 Complexity Management

**Start conservative:**
```python
graph = create_lats_graph(
    llm=llm,
    tools=tools,
    max_depth=2,      # Start small
    max_width=2,      # Few candidates
    max_iterations=6, # Limited nodes
)
```

**Then increase gradually:**
```python
# Monitor performance, then increase
graph = create_lats_graph(
    llm=llm,
    tools=tools,
    max_depth=3,
    max_width=3,
    max_iterations=15,
)
```

### 5.3 Monitoring Tree Growth

Always check tree statistics:

In [None]:
def analyze_search(result):
    """Analyze LATS search results."""
    root = result['root']
    all_nodes = [root] + root._get_all_children()
    
    terminal_nodes = [n for n in all_nodes if n.is_terminal]
    solved_nodes = [n for n in all_nodes if n.reflection and n.reflection.found_solution]
    
    print(f"Total nodes: {len(all_nodes)}")
    print(f"Terminal nodes: {len(terminal_nodes)}")
    print(f"Solved nodes: {len(solved_nodes)}")
    print(f"Tree height: {root.height}")
    print(f"Average node value: {sum(n.value for n in all_nodes) / len(all_nodes):.3f}")
    
    if solved_nodes:
        print(f"\nBest solved node score: {max(n.reflection.score for n in solved_nodes)}/10")

# Example usage
print("Search Analysis:")
print("="*60)
analyze_search(result)

## Summary

### Key Takeaways

1. **LATS = MCTS for Language Agents**
   - Select via UCB
   - Expand with multiple candidates
   - Simulate with tool execution
   - Backpropagate scores

2. **Complexity Management is Critical**
   - Set limits based on model size
   - Monitor tree growth
   - Start conservative, increase gradually

3. **Trade-offs**
   - Better solutions via exploration
   - More LLM calls (slower, costlier)
   - Best for complex reasoning tasks

### Next Steps

- Experiment with different exploration weights
- Try LATS on your complex reasoning tasks
- Compare with single-path agents (ReAct, ReWOO)
- Tune complexity limits for your use case

### Related Patterns

- **Tutorial 21 (Plan-and-Execute)**: Simpler planning approach
- **Tutorial 22 (Reflection)**: Single-path improvement
- **Tutorial 23 (Reflexion)**: Multi-attempt learning
- **Tutorial 25 (ReWOO)**: Efficient planning alternative