In [18]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

import pandas as pd
pd.set_option('display.max_colwidth', None)

In [11]:
NUM_SIZE = 1000_000

In [12]:
user_llm_instr_ds = load_dataset('lmsys/lmsys-chat-1m')['train'].shuffle(seed=42).select(range(NUM_SIZE))

In [4]:
instr_generation_sys_prompt = "Output an instruction or question to which the user provided text is the answer."

In [5]:
def get_chosen_rejected(llm, tokenizer, conv_batch):
    pair_0, pair_1 = zip(*[(conv[0]['content'], conv[1]['content']) for conv in conv_batch])
    user_instrs, assistant_responses = list(pair_0), list(pair_1)
    prompt_messages = [[{"role": "system", "content": instr_generation_sys_prompt},
                       {"role": "user", "content": text + "\n\n" + "Instruction:"}] for text in assistant_responses]
    prompts = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for messages in prompt_messages]

    outputs = llm.generate(prompts, SamplingParams(temperature=0.25, top_p=0.9, max_tokens=512))

    return {
        "chosen": user_instrs,
        "rejected": [output.outputs[0].text.strip() for output in outputs],
        "user_input": assistant_responses,
        "system_prompt": [instr_generation_sys_prompt] * len(user_instrs)
    }    

In [6]:
model_id = "meta-llama/Llama-3.1-70B-Instruct"

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [8]:
llm = LLM(model=model_id, max_model_len=4096, tensor_parallel_size=4)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

INFO 11-30 23:50:29 config.py:350] This model supports multiple tasks: {'generate', 'embedding'}. Defaulting to 'generate'.
INFO 11-30 23:50:29 config.py:1020] Defaulting to use mp for distributed inference
INFO 11-30 23:50:29 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post1) with config: model='meta-llama/Llama-3.1-70B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.1-70B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_ti

generation_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

INFO 11-30 23:50:29 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
INFO 11-30 23:50:30 selector.py:135] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=372)[0;0m [1;36m(VllmWorkerProcess pid=371)[0;0m INFO 11-30 23:50:30 selector.py:135] Using Flash Attention backend.
INFO 11-30 23:50:30 selector.py:135] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=372)[0;0m [1;36m(VllmWorkerProcess pid=371)[0;0m INFO 11-30 23:50:30 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 11-30 23:50:30 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
[1;36m(VllmWorkerProcess pid=373)[0;0m INFO 11-30 23:50:30 selector.py:135] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=373)[0;0m INFO 11-30 23:50:30 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 11-30 23:50:32 utils.py:961] Found nccl from library libnccl.so.2
INFO 11-30 23:50:32 pynccl.py:69

model-00008-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00030.safetensors:   0%|          | 0.00/4.58G [00:00<?, ?B/s]

model-00006-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00005-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00003-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00007-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00004-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00009-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00011-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00010-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00012-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00013-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00014-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00016-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00015-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00017-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00019-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00018-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00020-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00021-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00022-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00023-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00024-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00025-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00026-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00027-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00030-of-00030.safetensors:   0%|          | 0.00/2.10G [00:00<?, ?B/s]

model-00028-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00029-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/59.6k [00:00<?, ?B/s]

Loading safetensors checkpoint shards:   0% Completed | 0/30 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=372)[0;0m INFO 11-30 23:58:45 model_runner.py:1077] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=371)[0;0m INFO 11-30 23:58:45 model_runner.py:1077] Loading model weights took 32.8892 GB
INFO 11-30 23:58:45 model_runner.py:1077] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=373)[0;0m INFO 11-30 23:58:45 model_runner.py:1077] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=371)[0;0m [1;36m(VllmWorkerProcess pid=373)[0;0m [1;36m(VllmWorkerProcess pid=372)[0;0m INFO 11-30 23:58:47 worker.py:232] Memory profiling results: total_gpu_memory=79.15GiB initial_memory_usage=34.38GiB peak_torch_memory=33.26GiB memory_usage_post_profile=35.45GiB non_torch_memory=2.54GiB kv_cache_size=35.44GiB gpu_memory_utilization=0.90
INFO 11-30 23:58:47 worker.py:232] Memory profiling results: total_gpu_memory=79.15GiB initial_memory_usage=34.24GiB peak_torch_memory=33.26GiB memory_usage_post_profile=35.17GiB 

In [None]:
instr_preference_ds = user_llm_instr_ds.map(lambda batch: get_chosen_rejected(llm, tokenizer, batch),
                                            input_columns=['conversation'],
                                            batched=True,
                                            batch_size=512)

Map:   0%|          | 0/1000000 [00:00<?, ? examples/s]


Processed prompts:   0% 0/512 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][A
Processed prompts:   0% 1/512 [00:11<1:38:07, 11.52s/it, est. speed input: 5.90 toks/s, output: 0.52 toks/s][A
Processed prompts:   1% 6/512 [00:11<12:25,  1.47s/it, est. speed input: 38.23 toks/s, output: 3.54 toks/s] [A
Processed prompts:   3% 13/512 [00:12<04:48,  1.73it/s, est. speed input: 84.25 toks/s, output: 8.60 toks/s][A
Processed prompts:   4% 19/512 [00:12<02:50,  2.89it/s, est. speed input: 190.47 toks/s, output: 13.23 toks/s][A
Processed prompts:   5% 26/512 [00:13<01:54,  4.23it/s, est. speed input: 294.44 toks/s, output: 18.53 toks/s][A
Processed prompts:   6% 32/512 [00:13<01:21,  5.91it/s, est. speed input: 364.48 toks/s, output: 23.63 toks/s][A
Processed prompts:   9% 47/512 [00:14<00:50,  9.24it/s, est. speed input: 467.65 toks/s, output: 35.88 toks/s][A
Processed prompts:  12% 60/512 [00:14<00:35, 12.68it/s, est. speed input: 637.10 toks/s, output: 47.12 to

In [19]:
train_test_ds = instr_preference_ds.train_test_split(test_size=0.05, shuffle=True)
train_test_ds.push_to_hub('lmsys_synthetic_instruction_preferences')

Uploading the dataset shards:   0%|          | 0/8 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/119 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/50 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/amang1802/lmsys_synthetic_instruction_preferences/commit/0104e41873131dd2120a36d1f7ef0452510710bf', commit_message='Upload dataset', commit_description='', oid='0104e41873131dd2120a36d1f7ef0452510710bf', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/amang1802/lmsys_synthetic_instruction_preferences', endpoint='https://huggingface.co', repo_type='dataset', repo_id='amang1802/lmsys_synthetic_instruction_preferences'), pr_revision=None, pr_num=None)