In [1]:
# Imports:

import os
import huggingface_hub
import datasets
from pprint import pprint
from types import SimpleNamespace
import asyncio
import nest_asyncio
import oaib
import json

In [2]:
# Define constants:

# Dataset related.
ds_conf = SimpleNamespace(
    name = "lmsys/chatbot_arena_conversations",
    skip_rows = 0,
    rows_to_process = None, # None for all rows.
)

# Model Related. Prices: https://openai.com/api/pricing/
model_conf = SimpleNamespace(
    name = "gpt-4o", 
    in_tok_price = 1.25/1e6, 
    out_tok_price = 10.00/1e6,
)
# model_conf = SimpleNamespace(
#     name = "gpt-4o-mini",
#     in_tok_price = 0.15/1e6,
#     out_tok_price = 0.60/1e6,
# )

# Output file related.
out_file_name = f"DATA/dataset_arena_{model_conf.name}.jsonl"
cache_file_name = f"DATA/dataset_arena_{model_conf.name}.arrow"

# Prompt related.
class LimitedPrompt:
    def __init__(self, max_words):
        self.max_words = max_words
        # A word is "between 5 and 6.5 characters per word including spaces and punctuation":
        # https://charactercounter.com/characters-to-words
        self.max_ch_soft = max_words * 6
        # Hard limit adds 20% buffer and divides by 4 to get LLM token limit.
        # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
        self.max_tokens_hard = int(self.max_ch_soft * 1.2 / 4)

normal_prompt = LimitedPrompt(120)
normal_prompt.text = f'''Answer the user prompt below "---" line. Never exceed {normal_prompt.max_ch_soft} characters / {normal_prompt.max_words} words.
---
{{question}}
'''

brief_prompt = LimitedPrompt(20)
brief_prompt.text = f'''Given the user prompt and a normal answer, generate a brief answer. A brief answer should be as short as possible but still answer the question and give relevant information. Never exceed {brief_prompt.max_ch_soft} characters / {brief_prompt.max_words} words.
Examples between --- lines:
--- Example 1 ---
Your input:
Question: How much is 2+3?
Normal answer: Expression 2+3 is equal to 5.
Your output:
5
--- Example 2 ---
Your input:
Question: What is the color of the sky?
Normal answer: The sky is blue.
Your output:
Blue
--- End of examples

Considering all the above, give a brief answer to the prompt and normal answer below:
Question: {{question}}
Normal answer: {{normal_asnwer}}
Your output: 
'''

# Run related.
batch_size = 500
total_price = 0.0

In [None]:
# Load the dataset from HuggingFace.

huggingface_hub.login(token=os.getenv("HF_KEY"))

ds = datasets.load_dataset(ds_conf.name)['train']

if ds_conf.rows_to_process==None:
    ds_conf.rows_to_process=len(ds)-ds_conf.skip_rows

ds_range = range(ds_conf.skip_rows, ds_conf.skip_rows+ds_conf.rows_to_process)
ds = ds.select(ds_range) 
print(f"Loaded dataset range: {ds_range}")

ds = ds.select_columns(['question_id', 'conversation_a'])\
    .rename_column('question_id', 'question-id')\
    .map(lambda example: {'question': example['conversation_a'][0]['content']})

pprint(ds[4]) # Print a sample row.

In [4]:
# Define the generic batch processing function.

nest_asyncio.apply()

async def batch_call_llm(prompts, max_tokens):
    ''' Call the LLM API in a batch using OAIB library. '''

    global total_price

    # This is very strange. If index is not specified, the API returns the results in a different
    # order. If the index is specified, the results are returned in the same order as the input. So,
    # we add index but never actually use it.
    auto_batch = oaib.Auto(workers=8, index=["idx"])
    # auto_batch = oaib.Auto(workers=8)
    
    for idx, prompt in enumerate(prompts):
        messages=[{"role": "user", "content": prompt}]
        await auto_batch.add("chat.completions.create", 
                             metadata={"idx": idx},
                             model=model_conf.name, 
                             messages=messages, 
                             max_tokens=max_tokens)                       
    
    output = await auto_batch.run()
    
    answers, in_tokens, out_tokens = [], 0, 0
    for _, row in output.iterrows():
        answers.append(row.result['choices'][0]['message']['content'])
        in_tokens += row.result['usage']['prompt_tokens']
        out_tokens += row.result['usage']['completion_tokens']
    
    total_price += in_tokens*model_conf.in_tok_price + out_tokens*model_conf.out_tok_price
    
    return answers

def process_batch_sync(batch, indices, file_handle):
    ''' Process a batch of questions synchronously. '''

    print(f"Batch size: {len(indices)}, Start index: {indices[0]}")

    # Run 2-step LLM calls for all questions in the batch.
    questions = batch['question']
    normal_prompts = [normal_prompt.text.format(question=q) for q in questions]
    normal_answers = asyncio.run(batch_call_llm(normal_prompts, normal_prompt.max_tokens_hard))
    brief_prompts = [brief_prompt.text.format(question=q, normal_asnwer=na) 
                     for q, na in zip(questions, normal_answers)]
    brief_answers = asyncio.run(batch_call_llm(brief_prompts, brief_prompt.max_tokens_hard))

    # Append the batch to JSONL file.
    for id, q, np, na, bp, ba in zip(batch['question-id'], questions, 
                                     normal_prompts, normal_answers, 
                                     brief_prompts, brief_answers):
            entry = {
                'question-id': id, 'question': q,
                'normal-prompt': np, 'normal-answer': na,
                'brief-prompt': bp, 'brief-answer': ba,
            }
            json.dump(entry, file_handle, ensure_ascii=False)
            file_handle.write('\n')
    
    # Return new columns.
    return {
        'normal-prompt': normal_prompts,
        'normal-answer': normal_answers,
        'brief-prompt': brief_prompts,
        'brief-answer': brief_answers,
    }

In [None]:
# Generate both normal and brief answers in batches.

with open(out_file_name, 'a') as f:
    ds = ds.map(process_batch_sync, batched=True, batch_size=batch_size, with_indices=True,
                cache_file_name=cache_file_name, fn_kwargs={'file_handle': f})

pprint(ds[4]) # Print a sample row.

In [None]:
### DEBUG CELL ###

# Access and print the cache file paths
for cache_file in ds.cache_files:
    print(cache_file['filename'])

for i in range(11, 15):
    print(f"Q: {ds[i]['question']}")
    print(f"NORMAL: {ds[i]['normal-answer']}") 
    print(f"BRIEF: {ds[i]['brief-answer']}") 
    print("-------------------------------")

print(f"Total price: {total_price:.2f} USD")