In [1]:
from datasets import load_dataset, Dataset
from jinja2 import Template
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

import json

In [2]:
NUM_TOPICS = 1024
NUM_GPUS = 4
BATCH_SIZE = 32

In [3]:
ds = load_dataset('wikimedia/wikipedia', name='20231101.en', split='train', streaming=True)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

In [4]:
for row in ds.shuffle(seed=1998, buffer_size=10_000).take(10):
    print(row['title'])

Komorica
Glassport Odds
Ciudad Nueva (Hato Rey)
Kamiokite
Roobaka
Wayne Ormond
The Pagans (film)
Alfred A. Gilman
1922 Austin twin tornadoes
Gornji Emovci


In [5]:
with open("topic_persona_conditioned.jinja2") as f:
    template_str = f.read()

In [6]:
with open("topic_persona_content_shots.json") as f:
    content_json = json.load(f)

In [7]:
for c in content_json:
    c['persona'] = json.dumps(c['persona'], indent=2)

In [8]:
template = Template(template_str)

In [9]:
system_prompt = template.render(contents=content_json)

In [10]:
print(system_prompt)

# Instructions

Your goal is to write content about a topic.
The input will contain a persona, and content should appear that it is spoken by that person.

# Output Instructions

Respond with the content in plain text, with no structure.

# Examples

Persona:
{
  "identity": "Gordon James Ramsay, 57-year-old celebrity chef turned global hospitality mogul, embodies the dual nature of culinary artistry and unrelenting business acumen. His explosive temper and exacting standards in professional kitchens contrast sharply with his nurturing approach to amateur cooks and children. A former professional footballer whose career was cut short by injury, he channels the intensity of athletic competition into culinary excellence, driving himself and others toward perpetual improvement with both militaristic discipline and surprising moments of profound empathy.",
  "personalLife": "Living between London and Los Angeles with his wife Tana and five children, Gordon maintains a strict separation bet

In [11]:
personas_ds = load_dataset('amang1802/personas_sample_405B')['train']

README.md:   0%|          | 0.00/862 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/18.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2002 [00:00<?, ? examples/s]

In [12]:
personas_deduped = personas_ds.filter(lambda row: row['is_cluster_centroid'] == True)

Filter:   0%|          | 0/2002 [00:00<?, ? examples/s]

In [13]:
personas_sampled = personas_ds.shuffle(seed=1998).select(range(personas_deduped.num_rows))

In [15]:
wiki_ds = Dataset.from_list(list(ds.shuffle(seed=1998, buffer_size=1000_000).take(NUM_TOPICS)))

In [16]:
def persona_cross_product(int_ids, titles, personas):
    num_personas = personas.num_rows
    
    return {
        "id": [int_ids[0]] * num_personas,
        "title": [titles[0]] * num_personas,
        "persona_id": personas['id'],
        "persona": personas['persona']
    }

In [17]:
wiki_personas_sampled_ds = wiki_ds.map(lambda i, t: persona_cross_product(i, t, personas_sampled),
                                       input_columns=['id', 'title'], remove_columns=ds.column_names, batched=True, batch_size=1)

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

In [18]:
wiki_personas_deduped_ds = wiki_ds.map(lambda i, t: persona_cross_product(i, t, personas_deduped),
                                       input_columns=['id', 'title'], remove_columns=ds.column_names, batched=True, batch_size=1)

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

In [19]:
wiki_personas_sampled_ds, wiki_personas_deduped_ds

(Dataset({
     features: ['id', 'title', 'persona_id', 'persona'],
     num_rows: 25600
 }),
 Dataset({
     features: ['id', 'title', 'persona_id', 'persona'],
     num_rows: 25600
 }))

In [20]:
model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"

In [21]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [22]:
llm = LLM(model=model_id, max_model_len=4096, tensor_parallel_size=NUM_GPUS, gpu_memory_utilization=0.98, enable_chunked_prefill=True, max_num_batched_tokens=4096, max_num_seqs=BATCH_SIZE)

INFO 12-24 08:49:59 config.py:478] This model supports multiple tasks: {'score', 'generate', 'reward', 'embed', 'classify'}. Defaulting to 'generate'.
INFO 12-24 08:50:00 config.py:1216] Defaulting to use mp for distributed inference
INFO 12-24 08:50:00 config.py:1364] Chunked prefill is enabled with max_num_batched_tokens=4096.
INFO 12-24 08:50:00 llm_engine.py:249] Initializing an LLM engine (v0.6.5) with config: model='meta-llama/Llama-3.1-405B-Instruct-FP8', speculative_config=None, tokenizer='meta-llama/Llama-3.1-405B-Instruct-FP8', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fbgemm_fp8, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided

