In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

import pandas as pd
pd.set_option('display.max_colwidth', None)

In [None]:
NUM_SIZE = 10_000

In [None]:
user_llm_instr_ds = load_dataset('lmsys/lmsys-chat-1m')['train'].shuffle(seed=42).select(range(NUM_SIZE))

In [None]:
instr_generation_sys_prompt = "Output an instruction or question to which the user provided text is the answer."

In [None]:
def get_chosen_rejected(llm, tokenizer, conv_batch):
    pair_0, pair_1 = zip(*[(conv[0]['content'], conv[1]['content']) for conv in conv_batch])
    user_instrs, assistant_responses = list(pair_0), list(pair_1)
    prompt_messages = [[{"role": "system", "content": instr_generation_sys_prompt},
                       {"role": "user", "content": text + "\n\n" + "Instruction:"}] for text in assistant_responses]
    prompts = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for messages in prompt_messages]

    outputs = llm.generate(prompts, SamplingParams(temperature=0.25, top_p=0.9, max_tokens=512))

    return {
        "chosen": user_instrs,
        "rejected": [output.outputs[0].text.strip() for output in outputs],
        "user_input": assistant_responses,
        "system_prompt": [instr_generation_sys_prompt] * len(user_instrs)
    }    

In [None]:
model_id = "meta-llama/Llama-3.1-8B-Instruct"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
llm = LLM(model=model_id, max_model_len=4096)

In [None]:
test_run_ds = user_llm_instr_ds.select(range(10)).map(lambda batch: get_chosen_rejected(llm, tokenizer, batch),
                                                      input_columns=['conversation'],
                                                      batched=True,
                                                      batch_size=16)

In [None]:
test_run_ds.to_pandas()[['chosen', 'rejected', 'user_input']]