<a href="https://colab.research.google.com/github/Shushman/jkg-shushman/blob/main/MultiAgentMaxPlusLLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This notebook illustrates ideas from scalable anytime multi-agent coordination to two LLM-based settings:

1.  **Instruction Tuning (Meta-Prompting):** This section implements a meta-prompting framework where multiple LLM agents collaborate in a coordination graph to propose, refine, and evaluate candidate prompts.

2.  **Adaptive Divide-and-Conquer Arithmetic:** This section demonstrates a strategy for solving complex multi-step arithmetic problems by adaptively breaking them down into smaller, binary sub-expressions. Specialized agents then solve these sub-problems in parallel, and the results are progressively substituted back into the main expression until a final answer is reached. This approach aims to handle complex arithmetic in an interruptible and potentially more robust manner.

The notebook includes code for setting up the environment, loading and preparing datasets, defining the necessary DSPy modules for each method, and running experiments to evaluate their performance.

# Sec 1: Setup and Imports

This section configures API keys, imports necessary libraries, and defines our large language model (LLM) wrappers and datasets.

## Environment Configuration

In [None]:
!pip install -U dspy==2.6.10 litellm

In [None]:
# Configure API keys from Colab user data
import os
from google.colab import userdata


os.environ["OPENAI_API_KEY"] = userdata.get("openai_api_key")
os.environ["ANTHROPIC_API_KEY"] = userdata.get("anthropic_api_key")

## Imports

In [None]:
# Standard libraries
import json
import random
from collections import OrderedDict

# Numerical utilities
import numpy as np
from sklearn.model_selection import train_test_split

# DSPy framework and evaluation
import dspy
from dspy.datasets import DataLoader, MATH
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
from dspy.evaluate import Evaluate
from litellm import completion
import textwrap

dspy.configure(experimental=True)

## Load and Prepare Datasets

### [GSM8K](https://huggingface.co/datasets/openai/gsm8k)

Natively available in DSPy already

In [None]:
# GSM8K math word problems
gsm8k = GSM8K()
gsm8k_train, gsm8k_dev, gsm8k_test = gsm8k.train, gsm8k.dev, gsm8k.test
gsm8k_train_dev = gsm8k_train + gsm8k_dev

### MetaPrompting Datasets

