In [30]:
from datasets import load_dataset, Dataset
from jinja2 import Template
from transformers import AutoTokenizer
from pydantic import BaseModel, TypeAdapter
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

import json
import random
import traceback

In [3]:
NUM_GPUS = 4
BATCH_SIZE = 256

In [4]:
content_ds = load_dataset("amang1802/wiki_topic_conditioned_405B")['train']

In [5]:
persona_content_ds = load_dataset('amang1802/wiki_topic_persona_sampled_405B')['train']

In [6]:
def add_text_to_persona_ds(ds):
    id_to_text = {}
    for i in range(content_ds.num_rows):
        id_to_text[content_ds[i]['id']] = content_ds[i]['text']

    return ds.map(lambda idx: {"text": id_to_text[idx]}, input_columns=['id'])

In [7]:
persona_content_ds = add_text_to_persona_ds(persona_content_ds)

In [8]:
def pick_one_per_persona_ds(ds):
    uniq_personas = list(set(ds['persona_id']))
    uniq_contents = list(set(ds['id']))
    included_pairs = [(cid, random.choice(uniq_personas)) for cid in uniq_contents]

    return ds.filter(lambda row: (row['id'], row['persona_id']) in included_pairs)

In [9]:
persona_uniq_ds = pick_one_per_persona_ds(persona_content_ds)

Filter:   0%|          | 0/25600 [00:00<?, ? examples/s]

In [10]:
persona_uniq_ds

Dataset({
    features: ['id', 'title', 'persona_id', 'persona', 'synthetic_content', 'text'],
    num_rows: 1024
})

In [11]:
assert persona_uniq_ds.num_rows == len(set(persona_uniq_ds['id']))

In [12]:
with open("gt_accuracy.jinja2") as f:
    template_str = f.read()

In [13]:
with open("few_shots.json") as f:
    examples_json = json.load(f)

In [14]:
for example in examples_json:
    example['matches'] = json.dumps(example['matches'], indent=2)

In [15]:
template = Template(template_str)

In [16]:
system_prompt = template.render(examples=examples_json)

In [17]:
print(system_prompt)

# Instructions

You are a fact checker and you're required to compare a pair of texts and find segments that discuss the same facts and judge if they both match on the facts. The goal is to only judge the alignment on facts stated by both. They can state unique facts which we have to ignore. For a pair of texts, output the a list of common segments and if they match or not.

On the inclusion of segments:
- Inspect every sentence in text1 and text2 and include all segments that discuss common facts.
- If one of the text has a segment with no similar segment in the other text, ignore that segment altogether.
- Repeating this instruction: Include all segments that discuss common facts.

On the matching sensitivity:
- It's possible that two statements don't exactly agree but are close enough. Like if one says the length as 50cm, another 55cm, or age in 50s and the other in the early 60s. Mark that match as true.
- Use a smartly assessed judgement rather than pointing out even the smallest 

In [18]:
model_id = "meta-llama/Llama-3.3-70B-Instruct"

In [21]:
class Judgement(BaseModel):
    text1: str
    text2: str
    rationale: str
    match: bool

ta = TypeAdapter(list[Judgement])


json_schema = ta.json_schema()

In [22]:
llm = LLM(model=model_id, max_model_len=24576, tensor_parallel_size=NUM_GPUS, gpu_memory_utilization=0.98)

INFO 12-25 09:20:28 config.py:478] This model supports multiple tasks: {'embed', 'reward', 'score', 'classify', 'generate'}. Defaulting to 'generate'.
INFO 12-25 09:20:28 config.py:1216] Defaulting to use mp for distributed inference
INFO 12-25 09:20:28 llm_engine.py:249] Initializing an LLM engine (v0.6.5) with config: model='meta-llama/Llama-3.3-70B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.3-70B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


