In [1]:
from datasets import load_dataset, Dataset
from jinja2 import Template
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

import json

In [2]:
NUM_TOPICS = 1024
NUM_GPUS = 4
NUM_DUPS = 10

In [3]:
ds = load_dataset('wikimedia/wikipedia', name='20231101.en', split='train', streaming=True)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

In [4]:
with open("fulltext_conditioned.jinja2") as f:
    template_str = f.read()

In [5]:
with open("fulltext_content_shots.json") as f:
    content_json = json.load(f)

In [6]:
template = Template(template_str)

In [7]:
system_prompt = template.render(contents=content_json)

In [8]:
print(system_prompt)

# Instructions

Imagine you're an expert on the topic given by the user. Your goal is to rewrite the original content in your own words, matching it's original length in number of words and preserving all the facts and concepts from the original. Do not summarize any detail, but dive deeper and explain every word and sentence from the original in as much detail as possible.

# Output Instructions

Respond with the content in plain text, with no structure.

# Examples

Topic:
Jaguar Logo Rebrand

Original Content:
**Jaguar Logo Changes Are First Step in the Luxury Brand's Rebirth**

The British automaker is repositioning itself as an ultra-luxury EV brand, and the new direction is accompanied by a fresh set of badges for future Jaguars.

- Jaguar has debuted new logos and graphic designs, preparing to relaunch itself as an exclusive, high-priced luxury EV brand that competes in the league of Rolls-Royce and Bentley.
- A new "leaper" emulates Jaguar hood ornaments of old but with a more 

In [9]:
model_id = "meta-llama/Llama-3.3-70B-Instruct"

In [10]:
llm = LLM(model=model_id, max_model_len=32768, tensor_parallel_size=NUM_GPUS, gpu_memory_utilization=0.98)

INFO 12-31 18:07:14 config.py:478] This model supports multiple tasks: {'generate', 'embed', 'reward', 'score', 'classify'}. Defaulting to 'generate'.
INFO 12-31 18:07:14 config.py:1216] Defaulting to use mp for distributed inference
INFO 12-31 18:07:14 llm_engine.py:249] Initializing an LLM engine (v0.6.5) with config: model='meta-llama/Llama-3.3-70B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.3-70B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model

Loading safetensors checkpoint shards:   0% Completed | 0/30 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=424320)[0;0m INFO 12-31 18:07:23 weight_utils.py:243] Using model weights format ['*.safetensors']
INFO 12-31 18:08:10 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=424321)[0;0m INFO 12-31 18:08:10 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=424319)[0;0m INFO 12-31 18:08:10 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=424320)[0;0m INFO 12-31 18:08:10 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=424321)[0;0m INFO 12-31 18:08:15 worker.py:241] Memory profiling takes 4.68 seconds
[1;36m(VllmWorkerProcess pid=424321)[0;0m INFO 12-31 18:08:15 worker.py:241] the current vLLM instance can use total_gpu_memory (139.72GiB) x gpu_memory_utilization (0.98) = 136.92GiB
[1;36m(VllmWorkerProcess pid=424321)[0;0m INFO 12-31 18:08:15 worker.py:241] model weights take 32.89GiB; non_torch

In [25]:
def dup_list(l):
    l_dup = [l] * NUM_DUPS
    l_flat = [l_dup[i][j] for j in range(len(l)) for i in range(NUM_DUPS)]

    return l_flat

In [57]:
def generate_content(ids, urls, titles, texts):
    messages = [[{"role": "system", "content": system_prompt},
                {"role": "user", "content": "Topic:\n" + title + "\n\nOriginal Content:\n" + text + "\n\nRewritten Content:"}]
                for title, text in zip(titles, texts)]

    dup_messages = dup_list(messages)
    dup_ids = dup_list(ids)
    dup_urls = dup_list(urls)
    dup_titles = dup_list(titles)
    dup_texts = dup_list(texts)

    outputs = llm.chat(dup_messages, SamplingParams(temperature=0.9, top_p=0.9, max_tokens=16384))

    return {
        "id": dup_ids,
        "url": dup_urls,
        "title": dup_titles,
        "text": dup_texts,
        "synthetic_content": [output.outputs[0].text.strip() for output in outputs]
    }

In [58]:
syn_ds_stream = ds.shuffle(seed=1998, buffer_size=1000_000).take(NUM_TOPICS).map(generate_content, batched=True, batch_size=NUM_TOPICS, input_columns=["id", "url", "title", "text"])

In [59]:
syn_ds_list = list(syn_ds_stream)

Processed prompts:   0% 0/10240 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]



Processed prompts: 100% 10240/10240 [2:40:54<00:00,  1.06it/s, est. speed input: 7097.64 toks/s, output: 746.41 toks/s] 


In [60]:
syn_ds = Dataset.from_list(syn_ds_list)
syn_ds.push_to_hub('amang1802/synthetic_data_dup10_fulltext_conditioned_L3.3_70B')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/11 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/datasets/amang1802/synthetic_data_dup10_fulltext_conditioned_L3.3_70B/commit/626a13413c1b912743069cb46531165dbe2eca5e', commit_message='Upload dataset', commit_description='', oid='626a13413c1b912743069cb46531165dbe2eca5e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/amang1802/synthetic_data_dup10_fulltext_conditioned_L3.3_70B', endpoint='https://huggingface.co', repo_type='dataset', repo_id='amang1802/synthetic_data_dup10_fulltext_conditioned_L3.3_70B'), pr_revision=None, pr_num=None)

Adding this changes length drastically:

`Do not summarize any detail, but dive deeper and explain every word and sentence from the original in as much detail as possible.`

On a sample of 512 generations, it leads to 64% increase in length of generated content (in chars).

In [16]:
def avg_length(ds):
    lengths = [len(text) for text in ds['synthetic_content']]
    return sum(lengths) / len(lengths)

In [21]:
syn_ds[0]['title']

'Ruel Brathwaite'

In [17]:
avg_length(syn_ds)

2238.46875

In [20]:
with_prompt_ds = load_dataset('amang1802/synthetic_data_fulltext_conditioned_L3.3_70B')['train'].select(range(512))
with_prompt_ds[0]['title']

'Ruel Brathwaite'

In [22]:
avg_length(with_prompt_ds)

3682.220703125