In [5]:
import json
from langchain_core.messages.base import BaseMessage
from langchain_core.prompt_values import PromptValue
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from utils import load_prompt

import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")

_set_if_undefined("LANGCHAIN_API_KEY")

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_PROJECT"] = "cot"

# 并行处理
import asyncio

async def cot(item, prompt: ChatPromptTemplate, llm: ChatOpenAI, dataset_name:str="hotpot_qa") -> str:
    if dataset_name in ["gsm8k", "hotpot_qa", "ambig_qa", "trivia_qa"]:
        prompt_value: PromptValue = prompt.invoke({"question": item["question"]})
    elif dataset_name == "svamp":
        prompt_value = prompt.invoke({"Body": item["Body"], "Question": item["Question"]})
    elif dataset_name == "tabmwp":
        prompt_value = prompt.invoke({"question": item["question"], "table": item["table"], "table_title": item["table_title"]})
    try:
        result: BaseMessage = await llm.ainvoke(prompt_value)
        return result.content
    except Exception as e:
        return str(e)

In [3]:
# extract question and answer from ambig_qa
import random
def extract_question_answer(item):
    if item['annotations']['type'][0] == "singleAnswer":
        # single answer
        answers = item['nq_answer']
        for ans in item['annotations']['answer']:
            answers.extend(ans)
        item['answer'] = list(set(answers))
    else:
        # random choose a question with multiple answers
        qa_pairs = item['annotations']['qaPairs'][0]
        rand_i = random.randint(0, len(qa_pairs['question'])-1)
        item['question'] = qa_pairs['question'][rand_i]
        item['answer'] = qa_pairs['answer'][rand_i]
    
    return {"question": item["question"], "answer": item["answer"]}

# gsm8k svamp tabmwp hotpot_qa ambig_qa trivia_qa

In [6]:
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
from datasets.iterable_dataset import IterableDataset
from langchain_core.prompts.chat import ChatPromptTemplate


dataset_name = "tabmwp"
mode = "cot"
num_test_sample = 1000
temperature = 0
top_p = 1
batch_size = 100
prompt: ChatPromptTemplate = load_prompt(dataset_name)
llm = ChatOpenAI(temperature=0, base_url="https://api.chsdw.top/v1", top_p=1, model="gpt-4o-mini")

print(prompt)

from datasets import load_dataset
dataset: DatasetDict | Dataset | IterableDatasetDict | IterableDataset = load_dataset("json", data_files=f"../data/{dataset_name}.jsonl", split="train")
if num_test_sample > 0:
    dataset = dataset.select(range(num_test_sample))

if dataset_name == "ambig_qa":
    dataset = dataset.map(extract_question_answer, remove_columns=['id', 'annotations', 'viewed_doc_titles', 'used_queries', 'nq_answer', 'nq_doc_title'])
print(dataset, dataset[0])
results = []

input_variables=['question', 'table', 'table_title'] metadata={'lc_hub_owner': 'ariete', 'lc_hub_repo': 'tabmwp_4shot', 'lc_hub_commit_hash': 'b81b7059054de7a694a0572488479e04d0729b68068924d13b4a4ff5c9b2516c'} messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template="# Write Python Code to solve the following questions. Store your result as a variable named 'answer'. Use 'print(answer)' to output your answer. You should follow the format of the cases below.")), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='Read the following table regarding "Coin collections" and then write Python code to answer a question:\n\nName | Number of coins\nBraden | 76\nCamilla | 94\nRick | 86\nMary | 84\nHector | 80\nDevin | 83\nEmily | 82\nAvery | 87\n\nQuestion: Some friends discussed the sizes of their coin collections. What is the mean of the numbers?')), AIMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='# Python C

In [3]:
result: str = await cot(dataset[0], prompt, llm, dataset_name)
print(result)

# Python Code, return answer
number_of_coins = [76, 94, 86, 84, 80, 83, 82, 87]
mean_coins = sum(number_of_coins) / len(number_of_coins)
answer = mean_coins
print(answer)


## Load Dataset

In [7]:
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

for i in range(900, 1000, batch_size):
    batch = dataset.select(range(i, min(i+batch_size, len(dataset))))
    results.extend(await tqdm_asyncio.gather(*(cot(item, prompt, llm, dataset_name) for item in batch)))
    with open(f"/Users/ariete/Projects/self-improve/output/inference/{dataset_name}/{mode}_{900-1000}_temperature_{temperature}_top-p_{top_p}.jsonl", "w") as f:
        for idx, result in enumerate(results):
            f.write(json.dumps({"question":dataset[idx]["question"], "answer": dataset[idx]["answer"], "prediction": result}) + "\n")

100%|██████████| 100/100 [01:04<00:00,  1.54it/s]
