In [1]:
# Imports:

import os
import huggingface_hub
import datasets
from pprint import pprint
from types import SimpleNamespace
import asyncio
import nest_asyncio
import oaib
import json

In [2]:
# Define constants:

# Dataset related.
ds_conf = SimpleNamespace(
    name = "lmsys/chatbot_arena_conversations",
    skip_rows = 27000,
    rows_to_process = None, # None for all rows.
)

# Model Related. Prices: https://openai.com/api/pricing/
model_conf = SimpleNamespace(
    name = "gpt-4o", 
    in_tok_price = 1.25/1e6, 
    out_tok_price = 10.00/1e6,
)
# model_conf = SimpleNamespace(
#     name = "gpt-4o-mini",
#     in_tok_price = 0.15/1e6,
#     out_tok_price = 0.60/1e6,
# )

# Output file related.
out_file_name = f"DATA/dataset_arena_{model_conf.name}.jsonl"
cache_file_name = f"DATA/dataset_arena_{model_conf.name}.arrow"

# Prompt related.
class LimitedPrompt:
    def __init__(self, max_words):
        self.max_words = max_words
        # A word is "between 5 and 6.5 characters per word including spaces and punctuation":
        # https://charactercounter.com/characters-to-words
        self.max_ch_soft = max_words * 6
        # Hard limit adds 20% buffer and divides by 4 to get LLM token limit.
        # https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
        self.max_tokens_hard = int(self.max_ch_soft * 1.2 / 4)

normal_prompt = LimitedPrompt(120)
normal_prompt.text = f'''Answer the user prompt below "---" line. Never exceed {normal_prompt.max_ch_soft} characters / {normal_prompt.max_words} words.
---
{{question}}
'''

brief_prompt = LimitedPrompt(20)
brief_prompt.text = f'''Given the user prompt and a normal answer, generate a brief answer. A brief answer should be as short as possible but still answer the question and give relevant information. Never exceed {brief_prompt.max_ch_soft} characters / {brief_prompt.max_words} words.
Examples between --- lines:
--- Example 1 ---
Your input:
Question: How much is 2+3?
Normal answer: Expression 2+3 is equal to 5.
Your output:
5
--- Example 2 ---
Your input:
Question: What is the color of the sky?
Normal answer: The sky is blue.
Your output:
Blue
--- End of examples

Considering all the above, give a brief answer to the prompt and normal answer below:
Question: {{question}}
Normal answer: {{normal_asnwer}}
Your output: 
'''

# Run related.
batch_size = 500
total_price = 0.0

In [3]:
# Load the dataset from HuggingFace.

huggingface_hub.login(token=os.getenv("HF_KEY"))

ds = datasets.load_dataset(ds_conf.name)['train']

if ds_conf.rows_to_process==None:
    ds_conf.rows_to_process=len(ds)-ds_conf.skip_rows

ds_range = range(ds_conf.skip_rows, ds_conf.skip_rows+ds_conf.rows_to_process)
ds = ds.select(ds_range) 
print(f"Loaded dataset range: {ds_range}")

ds = ds.select_columns(['question_id', 'conversation_a'])\
    .rename_column('question_id', 'question-id')\
    .map(lambda example: {'question': example['conversation_a'][0]['content']})

