In [8]:
import os
import json

from jedi.api.helpers import infer
from langchain_core.outputs import LLMResult

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv

from langchain_experimental.utilities import PythonREPL

_ = load_dotenv(find_dotenv())

In [13]:
dataset_name = 'tabmwp'
mode = "pot"
model = "gpt-4o-mini"
num_samples = -1
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
n = 1
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset[0])




# Write Python Code to solve the following questions. Store your result as a variable named 'answer'. Use 'print(answer)' to output your answer. Your answer should follow the previous pattern and format.


Read the following table regarding "Class size" and and then answer a question.

Teacher | Number of students\nMrs. Truman | 23\nMiss Urban | 26\nMrs. Woodworth | 27\nMs. Hershfeld | 28

Question: Some teachers compared how many students are in their classes. Which teacher has the most students? Choose from the the options: ['Mrs. Truman', 'Miss Urban', 'Mrs. Woodworth', 'Ms. Hershfeld']


# Python code, return answer 
teachers = {
    'Mrs. Truman': 23,
    'Miss Urban': 26,
    'Mrs. Woodworth': 27,
    'Ms. Hershfeld': 28
}
# Find the teacher with the most students
most_students_teacher = max(teachers, key=teachers.get)
answer = most_students_teacher
print(answer)


Read the following table regarding "" and and then answer a question.

potatoes | $0.75 per kilogram\nzucchini | $0

In [14]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()
python_repl = PythonREPL()

async def inference(item: dict) -> str:
    try:
        response: LLMResult = await llm.agenerate(messages=[prompt.invoke(input=item)])
        response_content = response.generations[0][0].message.content
        try:
            prediction = python_repl.run(command=response_content, timeout=3)
            result = {**item, "code": response_content, "prediction": prediction}
        except Exception as e:
	        print(e)
	        result = {**item, "code": response_content, "prediction": str(e)}
            
    except Exception as e:
        print(e)
        result = {**item, "code": "", "prediction": "ERROR"}
    return result

async def pot_inference() -> list[dict]:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")
				
	return results


In [15]:
results = await pot_inference()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s][A[A

  1%|          | 1/100 [00:03<06:15,  3.79s/it][A[A

  2%|▏         | 2/100 [00:04<03:12,  1.97s/it][A[A

  3%|▎         | 3/100 [00:04<02:00,  1.24s/it][A[A

  4%|▍         | 4/100 [00:05<01:24,  1.13it/s][A[A

  6%|▌         | 6/100 [00:05<00:56,  1.67it/s][A[A

  7%|▋         | 7/100 [00:06<00:49,  1.87it/s][A[A

  8%|▊         | 8/100 [00:06<00:44,  2.07it/s][A[A

  9%|▉         | 9/100 [00:06<00:41,  2.21it/s][A[A

 10%|█         | 10/100 [00:07<00:38,  2.34it/s][A[A

 11%|█         | 11/100 [00:08<00:45,  1.94it/s][A[A

 14%|█▍        | 14/100 [00:09<00:42,  2.03it/s][A[A

 17%|█▋        | 17/100 [00:09<00:27,  3.07it/s][A[A

 19%|█▉        | 19/100 [00:10<00:28,  2.87it/s][A[A

 20%|██        | 20/100 [00:11<00:28,  2.84it/s][A[A

 21%|██        | 21/100 [00:11<00:28,  2.79it/s][A[A

 22%|██▏       | 22/100 [00:11<00:28,  2.78it/s][A[A

 23%|██▎       | 23/100 [

Connection error.




 39%|███▉      | 39/100 [00:17<00:21,  2.86it/s][A[A

 41%|████      | 41/100 [00:18<00:20,  2.84it/s][A[A

 43%|████▎     | 43/100 [00:18<00:16,  3.44it/s][A[A

 44%|████▍     | 44/100 [00:19<00:16,  3.31it/s][A[A

 46%|████▌     | 46/100 [00:19<00:17,  3.14it/s][A[A

 47%|████▋     | 47/100 [00:20<00:17,  3.07it/s][A[A

 48%|████▊     | 48/100 [00:20<00:17,  2.98it/s][A[A

 50%|█████     | 50/100 [00:21<00:17,  2.91it/s][A[A

 51%|█████     | 51/100 [00:21<00:17,  2.84it/s][A[A

 52%|█████▏    | 52/100 [00:21<00:16,  2.85it/s][A[A

 54%|█████▍    | 54/100 [00:22<00:16,  2.81it/s][A[A

 56%|█████▌    | 56/100 [00:23<00:18,  2.39it/s][A[A

 58%|█████▊    | 58/100 [00:24<00:14,  2.99it/s][A[A

 60%|██████    | 60/100 [00:24<00:13,  2.95it/s][A[A

 61%|██████    | 61/100 [00:25<00:13,  2.90it/s][A[A

 63%|██████▎   | 63/100 [00:25<00:12,  2.87it/s][A[A

 64%|██████▍   | 64/100 [00:26<00:12,  2.87it/s][A[A

 66%|██████▌   | 66/100 [00:27<00:14,  2.39it/

Connection error.




 48%|████▊     | 48/100 [00:21<00:25,  2.02it/s][A[A

 54%|█████▍    | 54/100 [00:22<00:13,  3.51it/s][A[A

 56%|█████▌    | 56/100 [00:23<00:14,  3.01it/s][A[A

 59%|█████▉    | 59/100 [00:24<00:15,  2.69it/s][A[A

 62%|██████▏   | 62/100 [00:24<00:11,  3.42it/s][A[A

 64%|██████▍   | 64/100 [00:25<00:10,  3.27it/s][A[A

 66%|██████▌   | 66/100 [00:26<00:10,  3.15it/s][A[A

 68%|██████▊   | 68/100 [00:27<00:11,  2.70it/s][A[A

 70%|███████   | 70/100 [00:27<00:09,  3.10it/s][A[A

 71%|███████   | 71/100 [00:28<00:09,  3.06it/s][A[A

 73%|███████▎  | 73/100 [00:28<00:09,  2.99it/s][A[A

 75%|███████▌  | 75/100 [00:29<00:09,  2.57it/s][A[A

 78%|███████▊  | 78/100 [00:30<00:07,  3.07it/s][A[A

 79%|███████▉  | 79/100 [00:30<00:07,  2.99it/s][A[A

 80%|████████  | 80/100 [00:31<00:06,  2.94it/s][A[A

 82%|████████▏ | 82/100 [00:31<00:06,  2.93it/s][A[A

 84%|████████▍ | 84/100 [00:33<00:06,  2.50it/s][A[A

 87%|████████▋ | 87/100 [00:33<00:04,  3.00it/