In [11]:
import os
import json

from langchain_core.outputs import LLMResult

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv

from langchain_experimental.utilities import PythonREPL

_ = load_dotenv(find_dotenv())

In [12]:
dataset_name = 'gsm8k'
mode = "pot"
model = "gpt-4o-mini"
num_samples = -1
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
n = 1
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset[0])




# Write Python Code to solve the following questions. Store your result as a variable named 'answer'. Use 'print(answer)' to output your answer. Your answer should follow the previous pattern and format.


Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?


# Python code, return answer
clips_sold_in_april = 48
clips_sold_in_may = clips_sold_in_april / 2
total_clips_sold = clips_sold_in_april + clips_sold_in_may
answer = total_clips_sold
print(answer)


Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?


# Python code, return answer
hourly_rate = 12
minutes_worked = 50
hours_worked = minutes_worked / 60
earnings = hourly_rate * hours_worked
answer = earnings
print(answer)


Question: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents deci

In [22]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()
python_repl = PythonREPL()

async def inference(item: dict) -> str:
    try:
        response: LLMResult = await llm.agenerate(messages=[prompt.invoke(input=item)])
        response_content = response.generations[0][0].message.content
        try:
            prediction = python_repl.run(command=response_content, timeout=3)
            result = {**item, "code": response_content, "prediction": prediction}
        except Exception as e:
	        print(e)
	        result = {**item, "code": response_content, "prediction": str(e)}
            
    except Exception as e:
        print(e)
        result = {**item, "code": "", "prediction": "ERROR"}
    return result

async def pot_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")


In [23]:
await pot_inference()

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s][A[APython REPL can execute arbitrary code. Use with caution.


  1%|          | 1/100 [00:04<07:27,  4.52s/it][A[A

  2%|▏         | 2/100 [00:05<03:43,  2.28s/it][A[A

  3%|▎         | 3/100 [00:05<02:16,  1.40s/it][A[A

  4%|▍         | 4/100 [00:06<01:47,  1.12s/it][A[A

  6%|▌         | 6/100 [00:06<00:58,  1.60it/s][A[A

  7%|▋         | 7/100 [00:07<01:00,  1.55it/s][A[A

  9%|▉         | 9/100 [00:07<00:39,  2.28it/s][A[A

 10%|█         | 10/100 [00:08<00:44,  2.02it/s][A[A

 11%|█         | 11/100 [00:08<00:40,  2.19it/s][A[A

 13%|█▎        | 13/100 [00:09<00:29,  2.98it/s][A[A

 15%|█▌        | 15/100 [00:09<00:28,  2.97it/s][A[A

 16%|█▌        | 16/100 [00:10<00:28,  2.98it/s][A[A

 17%|█▋        | 17/100 [00:10<00:34,  2.40it/s][A[A

 19%|█▉        | 19/100 [00:11<00:25,  3.14it/s][A[A

 20%|██        | 20/100 [00:11<00:31,  2.52it/s][A[A

 22%|██▏       | 22/10