In [1]:
import os
import json

from langchain_core.outputs import LLMResult
from langchain_core.prompts import ChatPromptTemplate

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv
from langchain_core.prompts.chat import BaseMessage

_ = load_dotenv(find_dotenv())

In [9]:
dataset_name = 'trivia_qa'
mode = "direct"
model = "gpt-4o-mini"
num_samples = 1000
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
if dataset_name == "toxicity":
	temperature = 0.9
	n = 25
else:
	temperature = 0
	n = 1
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset)



Generating train split: 0 examples [00:00, ? examples/s]


Remember your answer should follow previous pattern.


Q: Mendelssohn's 'Wedding March' was. originally written as incidental music for which Shakespeare play in 1842?


A: A Midsummer Night's Dream


Q: """Christ in the House of his Parents"" is one of the best known paintings of which artist?"


A: John Millais


Q: Who designed the National Theatre building on the South Bank in London ?


A: Sir Denys Lasdun


Q: Also a two-time World Champion, which American skier won the gold medal in the Men's Combined at the 2010 Winter Olympics?


A: Bodie Miller


Q: Famous composer, Handel, originally studied what?


A: Law


Q: Which great philosopher corresponded with Queen Christina of Sweden in his final years and died in 1650 in Stockholm where he had been invited as a teacher for her?


A: René Decartes


Q: Who was the female member of Britain's gang of four?


A: Baroness Williams


Q: An icosahedron has how many faces?


A: twenty


Q: What chemical makes hot peppers hot?


A: Capsc

In [10]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()

async def inference(item: dict) -> str:
	try:
		response: LLMResult = await llm.agenerate(messages=[prompt.invoke(input=item)])
		if dataset_name in ["hotpot_qa", "trivia_qa", "ambig_qa"]:
			response_content = response.generations[0][0].message.content
			result = {**item, "prediction": response_content.split("A: ")[-1]}
		elif dataset_name in ["gsm9k", "tabmwp", "svamp"]:
			response_content = response.generations[0][0].message.content
			result = {**item, "prediction": response_content.split("Answer: ")[-1]}
		elif dataset_name == "toxicity":
			responses_content = [choice.message.content for choice in response.generations[0]]
			result = {**item, "prediction": responses_content}
	except Exception as e:
		print(e)
		result = {**item, "prediction": "ERROR"}
	return result

async def direct_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), num_samples, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, num_samples)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")
				
await direct_inference()

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:03<05:52,  3.56s/it][A
  2%|▏         | 2/100 [00:03<02:30,  1.53s/it][A
  5%|▌         | 5/100 [00:03<00:45,  2.07it/s][A
  7%|▋         | 7/100 [00:03<00:28,  3.25it/s][A
 14%|█▍        | 14/100 [00:04<00:10,  8.32it/s][A
 20%|██        | 20/100 [00:04<00:05, 13.47it/s][A
 27%|██▋       | 27/100 [00:04<00:03, 20.07it/s][A
 32%|███▏      | 32/100 [00:04<00:03, 20.17it/s][A
 36%|███▌      | 36/100 [00:04<00:03, 17.77it/s][A
 39%|███▉      | 39/100 [00:05<00:03, 18.73it/s][A
 46%|████▌     | 46/100 [00:05<00:01, 27.03it/s][A
 52%|█████▏    | 52/100 [00:05<00:01, 32.26it/s][A
 60%|██████    | 60/100 [00:05<00:00, 41.62it/s][A
 68%|██████▊   | 68/100 [00:05<00:00, 46.69it/s][A
 76%|███████▌  | 76/100 [00:05<00:00, 53.46it/s][A
 86%|████████▌ | 86/100 [00:05<00:00, 61.29it/s][A
100%|██████████| 100/100 [00:05<00:00, 16.86it/s][A
 10%|█         | 1/10 [00:05<00:53,  5