In [12]:
import json

import numpy as np
from datasets import load_dataset
from math_agent import (
    MathAgent,
    MathEnvironment,
    extract_result_value,
    solve_task,
)
from termcolor import colored
from tqdm import tqdm

from tapeagents.llms import TrainableLLM


In [13]:
env = MathEnvironment()


def eval(tested_agent, test_set, name="") -> float:
    test_solved = []
    n = 0
    for sample in tqdm(test_set):
        sample = extract_result_value(sample)
        try:
            tape = solve_task(tested_agent, env, sample)
            test_solved.append(int(tape.metadata.result["solved"]))
        except Exception as e:
            print(colored("Failed to solve task: {e}", "red"))
            test_solved.append(0)
            raise e
        acc = np.mean(test_solved).item()
        n = len(test_solved)
        if n % 10 == 0 and n > 0:
            print(f"{n}: Current accuracy: {acc:.3f}")
            with open("results.jsonl", "a") as f:
                f.write(json.dumps({name: acc, "n": n}) + "\n")
    acc = np.mean(test_solved).item()
    with open("results.jsonl", "a") as f:
        f.write(json.dumps({name: acc, "n": n}) + "\n")
    return acc


In [3]:
test_dataset = load_dataset("openai/gsm8k", "main", split="test")
test_samples = [s for s in test_dataset]
np.random.seed(42)
np.random.shuffle(test_samples)  # type: ignore
test_set = test_samples[:200]

dataset = load_dataset("openai/gsm8k", "main", split="train")
val_samples = [s for s in dataset]
np.random.seed(42)
np.random.shuffle(val_samples)  # type: ignore
val_set = val_samples[:200]


## Untuned model accuracy

In [4]:
# run inference: vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct
untuned_agent = MathAgent(
    llms={
        "default": TrainableLLM(
            base_url="http://localhost:8000",
            model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
            tokenizer_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
            parameters=dict(temperature=0.1),
            use_cache=False,
        )
    }
)


In [12]:
acc = eval(untuned_agent, test_set)
print(f"Untuned on test {acc:.3f}")


  6%|▌         | 11/200 [00:50<15:08,  4.81s/it]

10: Current accuracy: 0.545


 10%|█         | 21/200 [01:22<08:29,  2.85s/it]

20: Current accuracy: 0.476


 16%|█▌        | 31/200 [01:55<08:12,  2.91s/it]

30: Current accuracy: 0.613


 20%|██        | 41/200 [02:35<10:20,  3.90s/it]

40: Current accuracy: 0.585


 26%|██▌       | 51/200 [03:21<13:38,  5.49s/it]

50: Current accuracy: 0.608


 30%|███       | 61/200 [03:53<06:24,  2.77s/it]

60: Current accuracy: 0.656


 36%|███▌      | 71/200 [04:34<08:15,  3.84s/it]

70: Current accuracy: 0.662


 40%|████      | 81/200 [06:01<11:10,  5.63s/it]

80: Current accuracy: 0.654


 46%|████▌     | 91/200 [06:38<07:27,  4.11s/it]

90: Current accuracy: 0.681


 50%|█████     | 101/200 [07:24<07:09,  4.34s/it]

100: Current accuracy: 0.683


 56%|█████▌    | 111/200 [08:01<06:02,  4.07s/it]

110: Current accuracy: 0.676


 60%|██████    | 121/200 [08:44<05:48,  4.41s/it]

120: Current accuracy: 0.678


 66%|██████▌   | 131/200 [09:47<05:18,  4.61s/it]

130: Current accuracy: 0.679


 70%|███████   | 141/200 [11:55<11:14, 11.44s/it]

140: Current accuracy: 0.660


 76%|███████▌  | 151/200 [12:34<03:24,  4.18s/it]

150: Current accuracy: 0.656


 80%|████████  | 161/200 [13:13<02:49,  4.34s/it]

160: Current accuracy: 0.658


 86%|████████▌ | 171/200 [14:13<02:37,  5.44s/it]

170: Current accuracy: 0.655


 90%|█████████ | 181/200 [14:49<01:05,  3.45s/it]

180: Current accuracy: 0.663


 96%|█████████▌| 191/200 [15:29<00:35,  3.95s/it]

190: Current accuracy: 0.654


100%|██████████| 200/200 [16:12<00:00,  4.86s/it]

Untuned on test 0.660





## Tuned model accuracy

In [7]:
# run inference: vllm serve gsm8k/tuning/llama31_70b_train_t02/tune1/intermediate/800/
tuned_agent = MathAgent(
    llms={
        "default": TrainableLLM(
            base_url="http://localhost:8000",
            model_name="gsm8k/tuning/llama31_70b_train_t02/tune1/intermediate/800/",
            tokenizer_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
            parameters=dict(temperature=0.0),
            use_cache=False,
        )
    }
)


In [8]:
tuned_acc = eval(tuned_agent, test_set, "tuned_acc")
print(f"Tuned on test {tuned_acc:.3f}")


  6%|▌         | 11/200 [00:54<16:13,  5.15s/it]

10: Current accuracy: 0.727


 10%|█         | 21/200 [01:43<13:07,  4.40s/it]

20: Current accuracy: 0.762


 16%|█▌        | 31/200 [02:23<11:41,  4.15s/it]

30: Current accuracy: 0.806


 20%|██        | 41/200 [03:13<10:48,  4.08s/it]

40: Current accuracy: 0.756


 26%|██▌       | 51/200 [04:04<13:32,  5.45s/it]

