In [1]:
import os
import json

from langchain_core.outputs import LLMResult

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())

In [13]:
dataset_name = 'tabmwp'
mode = "cot"
model = "gpt-4o-mini"
num_samples = 1000
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
if dataset_name == "toxicity":
	temperature = 0.9
	n = 25
else:
	temperature = 0
	n = 1
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset[0])




Remember your answer should follow previous pattern.


Read the following table regarding "Class size" and and then answer a question.

Teacher | Number of students Mrs. Truman | 23 Miss Urban | 26 Mrs. Woodworth | 27 Ms. Hershfeld | 28

Question: Some teachers compared how many students are in their classes. Which teacher has the most students? Choose from the the options: ['Mrs. Truman', 'Miss Urban', 'Mrs. Woodworth', 'Ms. Hershfeld']


Answer: Find the greatest number in the table. Remember to compare the numbers starting with the highest place value. The greatest number is 29.\n\nNow find the corresponding teacher. Ms. Hershfeld corresponds to 29. FINAL ANSWER: Ms. Hershfeld


Read the following table regarding "" and and then answer a question.

potatoes | $0.75 per kilogram zucchini | $0.60 per kilogram beets | $0.29 per kilogram cucumbers | $0.96 per kilogram carrots | $0.62 per kilogram

Question: Hanson went to the store and bought 2 kilograms of cucumbers. How much did he s

In [14]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()

async def inference(item: dict) -> str:
    try:
        response: LLMResult = await llm.agenerate(
            messages=[prompt.invoke(input=item)], 
            stop=[".", "\n"] if dataset_name == 'toxicity' else None
        )
        
        response_content = response.generations[0][0].message.content
        
        if dataset_name in ["hotpot_qa", "trivia_qa", "ambig_qa"]:
            result = {**item, "generation": response_content, "prediction": response_content.split("FINAL ANSWER: ")[-1]}
        elif dataset_name in ["gsm8k", "tabmwp", "svamp"]:
            result = {**item, "generation": response_content, "prediction": response_content.split("FINAL ANSWER: ")[-1]}
        else:
            result = {**item, "generation": response_content, "prediction": response_content}
    except Exception as e:
        print(e)
        result = {**item, "generation": "", "prediction": "ERROR"}
    return result

async def cot_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")


In [15]:
await cot_inference()

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:04<06:47,  4.11s/it][A
  3%|▎         | 3/100 [00:04<01:49,  1.13s/it][A
  6%|▌         | 6/100 [00:04<00:44,  2.11it/s][A
  9%|▉         | 9/100 [00:04<00:24,  3.69it/s][A
 13%|█▎        | 13/100 [00:04<00:13,  6.25it/s][A
 17%|█▋        | 17/100 [00:04<00:08,  9.35it/s][A
 22%|██▏       | 22/100 [00:05<00:06, 12.70it/s][A
 25%|██▌       | 25/100 [00:05<00:05, 13.82it/s][A
 28%|██▊       | 28/100 [00:05<00:05, 14.07it/s][A
 33%|███▎      | 33/100 [00:05<00:03, 19.13it/s][A
 36%|███▌      | 36/100 [00:05<00:03, 20.71it/s][A
 40%|████      | 40/100 [00:05<00:02, 23.77it/s][A
 43%|████▎     | 43/100 [00:05<00:02, 23.39it/s][A
 46%|████▌     | 46/100 [00:06<00:02, 23.46it/s][A
 50%|█████     | 50/100 [00:06<00:01, 26.03it/s][A
 53%|█████▎    | 53/100 [00:06<00:01, 25.75it/s][A
 57%|█████▋    | 57/100 [00:06<00:01, 27.62it/s][A
 61%|██████    | 61/100 [00:06<00:01, 