Loading safetensors checkpoint shards:   0% Completed | 0/109 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=113435)[0;0m INFO 12-24 08:52:24 model_runner.py:1097] Loading model weights took 113.4847 GB
INFO 12-24 08:52:25 model_runner.py:1097] Loading model weights took 113.4847 GB
[1;36m(VllmWorkerProcess pid=113438)[0;0m INFO 12-24 08:52:25 model_runner.py:1097] Loading model weights took 113.4847 GB
[1;36m(VllmWorkerProcess pid=113434)[0;0m INFO 12-24 08:52:25 model_runner.py:1097] Loading model weights took 113.4847 GB
[1;36m(VllmWorkerProcess pid=113435)[0;0m [1;36m(VllmWorkerProcess pid=113434)[0;0m INFO 12-24 08:52:29 worker.py:241] Memory profiling takes 4.09 seconds
INFO 12-24 08:52:29 worker.py:241] Memory profiling takes 4.09 seconds
[1;36m(VllmWorkerProcess pid=113435)[0;0m [1;36m(VllmWorkerProcess pid=113434)[0;0m [1;36m(VllmWorkerProcess pid=113438)[0;0m INFO 12-24 08:52:29 worker.py:241] the current vLLM instance can use total_gpu_memory (139.72GiB) x gpu_memory_utilization (0.98) = 136.92GiB
INFO 12-24 08:52:29 worker.py:241] the c

In [23]:
def generate_content(topics, personas):
    personas_str = [json.dumps(persona, indent=2) for persona in personas]
    messages = [[{"role": "system", "content": system_prompt},
                {"role": "user", "content": "Persona\n" + persona + "\n\nTopic:\n" + topic + "\n\nContent:"}]
                for topic, persona in zip(topics, personas_str)]
    prompts = [tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) for chat in messages]

    outputs = llm.generate(prompts, SamplingParams(temperature=0.25, top_p=0.9, max_tokens=3072))

    return {"synthetic_content": [output.outputs[0].text.strip() for output in outputs]}

In [None]:
syn_sampled_ds = wiki_personas_sampled_ds.map(generate_content, batched=True, batch_size=NUM_TOPICS, input_columns=["title", "persona"])



Map:   0%|          | 0/25600 [00:00<?, ? examples/s]


[Acessed prompts:   0% 0/1024 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[Acessed prompts:   0% 1/1024 [00:31<8:51:00, 31.14s/it, est. speed input: 95.33 toks/s, output: 3.66 toks/s]
[Acessed prompts:   0% 2/1024 [00:33<4:06:16, 14.46s/it, est. speed input: 175.01 toks/s, output: 8.22 toks/s]
[Acessed prompts:   0% 3/1024 [00:35<2:24:12,  8.47s/it, est. speed input: 253.15 toks/s, output: 12.98 toks/s]
[Acessed prompts:   0% 4/1024 [00:38<1:51:55,  6.58s/it, est. speed input: 306.24 toks/s, output: 17.61 toks/s]
[Acessed prompts:   0% 5/1024 [00:40<1:19:04,  4.66s/it, est. speed input: 370.01 toks/s, output: 22.84 toks/s]
[Acessed prompts:   1% 6/1024 [00:42<1:04:21,  3.79s/it, est. speed input: 417.46 toks/s, output: 27.96 toks/s]
[Acessed prompts:   1% 7/1024 [00:43<48:09,  2.84s/it, est. speed input: 480.35 toks/s, output: 33.75 toks/s]  
[Acessed prompts:   1% 8/1024 [00:44<37:14,  2.20s/it, est. speed input: 535.43 toks/s, output: 39.18 toks/s]


In [None]:
syn_sampled_ds.push_to_hub('amang1802/wiki_topic_persona_sampled_405B')

In [None]:
syn_deduped_ds = wiki_personas_deduped_ds.map(generate_content, batched=True, batch_size=NUM_TOPICS, input_columns=["title", "persona"])

In [None]:
syn_deduped_ds.push_to_hub('amang1802/wiki_topic_persona_deduped_405B')