50: Current accuracy: 0.784


 30%|███       | 61/200 [04:49<09:19,  4.03s/it]

60: Current accuracy: 0.820


 36%|███▌      | 71/200 [05:37<10:30,  4.89s/it]

70: Current accuracy: 0.789


 40%|████      | 81/200 [07:40<16:14,  8.19s/it]

80: Current accuracy: 0.765


 46%|████▌     | 91/200 [08:23<07:25,  4.09s/it]

90: Current accuracy: 0.791


 50%|█████     | 101/200 [09:12<08:19,  5.05s/it]

100: Current accuracy: 0.802


 56%|█████▌    | 111/200 [10:06<08:22,  5.64s/it]

110: Current accuracy: 0.784


 60%|██████    | 121/200 [10:59<08:44,  6.64s/it]

120: Current accuracy: 0.802


 66%|██████▌   | 131/200 [12:52<15:00, 13.04s/it]

130: Current accuracy: 0.771


 70%|███████   | 141/200 [15:14<13:32, 13.77s/it]

140: Current accuracy: 0.752


 76%|███████▌  | 151/200 [17:05<06:28,  7.93s/it]

150: Current accuracy: 0.748


 80%|████████  | 161/200 [17:50<03:11,  4.91s/it]

160: Current accuracy: 0.752


 86%|████████▌ | 171/200 [19:59<08:19, 17.21s/it]

170: Current accuracy: 0.754


 90%|█████████ | 181/200 [20:48<01:37,  5.14s/it]

180: Current accuracy: 0.768


 96%|█████████▌| 191/200 [21:43<00:51,  5.74s/it]

190: Current accuracy: 0.775


100%|██████████| 200/200 [22:35<00:00,  6.78s/it]

Tuned on test 0.775





In [14]:
# check teacher model
big_llm = TrainableLLM(
    base_url="https://api.together.xyz",
    model_name="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    tokenizer_name="meta-llama/Meta-Llama-3.1-70B-Instruct",
    parameters=dict(temperature=0.2),
    use_cache=False,
)
big_agent = MathAgent(llms={"default": big_llm})


In [15]:
big_acc = eval(big_agent, test_set, "big_acc")
print(f"Teacher on test {big_acc:.3f}")


  5%|▌         | 10/200 [01:59<38:17, 12.09s/it]

10: Current accuracy: 0.800


 10%|█         | 20/200 [04:12<39:39, 13.22s/it]

20: Current accuracy: 0.900


 15%|█▌        | 30/200 [05:27<18:59,  6.70s/it]

30: Current accuracy: 0.933


 20%|██        | 40/200 [07:17<32:46, 12.29s/it]

40: Current accuracy: 0.925


 25%|██▌       | 50/200 [08:58<24:59, 10.00s/it]

50: Current accuracy: 0.920


 30%|███       | 60/200 [10:27<16:43,  7.17s/it]

60: Current accuracy: 0.933


 35%|███▌      | 70/200 [12:01<21:15,  9.82s/it]

70: Current accuracy: 0.929


 40%|████      | 80/200 [13:53<22:59, 11.50s/it]

80: Current accuracy: 0.938


 45%|████▌     | 90/200 [15:38<17:53,  9.76s/it]

90: Current accuracy: 0.944


 50%|█████     | 100/200 [17:26<14:13,  8.54s/it]

100: Current accuracy: 0.940


 52%|█████▏    | 103/200 [17:51<13:42,  8.48s/it]Failed to parse agent output: {"kind": "use_calculator_action", "expression": "30 * (1 - 0.3)"}
{"kind": "reasoning_thought", "reasoning": "Calculate how many times Anne went down the slide, which is 30% less than Mitchel."}

Error: Extra data: line 2 column 1 (char 66)
Traceback (most recent call last):
  File "/home/toolkit/TapeAgents/tapeagents/guided_agent.py", line 102, in parse_completion
    step_dicts = json.loads(sanitize_json_completion(completion))
  File "/home/toolkit/.conda/envs/tapeagents2/lib/python3.10/json/__init__.py", line 346, in loads
    return _default_decoder.decode(s)
  File "/home/toolkit/.conda/envs/tapeagents2/lib/python3.10/json/decoder.py", line 340, in decode
    raise JSONDecodeError("Extra data", s, end)
json.decoder.JSONDecodeError: Extra data: line 2 column 1 (char 66)
 55%|█████▌    | 110/200 [20:23<52:32, 35.03s/it]

110: Current accuracy: 0.936


 60%|██████    | 120/200 [21:54<13:48, 10.36s/it]

120: Current accuracy: 0.942


 65%|██████▌   | 130/200 [25:12<24:50, 21.29s/it]

130: Current accuracy: 0.938


 70%|███████   | 140/200 [27:36<16:47, 16.79s/it]

140: Current accuracy: 0.929


 75%|███████▌  | 150/200 [29:17<08:06,  9.73s/it]

150: Current accuracy: 0.927


 80%|████████  | 160/200 [30:40<06:06,  9.17s/it]

160: Current accuracy: 0.925


 85%|████████▌ | 170/200 [32:37<06:13, 12.44s/it]

170: Current accuracy: 0.929


 90%|█████████ | 180/200 [34:44<03:17,  9.90s/it]

180: Current accuracy: 0.933


 95%|█████████▌| 190/200 [36:17<01:27,  8.75s/it]

190: Current accuracy: 0.937


100%|██████████| 200/200 [37:49<00:00, 11.35s/it]

200: Current accuracy: 0.935
Teacher on test 0.935