INFO 12-25 09:20:30 selector.py:120] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=209422)[0;0m INFO 12-25 09:20:30 selector.py:120] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=209422)[0;0m INFO 12-25 09:20:30 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
[1;36m(VllmWorkerProcess pid=209420)[0;0m INFO 12-25 09:20:30 selector.py:120] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=209420)[0;0m INFO 12-25 09:20:30 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
[1;36m(VllmWorkerProcess pid=209421)[0;0m INFO 12-25 09:20:30 selector.py:120] Using Flash Attention backend.
[1;36m(VllmWorkerProcess pid=209421)[0;0m INFO 12-25 09:20:30 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
INFO 12-25 09:20:33 utils.py:922] Found nccl from library libnccl.so.2
INFO 12-25 09:20:33 pynccl.py:69] vLLM is using nccl==2.21.5
[1;36m(VllmWorkerProcess pid=209420)[0;0m [1;36m(VllmWorkerProcess pid=209422)[0;0m INFO

Loading safetensors checkpoint shards:   0% Completed | 0/30 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=209420)[0;0m INFO 12-25 09:20:35 weight_utils.py:243] Using model weights format ['*.safetensors']
[1;36m(VllmWorkerProcess pid=209422)[0;0m INFO 12-25 09:20:47 model_runner.py:1097] Loading model weights took 32.8892 GB
INFO 12-25 09:20:47 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=209421)[0;0m INFO 12-25 09:20:47 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=209420)[0;0m INFO 12-25 09:20:47 model_runner.py:1097] Loading model weights took 32.8892 GB
[1;36m(VllmWorkerProcess pid=209422)[0;0m INFO 12-25 09:20:50 worker.py:241] Memory profiling takes 3.14 seconds
[1;36m(VllmWorkerProcess pid=209422)[0;0m INFO 12-25 09:20:50 worker.py:241] the current vLLM instance can use total_gpu_memory (139.72GiB) x gpu_memory_utilization (0.98) = 136.92GiB
[1;36m(VllmWorkerProcess pid=209422)[0;0m INFO 12-25 09:20:50 worker.py:241] model weights take 32.89GiB; non_torch

In [31]:
def compute_gt_accuracy(gt_texts, synthetic_texts):
    messages = [[{"role": "system", "content": system_prompt},
                {"role": "user", "content": "text1:\n" + text1 + "\n\ntext2:\n" + text2 + "\n\nresponse:" }]
                for text1, text2 in zip(gt_texts, synthetic_texts)]

    guided_decoding_params = GuidedDecodingParams(json=json_schema)
    outputs = llm.chat(messages, SamplingParams(temperature=0.3, top_p=0.9, max_tokens=1536, guided_decoding=guided_decoding_params))

    judgements = []
    scores = []
    for output in outputs:
        response = output.outputs[0].text.strip()
        judgement = []
        score = -1.0
        try:
            judgement = json.loads(response)
            num_matches = sum([1 for j in judgement if j['match']])
            score = num_matches / len(judgement) if len(judgement) > 0 else 0
        except Exception:
            print(traceback.format_exc())

        judgements.append(judgement)
        scores.append(score)
        
    return {
        "judgement": judgements,
        "accuracy_score": scores
    }   

In [32]:
def get_score(ds):
    valid_scores = [score for score in ds['accuracy_score'] if score >= 0]
    return sum(valid_scores) / len(valid_scores)

In [35]:
judged_ds1 = persona_uniq_ds.map(compute_gt_accuracy, input_columns=['text', 'synthetic_content'], batched=True, batch_size=BATCH_SIZE)

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]


[Acessed prompts:   0% 0/256 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[Acessed prompts:   0% 1/256 [01:25<6:05:13, 85.94s/it, est. speed input: 59.82 toks/s, output: 0.54 toks/s]
[Acessed prompts:   1% 2/256 [01:27<2:33:37, 36.29s/it, est. speed input: 121.64 toks/s, output: 1.21 toks/s]
[Acessed prompts:   1% 3/256 [01:28<1:25:14, 20.21s/it, est. speed input: 162.46 toks/s, output: 1.99 toks/s]
[Acessed prompts:   2% 4/256 [01:29<52:28, 12.49s/it, est. speed input: 202.99 toks/s, output: 2.82 toks/s]  
[Acessed prompts:   3% 7/256 [01:29<19:21,  4.66s/it, est. speed input: 334.51 toks/s, output: 5.38 toks/s]
[Acessed prompts:   4% 9/256 [01:30<12:36,  3.06s/it, est. speed input: 416.51 toks/s, output: 7.19 toks/s]
[Acessed prompts:   4% 10/256 [01:30<10:16,  2.51s/it, est. speed input: 456.72 toks/s, output: 8.13 toks/s]
[Acessed prompts:   5% 14/256 [01:30<04:38,  1.15s/it, est. speed input: 622.35 toks/s, output: 12.02 toks/s]
[Acessed prompts:




[Acessed prompts:   0% 1/256 [00:30<2:11:03, 30.84s/it, est. speed input: 581.49 toks/s, output: 0.00 toks/s]




[Acessed prompts:   1% 2/256 [00:41<1:21:14, 19.19s/it, est. speed input: 917.87 toks/s, output: 0.00 toks/s]
[Acessed prompts:   1% 3/256 [01:28<2:13:04, 31.56s/it, est. speed input: 486.79 toks/s, output: 0.74 toks/s]
[Acessed prompts:   2% 4/256 [01:28<1:20:24, 19.15s/it, est. speed input: 530.44 toks/s, output: 1.48 toks/s]
[Acessed prompts:   2% 5/256 [01:28<52:13, 12.48s/it, est. speed input: 569.79 toks/s, output: 2.28 toks/s]  
[Acessed prompts:   2% 6/256 [01:29<34:29,  8.28s/it, est. speed input: 612.09 toks/s, output: 3.10 toks/s]
[Acessed prompts:   3% 7/256 [01:29<23:55,  5.76s/it, est. speed input: 650.59 toks/s, output: 3.92 toks/s]
[Acessed prompts:   3% 8/256 [01:29<16:23,  3.97s/it, est. speed input: 690.91 toks/s, output: 4.76 toks/s]
[Acessed prompts:   4% 9/256 [01:30<11:38,  2.83s/it, est. speed input: 730.61 toks/s, output: 5.62 toks/s]
[Acessed prompts:   4% 10/256 [01:30<08:09,  1.99s/it, est. speed input: 770.23 toks/s, output: 6.50 toks/s]
[Acessed

Traceback (most recent call last):
  File "/tmp/ipykernel_208820/603597235.py", line 18, in compute_gt_accuracy
    judgement = json.loads(response)
  File "/usr/lib/python3.10/json/__init__.py", line 346, in loads
    return _default_decoder.decode(s)
  File "/usr/lib/python3.10/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/usr/lib/python3.10/json/decoder.py", line 355, in raw_decode
    raise JSONDecodeError("Expecting value", s, err.value) from None
json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

Traceback (most recent call last):
  File "/tmp/ipykernel_208820/603597235.py", line 18, in compute_gt_accuracy
    judgement = json.loads(response)
  File "/usr/lib/python3.10/json/__init__.py", line 346, in loads
    return _default_decoder.decode(s)
  File "/usr/lib/python3.10/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/usr/lib/python3.10/json/decoder.p




In [36]:
get_score(judged_ds1)

0.3080615807967022

In [37]:
judged_ds1.push_to_hub('amang1802/wiki_topic_persona_405B_uniq_gt_accuracy')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/datasets/amang1802/wiki_topic_persona_405B_uniq_gt_accuracy/commit/b7df52cd5334ededf1757fc581c4a9d92321502a', commit_message='Upload dataset', commit_description='', oid='b7df52cd5334ededf1757fc581c4a9d92321502a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/amang1802/wiki_topic_persona_405B_uniq_gt_accuracy', endpoint='https://huggingface.co', repo_type='dataset', repo_id='amang1802/wiki_topic_persona_405B_uniq_gt_accuracy'), pr_revision=None, pr_num=None)

In [38]:
judged_ds2 = content_ds.map(compute_gt_accuracy, input_columns=['text', 'synthetic_content'], batched=True, batch_size=BATCH_SIZE)

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]


[Acessed prompts:   0% 0/256 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[Acessed prompts:   0% 1/256 [01:30<6:26:11, 90.87s/it, est. speed input: 43.42 toks/s, output: 0.88 toks/s]
[Acessed prompts:   1% 3/256 [01:31<1:39:41, 23.64s/it, est. speed input: 123.96 toks/s, output: 2.66 toks/s]
[Acessed prompts:   2% 5/256 [01:31<48:32, 11.60s/it, est. speed input: 206.09 toks/s, output: 4.48 toks/s]  
[Acessed prompts:   3% 7/256 [01:31<27:57,  6.74s/it, est. speed input: 307.36 toks/s, output: 6.36 toks/s]
[Acessed prompts:   4% 9/256 [01:32<17:37,  4.28s/it, est. speed input: 388.36 toks/s, output: 8.29 toks/s]
[Acessed prompts:   4% 10/256 [01:32<13:55,  3.40s/it, est. speed input: 433.60 toks/s, output: 9.28 toks/s]
[Acessed prompts:   4% 11/256 [01:32<10:57,  2.68s/it, est. speed input: 472.83 toks/s, output: 10.29 toks/s]
[Acessed prompts:   5% 12/256 [01:32<08:18,  2.04s/it, est. speed input: 519.26 toks/s, output: 11.32 toks/s]
[Acessed prompts:




[Acessed prompts:   0% 1/256 [00:30<2:09:46, 30.53s/it, est. speed input: 603.58 toks/s, output: 0.00 toks/s]




[Acessed prompts:   1% 2/256 [00:42<1:24:12, 19.89s/it, est. speed input: 912.00 toks/s, output: 0.00 toks/s]
[Acessed prompts:   1% 3/256 [01:29<2:15:50, 32.21s/it, est. speed input: 480.74 toks/s, output: 0.73 toks/s]
[Acessed prompts:   2% 4/256 [01:31<1:24:07, 20.03s/it, est. speed input: 515.77 toks/s, output: 1.58 toks/s]
[Acessed prompts:   2% 5/256 [01:32<54:49, 13.11s/it, est. speed input: 552.62 toks/s, output: 2.44 toks/s]  
[Acessed prompts:   3% 7/256 [01:32<27:15,  6.57s/it, est. speed input: 637.21 toks/s, output: 4.23 toks/s]
[Acessed prompts:   3% 8/256 [01:32<20:25,  4.94s/it, est. speed input: 675.15 toks/s, output: 5.17 toks/s]
[Acessed prompts:   4% 10/256 [01:33<11:42,  2.86s/it, est. speed input: 760.73 toks/s, output: 7.09 toks/s]
[Acessed prompts:   4% 11/256 [01:33<09:13,  2.26s/it, est. speed input: 800.10 toks/s, output: 8.07 toks/s]
[Acessed prompts:   5% 12/256 [01:33<06:59,  1.72s/it, est. speed input: 841.05 toks/s, output: 9.07 toks/s]
[Acess

Traceback (most recent call last):
  File "/tmp/ipykernel_208820/603597235.py", line 18, in compute_gt_accuracy
    judgement = json.loads(response)
  File "/usr/lib/python3.10/json/__init__.py", line 346, in loads
    return _default_decoder.decode(s)
  File "/usr/lib/python3.10/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/usr/lib/python3.10/json/decoder.py", line 355, in raw_decode
    raise JSONDecodeError("Expecting value", s, err.value) from None
json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

Traceback (most recent call last):
  File "/tmp/ipykernel_208820/603597235.py", line 18, in compute_gt_accuracy
    judgement = json.loads(response)
  File "/usr/lib/python3.10/json/__init__.py", line 346, in loads
    return _default_decoder.decode(s)
  File "/usr/lib/python3.10/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/usr/lib/python3.10/json/decoder.p




In [39]:
get_score(judged_ds2)

0.4178142132108833

In [40]:
judged_ds2.push_to_hub('amang1802/wiki_topic_conditioned_405B')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/datasets/amang1802/wiki_topic_conditioned_405B/commit/d2a162e269c483d543c5a8b1bbdd7b4e5cf84618', commit_message='Upload dataset', commit_description='', oid='d2a162e269c483d543c5a8b1bbdd7b4e5cf84618', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/amang1802/wiki_topic_conditioned_405B', endpoint='https://huggingface.co', repo_type='dataset', repo_id='amang1802/wiki_topic_conditioned_405B'), pr_revision=None, pr_num=None)