Given a task (e.g., arithmetic reasoning) and a model, I want to know what types of problems in this task that the model finds difficult.
Generate samples from a structured representation of your task, and use this structured representation to search for the worst performing problems.
For now, I just consider GSM8k-style math word problems. I specify a set of valid math word problems with a context-free grammar. This lets me generate problems with a specific structure. For example, I can convert the following arithmetic expression
into the problem:
Being a plumber pays $5 an hour. Tina worked 19 hours as a plumber. Being a nurse pays $2 an hour. David worked 49 hours as a nurse. A ball costs $96. The group bought a ball using the amount David made working as a nurse. The group pooled the amount Tina made working as a plumber and the amount of money left after buying a ball and put it in a silver piggy bank. What is the value of the amount of money in the silver piggy bank?
Using this structured representation, I then search through the space of problems using Thompson sampling. The latter was inspired by (Sclar et al., 2023).
I search with a budget of 2500 queries (to Claude v1.2 Instant) over a space of arithmetic reasoning problems that have similar levels of complexity.
The worst performing problem has an accuracy of 67.5% over 80 queries, and has the structure:
Each query involves sampling random integers for x_1, ..., x_12 such that all intermediate values are between 2 and 100, and sampling random values for names, items, jobs, etc. that are used in the story.
The (relatively) poor performance on this problem type isn't just due to its complexity as one of the best performing problems gets 100% accuracy over 40 queries, with the structure:
The intermediate chain of thought reasoning allow us to analyze the failure-mode of this type of problem. Most of the the time, the error comes from the model "forgetting" to add the result from the subtree with leaf nodes x_5, x_6, x_7. My guess is that this is due to there being two consecutive additions in the problem, which tends to be unusual in e.g., the original GSM8k dataset.
See experiemnts/find_spread.py for an example of usage. To generate problems, run:
problem = ProblemBuilder("jobs")
df = problem.build_dataset(
max_depth,
num_samples_per_problem,
min_depth=min_depth,
subsample=num_subsample
)where jobs is a problem specified in mwp/problems/problem_values.py.
To search for the best/worst problems, run:
ps = ProblemSpace(df, replace=True)
opt = ThompsonOpt(ps, reward_func, beta_prior=(1, 1))
ranked_arms, arm_vals = opt.optimize(
budget=budget,
samples_per_pull=samples_per_pull,
maximize=maximize,
verbose=True
)where reward_func is a black-box function that returns the accuracy given a batch of samples, and samples_per_pull specifies size of the batch (See Sclar et al., 2023 for more detail).