We load several datasets from the [Meta-Prompting](https://github.com/suzgunmirac/meta-prompting) benchmark and split them into train/dev.

In [None]:
# Helper to load and split datasets
dl = DataLoader()
meta_prompter = dl.from_huggingface("turingmachine/meta-prompting")

def preprocess_metaprompter_dataset(ds_name: str, test_size: float = 0.5):
    """
    Load a MetaPrompter dataset and split into train/dev.
    Sets input keys for DSPy compatibility.
    """
    ds = meta_prompter[ds_name]
    for example in ds:
        example._input_keys = {"input"}
    train, dev = train_test_split(ds, test_size=test_size, random_state=42)
    return train, dev

# Define metric for MetaPrompter tasks
def meta_prompter_metric(example, response, trace=None) -> int:
    return int(example.target == response.answer)

**Geometric Shapes and Word Sorting**

In [None]:
geom_shapes_train, geom_shapes_dev = preprocess_metaprompter_dataset("GeometricShapes")
word_sorting_train, word_sorting_dev = preprocess_metaprompter_dataset("WordSorting")

**Multi-Step Arithmetic and Extended Combinations**

In [None]:
multi_step_train, multi_step_dev = preprocess_metaprompter_dataset("MultistepArithmeticTwo")
multi_step_combined = multi_step_train + multi_step_dev

def combine_multi_step_arithmetic(ex1, ex2):
    """
    Combine two arithmetic examples into a single multiplication problem.
    """
    new_input = f"{ex1.input.rstrip(' ?')} * {ex2.input.rstrip(' ?')}"
    new_target = str(int(ex1.target) * int(ex2.target))
    return dspy.Example(input=new_input, target=new_target)

# Generate extended dataset by chaining multiplications
multi_step_extended = [
    combine_multi_step_arithmetic(multi_step_combined[i], multi_step_combined[i+1])
    for i in range(len(multi_step_combined) - 1)
]

# Peek at an example
multi_step_extended[5]

# Sec 2: Instruction Tuning
This section implements a meta‑prompting framework where multiple LLM agents propose, refine, and evaluate candidate prompts in a coordination graph. We omit file writes and simply print metrics for downstream adjustment.

## 2.1 Library Modules

### 2.1.1 Base LLM Signatures and Wrappers
We define DSPy signatures and modules for each task type (GSM8K, GeometricShapes, WordSorting, MultiStepArithmetic).

In [None]:
# GSM8K Signature & Module
class GSM8kSignature(dspy.Signature):
    question: str = dspy.InputField(desc="math word problem")
    reasoning: str = dspy.OutputField(desc="step-by-step reasoning")
    answer: int = dspy.OutputField(desc="integer answer")

class GSM8kModule(dspy.Module):
    def __init__(self, llm):
        self.predictor = dspy.Predict(GSM8kSignature)
        self.predictor.set_lm(llm)

    def set_instructions(self, instr: str):
        self.predictor.signature = GSM8kSignature.with_instructions(instr)

    def __call__(self, question: str) -> dspy.Prediction:
        resp = self.predictor(question=question)
        return dspy.Prediction(reasoning=resp.reasoning, answer=str(resp.answer))

# GeometricShapes Signature & Module
class GeomShapesSignature(dspy.Signature):
    input: str = dspy.InputField(desc="SVG path question")
    reasoning: str = dspy.OutputField(desc="chain-of-thought reasoning")
    answer: str = dspy.OutputField(desc="shape label in parentheses")

class GeomShapesModule(dspy.Module):
    def __init__(self, llm):
        self.predictor = dspy.Predict(GeomShapesSignature)
        self.predictor.set_lm(llm)

    def set_instructions(self, instr: str):
        self.predictor.signature = GeomShapesSignature.with_instructions(instr)

    def __call__(self, input: str) -> dspy.Prediction:
        resp = self.predictor(input=input)
        # Truncate answer to first option label
        return dspy.Prediction(reasoning=resp.reasoning, answer=resp.answer[:3])

# MultiStepArithmetic Signature & Module
class MultiStepSignature(dspy.Signature):
    input: str = dspy.InputField(desc="arithmetic problem")
    reasoning: str = dspy.OutputField(desc="chain-of-thought")
    answer: int = dspy.OutputField(desc="integer result")

class MultiStepModule(dspy.Module):
    def __init__(self, llm):
        self.predictor = dspy.Predict(MultiStepSignature)
        self.predictor.set_lm(llm)

    def set_instructions(self, instr: str):
        self.predictor.signature = MultiStepSignature.with_instructions(instr)

    def __call__(self, input: str) -> dspy.Prediction:
        resp = self.predictor(input=input)
        return dspy.Prediction(reasoning=resp.reasoning, answer=str(resp.answer))

### 2.1.2 Proposer & Refiner
Two cooperating modules generate and refine candidate prompts via message passing.

In [None]:
class ProposalAgentSignature(dspy.Signature):
  task_description: str = dspy.InputField(desc="The task the prompt should address.")
  prompt: str = dspy.OutputField(desc="The generated candidate prompt.")

class RefinerSignature(dspy.Signature):
  task_description: str = dspy.InputField(desc="The task the prompt should address.")
  current_prompt: str = dspy.InputField(desc="The agent's current candidate prompt.")
  neighbor_prompts: list[str] = dspy.InputField(desc="The current prompts of the agent's neighbors in the coordination graph.")
  feedback_to_refine_prompt: str = dspy.InputField(desc="Feedback for refining the agent's current prompt.")
  refined_prompt: str = dspy.OutputField(desc="The improved candidate prompt after refinement.")

class ProposalAgent(dspy.Module):
  def __init__(self, idx, task_description, llm, verbose=False):
    self.task_description = task_description
    self.verbose = verbose
    self.idx = idx

    self.proposal_instr = """
    You are an expert prompt generator. Craft a creative and effective prompt for a large language model that
    will improve its performance at solving the given described task.
    """
    self.refinement_instr = """
    You are an expert prompt editor.
    Refine the current task prompt based on the feedback and on the prompts of your coordinating neighbors.
    Ensure the refined prompot is effective at solving the described task.
    """
    prop_signature = ProposalAgentSignature.with_instructions(self.proposal_instr)
    ref_signature = RefinerSignature.with_instructions(self.refinement_instr)
    self.generate_module = dspy.Predict(prop_signature)
    self.refine_module = dspy.Predict(ref_signature)

    self.current_prompt = ""
    self.current_score = 0.0
    self.set_lm(llm)

  def generate_initial_prompt(self):
    """
    Generate the initial prompt for the task.
    """
    self.current_prompt = self.generate_module(
        task_description=self.task_description).prompt

  def refine_prompt(self, neighbor_prompts, feedback):
    """
    Refine the agent's current prompt based on its score and neighbor information.
    """
    self.current_prompt = self.refine_module(
      task_description=self.task_description,
      current_prompt=self.current_prompt,
      neighbor_prompts=neighbor_prompts,
      feedback_to_refine_prompt=feedback
    ).refined_prompt

### 2.1.3 Evaluator

In [None]:
# Define Evaluator Agent; just a wrapper around dspy.Evaluate
class EvaluatorAgent:
  def __init__(self, model_lm, devset, metric, minibatch, **eval_kwargs):
    """
    Initialize the EvaluatorAgent.

    Args:
      model_lm (dspy.LM): The base language model for evaluation.
      devset (list): The full development dataset.
      metric (str): Metric to evaluate the prompts (e.g., accuracy, BLEU).
      minibatch (int): Number of samples to use for each evaluation.
      **eval_kwargs: Additional arguments for dspy.Evaluate.
    """
    self.model_lm = model_lm
    self.eval_kwargs = eval_kwargs
    self.devset = devset
    self.metric = metric
    self.minibatch = minibatch

    # Ensure subset size is valid
    assert self.minibatch <= len(self.devset), "Subset size exceeds dataset size."

    self.current_subset = None  # Placeholder for the subset used in a round

  def set_subset_for_round(self):
    """
    Set a random subset of the dataset to be used for the current round.
    """
    self.current_subset = random.sample(self.devset, self.minibatch)

  def set_full_set(self):
    self.current_subset = self.devset

  def evaluate_prompt(self, candidate_prompt):
    """
    Evaluate a candidate prompt using the model and the current subset of the development dataset.

    Args:
      candidate_prompt (str): The prompt to evaluate.

    Returns:
      float: The evaluation score of the prompt.
    """
    # Ensure the subset is set
    if self.current_subset is None:
      raise ValueError("Dataset subset for the current round has not been set. Call set_subset_for_round() first.")

    # Create the evaluator
    evaluator = dspy.Evaluate(
        devset=self.current_subset, metric=self.metric, **self.eval_kwargs
    )

    # Define the model to evaluate with the given prompt
    self.model_lm.set_instructions(candidate_prompt)

    # Evaluate the model and return the score
    score = evaluator(self.model_lm)
    return score

### 2.1.4 Coordination Graph & Orchestrator
Defines the message‑passing algorithm over multiple rounds.

In [None]:
class CoordinationGraphSignature(dspy.Signature):
  task_description: str = dspy.InputField(desc="The task the prompts should address.")
  agent_prompts: list[str] = dspy.InputField(desc="List of all current prompts from the agents.")
  previous_round_scores: list[float] = dspy.InputField(desc="List of prompt scores from the previous round.")
  feedback_to_refine_prompts: list[str] = dspy.OutputField(desc="Feedback for refining each agent's prompt.")
  coordination_graph: list[list[int]] = dspy.OutputField(
      desc="For each agent in order, a list of other agents it should coordinate prompts with.")

# Define Coordination Graph and Message Passing
class CoordinationGraph:
  def __init__(self, proposal_agents, task_description, llm, max_neighbors=3, verbose=False):
    """
    Initialize the CoordinationGraph with a list of proposal agents, task description,
    and an LLM for generating the graph.

    Args:
      proposal_agents (list): List of ProposalAgent objects.
      task_description (str): Description of the task the prompts are addressing.
      llm (dspy.LM): The language model to generate the coordination graph.
      verbose (bool): Whether to print verbose output.
    """
    self.proposal_agents = proposal_agents
    self.task_description = task_description
    self.verbose = verbose

    self.graph = {agent.idx: [] for agent in proposal_agents}
    self.current_refinement_feedback = []

    # Define the DSPy module for generating the coordination graph
    self.graph_instr = f"""
      You are an expert at prompt writing and editing.
      Based on the task description, current agent prompts, and corresponding scores in the previous
      round, determine which agents should coordinate their prompts in the next round to improve their
      performance on the described task.
      Consider the following principles for deciding which agents should coordinate:
      1. Agents with very similar prompts should coordinate to avoid redundancy.
      2. If a pair of agents have different prompts and different scores,
      the lower scoring agent should coordinate with the higher scoring agent but not vice versa.
      3. The highest scoring agents don't necessarily have to coordinate but should still
      get feedback on possible improvements.
      Limit the number of such neighbors for each agent to a maximum of {max_neighbors}, prioritizing the most relevant dependencies.
      Return a mapping of each agent to the agents it should coordinate with and
      specific and targeted feedback for each agent on how they should refine their prompts to improve
      performance on the described task.
    """
    cg_signature = CoordinationGraphSignature.with_instructions(self.graph_instr)
    self.graph_module = dspy.Predict(cg_signature)

    self.graph_module.set_lm(llm)
    self.max_neighbors = max_neighbors

  def update_graph(self, previous_round_scores):
    """
    Update the coordination graph using the LLM to evaluate relationships between prompts.

    Args:
      max_neighbors (int): Maximum number of neighbors for each agent.
    """
    # Gather current prompts from all agents
    prompts = [agent.current_prompt for agent in self.proposal_agents]
    idxs = [agent.idx for agent in self.proposal_agents]

    # Use the LLM to generate the coordination graph
    result = self.graph_module(
        task_description=self.task_description,
        agent_prompts=prompts,
        previous_round_scores=previous_round_scores
    )

    # Limit the number of neighbors per agent
    self.graph = {}
    for (i, neighbors) in enumerate(result.coordination_graph):
        # Prioritize neighbors based on some criteria (e.g., relevance, score, etc.)
        prioritized_neighbors = neighbors[:self.max_neighbors]
        self.graph[idxs[i]] = prioritized_neighbors
        if self.verbose:
          print(f"Agent {idxs[i]}: Neighbors -> {prioritized_neighbors}")

    # Set current refinement feedback
    self.current_refinement_feedback = result.feedback_to_refine_prompts
    if self.verbose:
      print("Refinement Feedback:")
      for idx, feedback in enumerate(self.current_refinement_feedback):
          print(f"Agent {idx}: Feedback -> {feedback}")


In [None]:
class Orchestrator:
  def __init__(self, proposal_agents, evaluator, coordination_graph, num_rounds, subset=True, verbose=False):
    """
    Initialize the orchestrator.

    Args:
      proposal_agents (list): List of ProposalAgent instances.
      evaluator (EvaluatorAgent): Evaluator agent for scoring prompts.
      coordination_graph (CoordinationGraph): Dynamic coordination graph.
      num_rounds (int): Number of message-passing rounds.
      verbose (bool): Whether to enable verbose output.
    """
    self.proposal_agents = proposal_agents
    self.evaluator = evaluator
    self.coordination_graph = coordination_graph
    self.num_rounds = num_rounds
    self.verbose = verbose
    self.subset = subset

    # Track all scores and prompts
    self.all_scores = []
    self.all_prompts = []
    self.best_score_in_round = []
    self.best_prompt_in_round = []

  def run(self):
    # Generate initial prompts
    print("\n--- Generating and Evaluating Initial Prompts ---")
    for agent in self.proposal_agents:
      agent.generate_initial_prompt()

    # Initial evaluation
    self._evaluate_all_prompts()

    # Get best score and prompt after round
    best_prompt, best_score = self._get_best_prompt_and_score_in_round()
    self.best_prompt_in_round.append(best_prompt)
    self.best_score_in_round.append(best_score)

    # Iterative message passing and refinement
    for round_num in range(self.num_rounds):
      print(f"\n--- Round {round_num + 1} ---")

      # Update coordination graph
      self.coordination_graph.update_graph(
          previous_round_scores=[
              agent.current_score for agent in self.proposal_agents
          ]
      )

      # Refine prompts using message passing
      for agent in self.proposal_agents:
        neighbor_prompts = [
            self.proposal_agents[nbr_idx].current_prompt for nbr_idx in self.coordination_graph.graph[agent.idx]]
        feedback = self.coordination_graph.current_refinement_feedback[agent.idx]
        agent.refine_prompt(neighbor_prompts, feedback)

      # Evaluate refined prompts
      self._evaluate_all_prompts()
      best_prompt, best_score = self._get_best_prompt_and_score_in_round()
      self.best_prompt_in_round.append(best_prompt)
      self.best_score_in_round.append(best_score)

      # Print the best score so far
      if self.verbose:
        print(f"\n--- Best Score: {max(self.all_scores)}")

    # Decide final output
    print("\n--- Finalizing Output ---")
    best_prompt, best_score = self._get_best_prompt_and_score()

    final_results = {
        "best_prompt": best_prompt,
        "best_score": best_score,
        "best_prompt_after_round": self.best_prompt_in_round,
        "best_score_after_round": self.best_score_in_round
    }
    return final_results

  def _evaluate_all_prompts(self):
    """
    Evaluate all prompts and track their scores and prompts.
    """
    if self.subset:
      self.evaluator.set_subset_for_round()
    else:
      self.evaluator.set_full_set()
    for agent in self.proposal_agents:
      score = self.evaluator.evaluate_prompt(agent.current_prompt)
      agent.current_score = score

      # Track all scores and prompts
      self.all_scores.append(score)
      self.all_prompts.append(agent.current_prompt)

      if self.verbose:
        print(f"Agent Prompt: {agent.current_prompt}, Score: {score}")

  def _get_best_prompt_and_score(self):
    """
    Get the best prompt and its score across all rounds.

    Returns:
      tuple: Best prompt and its score.
    """
    best_index = self.all_scores.index(max(self.all_scores))
    best_prompt = self.all_prompts[best_index]
    best_score = self.all_scores[best_index]
    return best_prompt, best_score

  def _get_best_prompt_and_score_in_round(self):
    curr_scores = []
    curr_prompts = []
    for agent in self.proposal_agents:
      curr_scores.append(agent.current_score)
      curr_prompts.append(agent.current_prompt)
    best_score = max(curr_scores)
    best_index = curr_scores.index(best_score)
    best_prompt = curr_prompts[best_index]
    return best_prompt, best_score

## 2.2 Usage

### 2.2.1 Configure Models & Task

In [None]:
# Define LLM instances for this section
gpt4o = dspy.LM('openai/gpt-4o', temperature=0.7, max_tokens=3000, cache=False)
gpt4o_mini = dspy.LM('openai/gpt-4o-mini', temperature=0.7, max_tokens=3000, cache=False)
haiku3 = dspy.LM('anthropic/claude-3-haiku-20240307', max_tokens=3000, cache=False)
gpt_35 = dspy.LM('openai/gpt-3.5-turbo', max_tokens=3000, cache=False)

# Select configurable params for experiments
task = 'gsm8k' #@param enum=["geom_shapes", "gsm8k", "multi_step"]
n_agents = 4 #@param {type:"integer"}
n_rounds = 2 #@param {type:"integer"}
max_cg_nbrs = 1 #@param {type:"integer"}
verbose = False #@param {type:"boolean"}

base_map = {
    'gsm8k': (GSM8kModule, gpt_35),
    'multi_step': (MultiStepModule, gpt_35),
    'geom_shapes': (GeomShapesModule, haiku3),
}
ModuleClass, base_llm = base_map[task]

if task == "gsm8k":
  trainset = gsm8k_dev
  testset = gsm8k_test[0:1300:6]
  metric = gsm8k_metric
elif task == "geom_shapes":
  trainset = geom_shapes_train
  testset = geom_shapes_dev
  metric = meta_prompter_metric
elif task == "multi_step":
  trainset = multi_step_train
  testset = multi_step_dev
  metric = meta_prompter_metric

TASK_DESCRIPTION_MAP = {
    "gsm8k": "Solve complex mathematical reasoning problems correctly",
    "multi_step": "Solve complex arithmetic problems correctly",
    "geom_shapes": "Correctly identify a 2D geometric shape from an SVG path element",
}

task_description = TASK_DESCRIPTION_MAP[task]

eval_kwargs = dict(num_threads=2, display_progress=True, display_table=0)
dspy_evaluator = dspy.Evaluate(devset=testset, metric=metric, **eval_kwargs)

# Instantiate evaluation metric and evaluator
evaluator = EvaluatorAgent(
    model_lm=ModuleClass(base_llm),
    devset=trainset,
    metric=metric,
    minibatch=50,
    num_threads=2,
    display_progress=True
)

### 2.2.2 Run Multi-Agent Prompt Optimization

In [None]:
# Initialize proposal agents and coordination graph
agents = [ProposalAgent(i, task_description, gpt4o_mini) for i in range(5)]
cg = CoordinationGraph(agents, task_description, gpt4o, max_cg_nbrs, verbose)
orch = Orchestrator(agents, evaluator, cg, num_rounds=n_rounds, subset=True, verbose=True)

results = orch.run()
print("Optimization Rounds → Best Scores on minibatch:", results['best_score_after_round'])

### 2.2.3 Held‑Out Evaluation & Baselines
We apply the best prompts to a held‑out test set and compare against baselines.

In [None]:
# Held‑out scoring
heldout_scores = []

for i in range(n_rounds+1):
  prompt = results['best_prompt_after_round'][i]
  if i >=1 and prompt == results['best_prompt_after_round'][i-1]:
    print("Skipping duplicate prompt")
    heldout_scores.append(heldout_score)
    continue
  maxplus_optimized_module = ModuleClass(base_llm)
  maxplus_optimized_module.set_instructions(prompt)
  heldout_score = dspy_evaluator(maxplus_optimized_module)
  heldout_scores.append(heldout_score)
print("Test set scores per round:" + str(heldout_scores))

# Baselines
# 1) MIPRO
from dspy import MIPROv2
mipro_module_before_opt = ModuleClass(llm=base_llm)
kwargs = dict(num_threads=2, prompt_model=gpt4o_mini, auto="medium")
optimizer = dspy.MIPROv2(metric=metric, **kwargs)
kwargs = dict(requires_permission_to_run=False, max_bootstrapped_demos=0, max_labeled_demos=0)
mipro_optimized_module = optimizer.compile(mipro_module_before_opt, trainset=random.sample(trainset, min(n_rounds*50, len(trainset))), **kwargs)

mipro_res = dspy_evaluator(mipro_optimized_module)
print("MIPRO baseline:", mipro_res)

# 2) Base LLM
base_mod = ModuleClass(base_llm)
print("Base LLM accuracy:", dspy_evaluator(base_mod))

# # 3) CoT
gsm_cot_module = ModuleClass(llm=base_llm)
meta_prompter_cot_module = dspy.ChainOfThought("input -> answer")
meta_prompter_cot_module.set_lm(base_llm)

if task == "gsm8k":
  cot_module = gsm_cot_module
else:
  cot_module = meta_prompter_cot_module
print("Chain-of-Thought:", dspy_evaluator(cot_module))

# Sec 3: Adaptive Divide-and-Conquer Arithmetic
This notebook section demonstrates how to split multi-step arithmetic into binary sub-expressions,
coordinate specialized agents, and stream intermediate results in an interruptible fashion.

## 3.1 Library Modules

### 3.1.1 Expression Decomposer
Breaks a complex expression into as many binary sub-expressions as possible for one round.

In [None]:
class ExpressionDecompositionSignature(dspy.Signature):
  full_expr: str = dspy.InputField(desc="The expression to decompose.")
  sub_expressions: list[str] = dspy.OutputField(desc="Binary sub-expressions to solve in parallel.")
  plan: str = dspy.OutputField(desc="Placeholder-based plan for substituting results.")

class ExpressionDecomposer:
  def __init__(self, llm, verbose=False):
    self.llm = llm
    self.verbose = verbose
    self.decomp_instructions = (
      """
      You are an expert arithmetic expression analyzer.
      Input: a multi-step arithmetic string.
      Outputs:
        1. sub_expressions: list of minimal binary sub-exprs solvable in parallel this round.
        2. plan: concise instructions using placeholders (X1, X2, ...) to substitute results.
      Rules:
        - Only decompose the current expression, not future rounds.
        - Carefully handle signs and parentheses.
        - Maximize parallelizable binary ops.
        - Use as few words as possible while describing the plan
      """
    )
    sig = ExpressionDecompositionSignature.with_instructions(self.decomp_instructions)
    self.decomp_module = dspy.Predict(sig)
    self.decomp_module.set_lm(self.llm)

  def decompose(self, expr: str) -> tuple[list[str], str]:
    result = self.decomp_module(full_expr=expr)
    if self.verbose:
      print(f"[Decomposer] sub_expressions: {result.sub_expressions}")
      print(f"[Decomposer] plan: {result.plan}")
    return result.sub_expressions, result.plan


### 3.1.2. Expression Substitutor
Merges numeric results back into the expression according to the plan.

In [None]:
class ExpressionSubstitutionSignature(dspy.Signature):
  old_expr: str = dspy.InputField(desc="Previous expression.")
  plan: str = dspy.InputField(desc="Placeholder plan from decomposer.")
  partial_results: list[str] = dspy.InputField(desc="Results for each sub-expression.")
  new_expr: str = dspy.OutputField(desc="Updated expression or final binary op.")
  done: bool = dspy.OutputField(desc="True if only one operation remains.")

class ExpressionSubstitutor:
  def __init__(self, llm, verbose=False):
    self.llm = llm
    self.verbose = verbose
    self.subst_instructions = (
      """
      You are a specialized expression aggregator.
      Inputs:
        - old_expr: expression from last round
        - plan: how to substitute placeholders
        - partial_results: values for each sub-expression
      Outputs:
        new_expr: updated expression after substitution
        done: true if it's a single binary op
      Steps:
        1. Substitute results per plan, no new computation.
        2. Set done = true if only one operator remains.
        3. Remove redundant parentheses.
      """
    )
    sig = ExpressionSubstitutionSignature.with_instructions(self.subst_instructions)
    self.subst_module = dspy.Predict(sig)
    self.subst_module.set_lm(self.llm)

  def substitute(self, old_expr: str, plan: str, partial_results: list[str]) -> tuple[str, bool]:
    result = self.subst_module(
      old_expr=old_expr,
      plan=plan,
      partial_results=partial_results
    )
    if self.verbose:
      print(f"[Substitutor] new_expr: {result.new_expr}, done: {result.done}")
    return result.new_expr, result.done

### 3.1.3 Subproblem Solver Agent
Solves a single binary arithmetic expression. Will be token-constrained.

In [None]:
class SubproblemSignature(dspy.Signature):
  sub_expr: str = dspy.InputField(desc="A minimal single-operator expression.")
  result: int = dspy.OutputField(desc="Integer result.")

class SubproblemAgent:
  def __init__(self, idx: int, llm, verbose=False):
    self.idx = idx
    self.llm = llm
    self.verbose = verbose
    instr = (
      """
      You are a specialized solver for one binary operation.
      Input: '3 * -4' or '5 + 10'.
      Output: the integer result.
      """
    )
    sig = SubproblemSignature.with_instructions(instr)
    self.compute_module = dspy.Predict(sig)
    self.compute_module.set_lm(self.llm)

  def compute(self, sub_expr: str) -> str:
    response = self.compute_module(sub_expr=sub_expr)
    if self.verbose:
      print(f"[SubproblemAgent {self.idx}] {sub_expr} -> {response.result}")
    return str(response.result)


### 3.1.4 Arithmetic Orchestrator
Coordinates decomposer, subproblem agents, and substitutor until completion.

In [None]:
class AnswerFormatterSignature(dspy.Signature):
  final_expression: str = dspy.InputField(desc="Final binary operation.")
  answer: int = dspy.OutputField(desc="Integer answer.")

class ArithmeticOrchestrator:
  def __init__(self, llm_decomp, llm_subst, llm_small, verbose=False):
    self.decomposer = ExpressionDecomposer(llm_decomp, verbose)
    self.substitutor = ExpressionSubstitutor(llm_subst, verbose)
    self.llm_small = llm_small
    self.answer_formatter = dspy.Predict(AnswerFormatterSignature)
    self.answer_formatter.set_lm(llm_small)
    self.verbose = verbose

  def solve_expression(self, expression: str) -> tuple[dspy.Prediction, list[str]]:
    partial_exprs: list[str] = []
    current_expr = expression
    done = False
    round_idx = 0

    # Loop until the expression resolves to one binary operation
    while not done:
      if self.verbose:
        print(f"\n=== Round {round_idx} ===")
        print(f"Current expr: {current_expr}")

      partial_exprs.append(current_expr)

      # 1) Decompose
      sub_exprs, plan = self.decomposer.decompose(current_expr)
      if not sub_exprs:
        # Fully simplified
        break

      # 2) Parallel solve
      partial_results = [
        SubproblemAgent(i, self.llm_small, self.verbose).compute(sub)
        for i, sub in enumerate(sub_exprs)
      ]

      # 3) Substitute
      current_expr, done = self.substitutor.substitute(current_expr, plan, partial_results)
      round_idx += 1

    # Final answer formatting
    final = self.answer_formatter(final_expression=current_expr)
    if self.verbose:
      print(f"\nFinal answer: {final.answer}")
    return final, partial_exprs

## 3.2 Usage

### 3.2.1 Configure models

In [None]:
gpt_4o3 = dspy.LM('openai/o3-mini-2025-01-31', temperature=1.0, max_tokens=20000, cache=False)
gpt4o_mini_250_lotemp = dspy.LM('openai/gpt-4o-mini', temperature=0.1, max_tokens=250, cache=False)

LLM_DECOMPOSE = gpt_4o3
LLM_SUBSTITUTE = gpt_4o3
LLM_SMALL = gpt4o_mini_250_lotemp

### 3.2.1 Single expression demo

In [None]:
orchestrator = ArithmeticOrchestrator(LLM_DECOMPOSE, LLM_SUBSTITUTE, LLM_SMALL, verbose=True)
expr = "((1 + 0 + 2 - 4) + (-9 + 6 * -5 + 8))"
final_ans, steps = orchestrator.solve_expression(expr)
print(f"Answer: {final_ans.answer}")

# Baseline on same example
o3_reasoning_agent = dspy.Predict("input -> reasoning, answer")
o3_reasoning_agent.set_lm(gpt_4o3)
response = o3_reasoning_agent(input=expr)
print(response.reasoning)

### 3.2.2 Batch Evaluation
**Note:** This will take a LONG time to run

In [None]:
evalset = multi_step_combined[10:210]
correct = []
for ex in evalset:
  ans, _ = orchestrator.solve_expression(ex['input'])
  correct.append(meta_prompter_metric(ex, ans))
print(f"Batch accuracy: {np.mean(correct) * 100:.2f}%")

### Baseline comparison
o3_reasoning_agent = dspy.Predict("input -> reasoning, answer")
o3_reasoning_agent.set_lm(gpt_4o3)
evaluate = dspy.Evaluate(devset=evalset, metric=meta_prompter_metric)
res = evaluate(o3_reasoning_agent)
print(f"Baseline batch accuracy: {res['score']:.2f}%")