In [1]:
from typing import Generator
import utils.llm_utils as llm
import datasets # type: ignore

input_dataset = 'lmsys/chatbot_arena_conversations'
dataset = datasets.load_dataset(input_dataset)['train']
n_rows = len(dataset)
print(str(n_rows))

model_name = 'microsoft/Phi-3-mini-4k-instruct'
tokenizer, model = llm.load_tokenizer_and_model(model_name)

33000


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# Test:
out_jsonl = "1_dataset_creator/phi3_arena_brief_dataset.jsonl"
with open(out_jsonl, 'w') as f: pass

dataset = dataset.select_columns(['question_id', 'conversation_a']) 
dataset = dataset.map(
        lambda example: {'question': example['conversation_a'][0]['content']})

for normal_chunk in llm.process_variable_chunks(dataset, tokenizer, model, 100, 1500, 50):
    normal_chunk = normal_chunk.select_columns(['question_id', 'question', 'answer'])
    with open(out_jsonl, 'ab') as f:
        normal_chunk.to_json(f, lines=True)
    break

In [3]:
import time

# Word limits for normal and brief answers.
normal_word_limit = 120
brief_word_limit = 20
# A word is "between 5 and 6.5 characters per word including spaces and punctuation":
# https://charactercounter.com/characters-to-words
normal_max_ch_soft = normal_word_limit * 6
brief_max_ch_soft = brief_word_limit * 6
# Hard limit adds 20% buffer and divides by 4 to get LLM token limit.
# https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
normal_max_tokens_hard = int(normal_max_ch_soft * 1.2 / 4)
brief_max_tokens_hard = int(brief_max_ch_soft * 1.2 / 4)

normal_prompt = f'''Answer the user prompt below "---" line. Never exceed {normal_max_ch_soft} characters / {normal_word_limit} words.
---
'''

brief_prompt = f'''Given the user prompt and a normal answer, generate a brief answer. A brief answer should be as short as possible but still answer the question and give relevant information. Never exceed {brief_max_ch_soft} characters / {brief_word_limit} words.
Examples between --- lines:
--- Example 1 ---
Your input:
Question: How much is 2+3?
Normal answer: Expression 2+3 is equal to 5.
Your output:
5
--- Example 2 ---
Your input:
Question: What is the color of the sky?
Normal answer: The sky is blue.
Your output:
Blue
--- End of examples

Considering all the above, give a brief answer to the prompt and normal answer below:
'''

def brief_prompt(example):
    q = example['question']
    na = example['Answer-normal']
    return f"{brief_prompt}Question: {q}\nNormal answer: {na}"

out_jsonl = "1_dataset_creator/phi3_arena_brief_dataset.jsonl"
with open(out_jsonl, 'w') as f: pass

normal_dataset = dataset.select_columns(['question_id', 'conversation_a']) 
normal_dataset = normal_dataset.map(
        lambda example: {'question': example['conversation_a'][0]['content']}) 
normal_dataset = normal_dataset.map(
        lambda example: {'prompt': normal_prompt + example['question']})

start_time = time.time()
current_example = 0
print(f"Example {current_example} of {n_rows}...")
for normal_chunk in llm.process_variable_chunks(normal_dataset, tokenizer, model, 100, 5000, normal_max_tokens_hard):
    normal_chunk = normal_chunk.select_columns(['question_id', 'question', 'answer'])
    normal_chunk = normal_chunk.rename_column('answer', 'Answer-normal')
    normal_chunk = normal_chunk.map(
        lambda example: {'prompt':  brief_prompt(example)})
    
    for brief_chunk in llm.process_variable_chunks(normal_chunk, tokenizer, model, 100, 5000, brief_max_tokens_hard):
        brief_chunk = brief_chunk.select_columns(['question_id', 'question', 'answer', 'Answer-normal'])
        brief_chunk = brief_chunk.rename_column('answer', 'Answer-short')
        with open(out_jsonl, 'ab') as f:
            brief_chunk.to_json(f, lines=True)

    last_elapsed_time = time.time() - start_time
    print(f"Time/sample: {last_elapsed_time/len(normal_chunk):.2f} sec") 
    current_example += len(normal_chunk)
    
    break


Example 0 of 33000...


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/8 [00:00<?, ? examples/s]

Map:   0%|          | 0/19 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Map:   0%|          | 0/69 [00:00<?, ? examples/s]

Map:   0%|          | 0/19 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Time/sample: 0.53 sec
