In [35]:
import json
from langchain_core.messages.base import BaseMessage
from langchain_core.prompt_values import PromptValue
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate

import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")

_set_if_undefined("LANGCHAIN_API_KEY")

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_PROJECT"] = "cot"

from langchain import hub

# Load prompt
def load_prompt(dataset_name: str) -> ChatPromptTemplate:
    dataset_prompts = {
        "gsm8k": "gsm8k_9shot",
        "svamp": "svamp_7shot",
        "tabmwp": "tabmwp_4shot",
        "hotpot_qa": "hotpot_qa_6shot",
        "ambig_qa": "ambig_qa_5shot",
        "trivia_qa": "trivia_qa_5shot",
    }
    
    if dataset_name not in dataset_prompts:
        raise ValueError(f"Dataset {dataset_name} not supported")
    
    prompt = hub.pull(dataset_prompts[dataset_name])
    return prompt
    


# 并行处理
import asyncio

async def cot(item, prompt: ChatPromptTemplate, llm: ChatOpenAI, dataset_name:str="hotpot_qa") -> str:
    if dataset_name in ["gsm8k", "hotpot_qa", "ambig_qa", "trivia_qa"]:
        prompt_value: PromptValue = prompt.invoke({"question": item["question"]})
    elif dataset_name == "svamp":
        prompt_value = prompt.invoke({"Body": item["Body"], "Question": item["Question"]})
    elif dataset_name == "tabmwp":
        prompt_value = prompt.invoke({"question": item["question"], "table": item["table"], "table_title": item["table_title"]})
    try:
        result: BaseMessage = await llm.ainvoke(prompt_value)
        return result.content
    except Exception as e:
        return "None"

In [30]:
# extract question and answer from ambig_qa
import random
def extract_question_answer(item):
    if item['annotations']['type'][0] == "singleAnswer":
        # single answer
        answers = item['nq_answer']
        for ans in item['annotations']['answer']:
            answers.extend(ans)
        item['answer'] = list(set(answers))
    else:
        # random choose a question with multiple answers
        qa_pairs = item['annotations']['qaPairs'][0]
        rand_i = random.randint(0, len(qa_pairs['question'])-1)
        item['question'] = qa_pairs['question'][rand_i]
        item['answer'] = qa_pairs['answer'][rand_i]
    
    return {"question": item["question"], "answer": item["answer"]}

# gsm8k svamp tabmwp hotpot_qa ambig_qa trivia_qa

In [36]:
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
from datasets.iterable_dataset import IterableDataset
from langchain_core.prompts.chat import ChatPromptTemplate


dataset_name = "ambig_qa"
mode = "cot"
num_test_sample = 1000
temperature = 0
top_p = 1
batch_size = 100
prompt: ChatPromptTemplate = load_prompt(dataset_name)
llm = ChatOpenAI(temperature=0, base_url="https://api.chsdw.top/v1", top_p=1, model="gpt-4o-mini")

print(prompt)

from datasets import load_dataset
dataset: DatasetDict | Dataset | IterableDatasetDict | IterableDataset = load_dataset("json", data_files=f"../data/{dataset_name}.jsonl", split="train")
if num_test_sample > 0:
    dataset = dataset.select(range(num_test_sample))

if dataset_name == "ambig_qa":
    dataset = dataset.map(extract_question_answer, remove_columns=['id', 'annotations', 'viewed_doc_titles', 'used_queries', 'nq_answer', 'nq_doc_title'])
print(dataset, dataset[0])
results = []

input_variables=['question'] metadata={'lc_hub_owner': '-', 'lc_hub_repo': 'ambig_qa_5shot', 'lc_hub_commit_hash': 'd229fd2b0f6d226826ed84152137089274dbe51f34abe1dbdeccb180f057bc44'} messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='Q: What airport is closest to Palm Springs?')), AIMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='A: The nearest airport to Palm Springs is Indio/Palm Springs (PSP) Airport which is 2.1 miles away. FINAL ANSWER: Palm Springs International Airport')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='Q: What degree did Martin Luther King get?')), AIMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='A: Martin Luther King earned his Bachelor of Divinity degree from Crozer Theological Seminary, followed by a doctorate in Systematic Theology from Boston University. FINAL ANSWER: Bachelor of Divinity')), HumanMessagePromptTemplate(prompt=PromptTemp

In [34]:
result: str = await cot(dataset[0], prompt, llm, dataset_name)
print(result)

A: In Season 1 of "Dexter," the character Dr. Rudy Cooper, who later becomes known as the Ice Truck Killer, is played by Christian Camargo. FINAL ANSWER: Christian Camargo.


## Load Dataset

In [37]:
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

for i in range(0, 1000, batch_size):
    batch = dataset.select(range(i, min(i+batch_size, len(dataset))))
    results.extend(await tqdm_asyncio.gather(*(cot(item, prompt, llm, dataset_name) for item in batch)))
    with open(f"/Users/ariete/Projects/self-improve/output/inference/{dataset_name}/{mode}_{1000}_temperature_{temperature}_top-p_{top_p}.jsonl", "w") as f:
        for idx, result in enumerate(results):
            # if dataset_name == "ambig_qa":
            #     if dataset[idx]['annotations']['type'][0] == "singleAnswer":
            #         # single answer
            #         answers = dataset[idx]['nq_answer']
            #         for ans in dataset[idx]['annotations']['answer']:
            #             answers.extend(ans)
            #         answers = list(set(answers))
            #     else:
            #         # random choose a question with multiple answers
            #         qa_pairs = dataset[idx]['annotations']['qaPairs'][0]
            #         rand_i = random.randint(0, len(qa_pairs['question'])-1)
            #         dataset[idx]['question'] = qa_pairs['question'][rand_i]
            #         dataset[idx]['answer'] = qa_pairs['answer'][rand_i]
            f.write(json.dumps({"question":dataset[idx]["question"], "answer": dataset[idx]["answer"], "prediction": result}) + "\n")

100%|██████████| 100/100 [00:06<00:00, 14.93it/s]
100%|██████████| 100/100 [00:06<00:00, 14.79it/s]
100%|██████████| 100/100 [00:05<00:00, 17.79it/s]
100%|██████████| 100/100 [00:06<00:00, 14.85it/s]
100%|██████████| 100/100 [00:06<00:00, 15.58it/s]
100%|██████████| 100/100 [00:07<00:00, 13.16it/s]
100%|██████████| 100/100 [00:05<00:00, 16.95it/s]
100%|██████████| 100/100 [00:08<00:00, 12.40it/s]
100%|██████████| 100/100 [00:06<00:00, 16.22it/s]
100%|██████████| 100/100 [00:06<00:00, 15.67it/s]
