In [1]:
import os
import json

from langchain_core.outputs import LLMResult

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())

In [2]:
dataset_name = 'gsm8k'
mode = "cot"
model = "gpt-4o-mini"
num_samples = 1000
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
if dataset_name == "toxicity":
	temperature = 0.9
	n = 25
else:
	temperature = 0
	n = 1
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset[0])

  warn_beta(



Remember your answer should follow previous pattern and format.


Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?


Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. FINAL ANSWER: 72


Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?


Answer: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. FINAL ANSWER: 10


Question: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?


Answer: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50. Betty's grandpar

In [3]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()

async def inference(item: dict) -> str:
    try:
        response: LLMResult = await llm.agenerate(
            messages=[prompt.invoke(input=item)], 
            stop=[".", "\n"] if dataset_name == 'toxicity' else None
        )
        
        response_content = response.generations[0][0].message.content
        
        if dataset_name in ["hotpot_qa", "trivia_qa", "ambig_qa"]:
            result = {**item, "generation": response_content, "prediction": response_content.split("FINAL ANSWER: ")[-1]}
        elif dataset_name in ["gsm8k", "tabmwp", "svamp"]:
            result = {**item, "generation": response_content, "prediction": response_content.split("FINAL ANSWER: ")[-1]}
        else:
            result = {**item, "generation": response_content, "prediction": response_content}
    except Exception as e:
        print(e)
        result = {**item, "generation": "", "prediction": "ERROR"}
    return result

async def cot_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")


In [4]:
await cot_inference()

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:05<08:39,  5.25s/it][A
  6%|▌         | 6/100 [00:05<01:03,  1.49it/s][A
 12%|█▏        | 12/100 [00:05<00:25,  3.51it/s][A
 15%|█▌        | 15/100 [00:05<00:20,  4.18it/s][A
 17%|█▋        | 17/100 [00:06<00:16,  5.07it/s][A
 19%|█▉        | 19/100 [00:06<00:13,  6.10it/s][A
 24%|██▍       | 24/100 [00:06<00:07,  9.99it/s][A
 28%|██▊       | 28/100 [00:06<00:05, 12.74it/s][A
 32%|███▏      | 32/100 [00:06<00:04, 16.26it/s][A
 37%|███▋      | 37/100 [00:06<00:03, 20.46it/s][A
 41%|████      | 41/100 [00:06<00:02, 21.21it/s][A
 46%|████▌     | 46/100 [00:06<00:02, 26.46it/s][A
 50%|█████     | 50/100 [00:07<00:01, 28.73it/s][A
 54%|█████▍    | 54/100 [00:07<00:01, 28.03it/s][A
 59%|█████▉    | 59/100 [00:07<00:01, 32.83it/s][A
 63%|██████▎   | 63/100 [00:07<00:01, 32.58it/s][A
 68%|██████▊   | 68/100 [00:07<00:00, 35.93it/s][A
 73%|███████▎  | 73/100 [00:07<00:00