In [1]:
import os
import json

from datasets import DatasetDict
from langchain_core.outputs import LLMResult
from langchain_core.prompts import ChatPromptTemplate

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv
from langchain_core.prompts.chat import BaseMessage

_ = load_dotenv(find_dotenv())

In [39]:
dataset_name = 'toxicity'
mode = "direct"
model = "gpt-4o-mini"
num_samples = 1000
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
if dataset_name == "toxicity":
	temperature = 0.9
	n = 25
else:
	temperature = 0
	n = 1
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset)

ValueError: Dataset toxicity not supported

In [38]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()

async def inference(item: dict) -> str:
	try:
		if dataset_name == 'toxicity':
			response: LLMResult = await llm.agenerate(messages=[prompt.invoke(input=item)], stop=["."])
		if dataset_name in ["hotpot_qa", "trivia_qa", "ambig_qa"]:
			response_content = response.generations[0][0].message.content
			result = {**item, "prediction": response_content.split("A: ")[-1]}
		elif dataset_name in ["gsm8k", "tabmwp", "svamp"]:
			response_content = response.generations[0][0].message.content
			result = {**item, "prediction": response_content.split("Answer: ")[-1]}
		elif dataset_name == "toxicity":
			responses_content = [choice.message.content for choice in response.generations[0]]
			result = {**item, "prediction": responses_content}
	except Exception as e:
		print(e)
		result = {**item, "prediction": "ERROR"}
	return result

async def direct_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")
				
await direct_inference()

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:05<09:08,  5.54s/it][A
  2%|▏         | 2/100 [00:05<03:52,  2.37s/it][A
  3%|▎         | 3/100 [00:05<02:13,  1.37s/it][A
  5%|▌         | 5/100 [00:06<01:00,  1.57it/s][A
  9%|▉         | 9/100 [00:06<00:24,  3.76it/s][A
 12%|█▏        | 12/100 [00:06<00:16,  5.41it/s][A
 14%|█▍        | 14/100 [00:06<00:12,  6.73it/s][A
 16%|█▌        | 16/100 [00:06<00:10,  8.13it/s][A
 20%|██        | 20/100 [00:06<00:06, 11.68it/s][A
 24%|██▍       | 24/100 [00:06<00:04, 15.75it/s][A
 27%|██▋       | 27/100 [00:07<00:04, 15.11it/s][A
 32%|███▏      | 32/100 [00:07<00:03, 20.70it/s][A
 38%|███▊      | 38/100 [00:07<00:02, 27.91it/s][A
 45%|████▌     | 45/100 [00:07<00:01, 36.29it/s][A
 52%|█████▏    | 52/100 [00:07<00:01, 43.91it/s][A
 59%|█████▉    | 59/100 [00:07<00:00, 45.57it/s][A
 67%|██████▋   | 67/100 [00:07<00:00, 53.36it/s][A
 73%|███████▎  | 73/100 [00:07<00:00, 5