pprint(ds[4]) # Print a sample row.

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/zel/.cache/huggingface/token
Login successful
Loaded dataset range: range(27000, 33000)
{'conversation_a': [{'content': 'make an original tongue twister with black '
                                'metal aesthetics',
                     'role': 'user'},
                    {'content': '"Faster than a black metal lightning, more '
                                'evil than a demonic riff, stronger than an '
                                'abyssal howl, and more intense than a blast '
                                'beat, the darkness shall consume all in its '
                                'wake, leaving only ash and embers in its '
                                'path!"',
              

In [4]:
# Define the generic batch processing function.

nest_asyncio.apply()

async def batch_call_llm(prompts, max_tokens):
    ''' Call the LLM API in a batch using OAIB library. '''

    global total_price

    # This is very strange. If index is not specified, the API returns the results in a different
    # order. If the index is specified, the results are returned in the same order as the input. So,
    # we add index but never actually use it.
    auto_batch = oaib.Auto(workers=8, index=["idx"])
    # auto_batch = oaib.Auto(workers=8)
    
    for idx, prompt in enumerate(prompts):
        messages=[{"role": "user", "content": prompt}]
        await auto_batch.add("chat.completions.create", 
                             metadata={"idx": idx+1000}, # Add 1000 for debugging purposes.
                             model=model_conf.name, 
                             messages=messages, 
                             max_tokens=max_tokens)                       
    
    output = await auto_batch.run()

    if len(output)!=len(prompts):
        print(f"MISSING ANSWERS ISSUE: len(output)={len(output)}, len(prompts)={len(prompts)}")
    
    answers = [""]*batch_size
    in_tokens, out_tokens = 0, 0
    for _, row in output.iterrows():
        idx = row.name-1000 # ChatGPT: "The Name attribute of a Series in pandas always corresponds to the index value of that row in the DataFrame."
        answers[idx] = row.result['choices'][0]['message']['content']
        in_tokens += row.result['usage']['prompt_tokens']
        out_tokens += row.result['usage']['completion_tokens']
    
    total_price += in_tokens*model_conf.in_tok_price + out_tokens*model_conf.out_tok_price
    
    return answers

def process_batch_sync(batch, indices, file_handle):
    ''' Process a batch of questions synchronously. '''

    print(f"Batch size: {len(indices)}, Start index: {indices[0]}")

    # Run 2-step LLM calls for all questions in the batch.
    retry = 0
    while retry<2:
        questions = batch['question']
        normal_prompts = [normal_prompt.text.format(question=q) for q in questions]
        normal_answers = asyncio.run(batch_call_llm(normal_prompts, normal_prompt.max_tokens_hard))
        brief_prompts = [brief_prompt.text.format(question=q, normal_asnwer=na) 
                        for q, na in zip(questions, normal_answers)]
        brief_answers = asyncio.run(batch_call_llm(brief_prompts, brief_prompt.max_tokens_hard))
        if len(questions)==len(normal_answers)==len(brief_answers)==batch_size:
            break
        else:
            print(f"ANSWERS MISSING, RETRYING: len(questions)={len(questions)}, " + 
                   f"len(normal_answers)={len(normal_answers)}, len(brief_answers)={len(brief_answers)}") 
            retry += 1

    # Append the batch to JSONL file.
    for id, q, np, na, bp, ba in zip(batch['question-id'], questions, 
                                     normal_prompts, normal_answers, 
                                     brief_prompts, brief_answers):
            entry = {
                'question-id': id, 'question': q,
                'normal-prompt': np, 'normal-answer': na,
                'brief-prompt': bp, 'brief-answer': ba,
            }
            json.dump(entry, file_handle, ensure_ascii=False)
            file_handle.write('\n')
    
    # Return new columns.
    return {
        'normal-prompt': normal_prompts,
        'normal-answer': normal_answers,
        'brief-prompt': brief_prompts,
        'brief-answer': brief_answers,
    }

In [5]:
# Generate both normal and brief answers in batches.

with open(out_file_name, 'a') as f:
    ds = ds.map(process_batch_sync, batched=True, batch_size=batch_size, with_indices=True,
                cache_file_name=cache_file_name, fn_kwargs={'file_handle': f})

pprint(ds[4]) # Print a sample row.

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]

Batch size: 500, Start index: 0


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 19.57s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 27.50s.

Batch size: 500, Start index: 500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 17.69s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 26.60s.

Batch size: 500, Start index: 1000


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 22.97s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 28.18s.

Batch size: 500, Start index: 1500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 15.25s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 61.40s.

MISSING ANSWERS ISSUE: len(output)=499, len(prompts)=500
endpoint                                chat.completions.create
model                                                    gpt-4o
messages      [{'role': 'user', 'content': 'Given the user p...
max_tokens                                                   36
result        {'id': 'chatcmpl-AgVbYazfCPRj5mSF7dSkdSjC1h5Cs...
Name: 1067, dtype: object
{'choices': [{'finish_reason': 'stop',
              'index': 0,
              'logprobs': None,
              'message': {'audio': None,
                          'content': 'Autobiography: full life story; memoir: '
                                     'specific themes/events, introspective.',
                          'function_call': None,
                          'refusal': None,
                          'role': 'assistant',
                          'tool_calls': None}}],
 'created': 1734694628,
 'id': 'chatcmpl-AgVbYazfCPRj5mSF7dSkdSjC1h5Cs',
 'model': 'gpt-4o-20

  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 27.08s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 24.21s.

Batch size: 500, Start index: 2500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 18.96s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 26.01s.

Batch size: 500, Start index: 3000


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 19.17s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 26.91s.

Batch size: 500, Start index: 3500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 63.42s.

MISSING ANSWERS ISSUE: len(output)=499, len(prompts)=500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 34.08s.

Batch size: 500, Start index: 4000


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 18.53s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 13.24s.

Batch size: 500, Start index: 4500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 18.21s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 29.97s.

Batch size: 500, Start index: 5000


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 18.53s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 23.79s.

Batch size: 500, Start index: 5500


  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 25.19s.



  0%|          | 0/500 [00:00<?, ?req/s]

RPM:   0%|          | 0/500

TPM:   0%|          | 0/10000


Run took 27.48s.

{'brief-answer': 'Bleak black bats bask beneath bloodred banners, blasting '
                 'black metal before midnight’s mist.',
 'brief-prompt': 'Given the user prompt and a normal answer, generate a brief '
                 'answer. A brief answer should be as short as possible but '
                 'still answer the question and give relevant information. '
                 'Never exceed 120 characters / 20 words.\n'
                 'Examples between --- lines:\n'
                 '--- Example 1 ---\n'
                 'Your input:\n'
                 'Question: How much is 2+3?\n'
                 'Normal answer: Expression 2+3 is equal to 5.\n'
                 'Your output:\n'
                 '5\n'
                 '--- Example 2 ---\n'
                 'Your input:\n'
                 'Question: What is the color of the sky?\n'
                 'Normal answer: The sky is blue.\n'
                 'Your output:\n'
                 'Blue\n'
              

In [4]:
# Load JSONL, filter it, rename columns, and push to HuggingFace.

# Load the JSONL file.
ds = datasets.load_dataset('json', data_files='DATA/dataset_arena_gpt-4o.jsonl')['train']
print(f"Original dataset len: {len(ds)}")

# Filter out duplicates.
def filter_out_duplicates(example, seen_questions = set()):
    ''' Filter out duplicates. '''
    if example['question'] in seen_questions:
        return False
    else:
        seen_questions.add(example['question'])
        return True

ds = ds.filter(filter_out_duplicates)
print(f"Len after removing duplicate questions: {len(ds)}")

# Filter out bad data.
def filter_out_bad_data(example):
    ''' Filter out errors and brief answers that are not much shorter than normal answers. '''
    brief, normal = example['brief-answer'], example['normal-answer']
    # Include only if:
    return len(brief)>=1 and len(brief)<len(normal)/2

ds = ds.filter(filter_out_bad_data)
print(f"Len after removing bad data: {len(ds)}")

# Rename columns to match DPO format:
# https://huggingface.co/docs/trl/dpo_trainer#expected-dataset-format
ds = ds.map(lambda row: {
    # 'question-id': row['question-id'],
    'prompt': row['question'],
    'chosen': row['brief-answer'],
    'rejected': row['normal-answer'],
    })
ds = ds.select_columns(['question-id', 'prompt', 'chosen', 'rejected'])
print(f"Columns after renaming: {ds.column_names}")

# Split the dataset into train and test.
splits = ds.train_test_split(test_size=0.1)
split_ds = datasets.DatasetDict({'train': splits['train'], 'test': splits['test']})
print(f"Train len: {len(split_ds['train'])}, Test len: {len(split_ds['test'])}")

# Upload the dataset to HuggingFace.
split_ds.push_to_hub("ZSvedic/gpt4o-arena-brevity-dpo")

Original dataset len: 33000
Len after removing duplicate questions: 26968
Len after removing bad data: 25490
Columns after renaming: ['question-id', 'prompt', 'chosen', 'rejected']
Train len: 22941, Test len: 2549


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/23 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/526 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/ZSvedic/gpt4o-arena-brevity-dpo/commit/05370ec13657037acda3a24ba66922076e40f367', commit_message='Upload dataset', commit_description='', oid='05370ec13657037acda3a24ba66922076e40f367', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
### DEBUG CELL ###

print(f"Total price: {total_price:.2f} USD")