In [1]:
import os
import json

from langchain_core.outputs import LLMResult

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())

In [17]:
dataset_name = 'gsm8k'
mode = "self-consistency"
model = "gpt-4o-mini"
num_samples = -1
top_p = 0.95
# 原文使用的gpt-3 温度为0.7， n=10
temperature = 0.7
seed = 42
batch_size = 100
n = 10
processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset[0])




Remember your answer should follow previous pattern and format.


Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?


Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. FINAL ANSWER: 72


Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?


Answer: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute. Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10. FINAL ANSWER: 10


Question: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?


Answer: In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50. Betty's grandpar

In [11]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()

def vote(candidate: list, split_signal: str) -> str:
    predictions = [message.split(split_signal)[-1].strip() for message in candidate]
    return max(set(predictions), key=predictions.count)

async def inference(item: dict) -> str:
    try:
        response: LLMResult = await llm.agenerate(messages=[prompt.invoke(input=item)])
        candidate = [choice.message.content for choice in response.generations[0]]
        prediction = vote(candidate, split_signal="FINAL ANSWER: ")
        if dataset_name in ["hotpot_qa", "trivia_qa", "ambig_qa"]:
            result = {**item, "candidate": candidate, "prediction": prediction}
        elif dataset_name in ["gsm8k", "tabmwp", "svamp"]:
            result = {**item, "candidate": candidate, "prediction": prediction}
        elif dataset_name == "toxicity":
            result = {**item, "candidate": candidate}
        else:
            result = {**item, "candidate": candidate}
    except Exception as e:
        print(e)
        result = {**item, "prediction": "ERROR"}
    return result

async def self_consistency_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(inference(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")


In [19]:
await self_consistency_inference()

  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:33,  1.55s/it][A
100%|██████████| 100/100 [00:01<00:00, 58.23it/s][A
 25%|██▌       | 1/4 [00:01<00:05,  1.75s/it]

Connection error.
Connection error.
Connection error.
Connection error.
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163905589000418qp5BXJuO)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163905588043925YmayGLRn)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163905597155525uiVkmPr0)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163905618929733I9F5zpbh)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163905610228420ZC3exGPG)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163905619304813G1f7hsPY)', 


  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<01:55,  1.17s/it][A
100%|██████████| 100/100 [00:01<00:00, 73.41it/s][A

Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 202410081639078909871awvQ3ZhJ)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 2024100816390712480565jT5mjqVp)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 2024100816390715477013CU6KiAWW)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 2024100816390717849398GIA8hjxw)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163907196285284j0kUnJS)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163907237919405580R43D)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 Tok


 50%|█████     | 2/4 [00:03<00:03,  1.54s/it]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<02:03,  1.25s/it][A
100%|██████████| 100/100 [00:01<00:00, 71.08it/s][A
 75%|███████▌  | 3/4 [00:04<00:01,  1.49s/it]

Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908465678735uw0EN0WI)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908464812071RiDchsGR)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908468751668U6d8nGkT)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 202410081639084710669912CfNmqwd)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908475675527QyBxtfPG)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908473103213TtBt4VSM)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度


  0%|          | 0/19 [00:00<?, ?it/s][A
100%|██████████| 19/19 [00:00<00:00, 63.77it/s][A
100%|██████████| 4/4 [00:04<00:00,  1.23s/it]

Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908986028015fw1PTCsB)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908989041308zR3XK2Ym)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908991732425n1LSgqas)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908992815354EgIHYJ5R)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 20241008163908997151805nyYgJcGr)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度已用尽 TokenStatusExhausted[sk-4CW***4Da] (request id: 202410081639089999084497Pr7Vaob)', 'type': 'new_api_error'}}
Error code: 401 - {'error': {'message': '该令牌额度


