In [1]:
from datasets import load_dataset, Dataset
from jinja2 import Template
from transformers import AutoTokenizer
from pydantic import BaseModel, TypeAdapter
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

import json
import random
import traceback

In [28]:
NUM_GPUS = 4
BATCH_SIZE = 128

In [3]:
#persona_content_ds = load_dataset('amang1802/wiki_topic_persona_sampled_405B')['train']

In [4]:
def add_text_to_persona_ds(ds):
    id_to_text = {}
    for i in range(content_ds.num_rows):
        id_to_text[content_ds[i]['id']] = content_ds[i]['text']

    return ds.map(lambda idx: {"text": id_to_text[idx]}, input_columns=['id'])

In [5]:
#persona_content_ds = add_text_to_persona_ds(persona_content_ds)

In [6]:
def pick_one_per_persona_ds(ds):
    uniq_personas = list(set(ds['persona_id']))
    uniq_contents = list(set(ds['id']))
    included_pairs = [(cid, random.choice(uniq_personas)) for cid in uniq_contents]

    return ds.filter(lambda row: (row['id'], row['persona_id']) in included_pairs)

In [7]:
#persona_uniq_ds = pick_one_per_persona_ds(persona_content_ds)

In [8]:
#persona_uniq_ds

In [9]:
#assert persona_uniq_ds.num_rows == len(set(persona_uniq_ds['id']))

In [10]:
with open("gt_accuracy.jinja2") as f:
    template_str = f.read()

In [11]:
with open("few_shots.json") as f:
    examples_json = json.load(f)

In [12]:
for example in examples_json:
    example['matches'] = json.dumps(example['matches'], indent=2)

In [13]:
template = Template(template_str)

In [14]:
system_prompt = template.render(examples=examples_json)

In [15]:
print(system_prompt)

# Instructions

You are a fact checker and you're required to compare a pair of texts and find segments that discuss the same facts and judge if they both match on the facts. The goal is to only judge the alignment on facts stated by both. They can state unique facts which we have to ignore. For a pair of texts, output the a list of common segments and if they match or not.

On the inclusion of segments:
- Inspect every sentence in text1 and text2 and include all segments that discuss common facts.
- If one of the text has a segment with no similar segment in the other text, ignore that segment altogether.
- Do not pair segments that have different facts. For example: If one says Roger Federer won the Wimbledon in 2003, and the other says Roger Federer won the French Open in 2009 - they are different facts and shouldn't be paired together.
- Repeating this instruction: Include all segments that discuss common facts.

On the matching sensitivity:
- It's possible that two statements don'

In [16]:
model_id = "Qwen/Qwen2.5-72B-Instruct"

In [17]:
class Judgement(BaseModel):
    text1: str
    text2: str
    rationale: str
    match: bool

ta = TypeAdapter(list[Judgement])

json_schema = ta.json_schema()

In [18]:
llm = LLM(model=model_id, max_model_len=24576, tensor_parallel_size=NUM_GPUS, gpu_memory_utilization=0.98, max_num_seqs=BATCH_SIZE)

INFO 01-06 20:33:42 config.py:478] This model supports multiple tasks: {'score', 'generate', 'embed', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 01-06 20:33:42 config.py:1216] Defaulting to use mp for distributed inference
INFO 01-06 20:33:42 llm_engine.py:249] Initializing an LLM engine (v0.6.5) with config: model='Qwen/Qwen2.5-72B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-72B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=24576, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=Fa

Loading safetensors checkpoint shards:   0% Completed | 0/37 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=555068)[0;0m INFO 01-06 20:34:36 model_runner.py:1097] Loading model weights took 33.9833 GB
INFO 01-06 20:34:36 model_runner.py:1097] Loading model weights took 33.9833 GB
[1;36m(VllmWorkerProcess pid=555069)[0;0m INFO 01-06 20:34:36 model_runner.py:1097] Loading model weights took 33.9833 GB
[1;36m(VllmWorkerProcess pid=555070)[0;0m INFO 01-06 20:34:36 model_runner.py:1097] Loading model weights took 33.9833 GB
[1;36m(VllmWorkerProcess pid=555068)[0;0m INFO 01-06 20:34:40 worker.py:241] Memory profiling takes 4.23 seconds
[1;36m(VllmWorkerProcess pid=555068)[0;0m INFO 01-06 20:34:40 worker.py:241] the current vLLM instance can use total_gpu_memory (139.72GiB) x gpu_memory_utilization (0.98) = 136.92GiB
[1;36m(VllmWorkerProcess pid=555068)[0;0m INFO 01-06 20:34:40 worker.py:241] model weights take 33.98GiB; non_torch_memory takes 4.14GiB; PyTorch activation peak memory takes 2.92GiB; the rest of the memory reserved for KV Cache is 95.88GiB.
[1

In [19]:
def compute_gt_accuracy(gt_texts, synthetic_texts):
    messages = [[{"role": "system", "content": system_prompt},
                {"role": "user", "content": "text1:\n" + text1 + "\n\ntext2:\n" + text2 + "\n\nresponse:" }]
                for text1, text2 in zip(gt_texts, synthetic_texts)]

    guided_decoding_params = GuidedDecodingParams(json=json_schema)
    outputs = llm.chat(messages, SamplingParams(temperature=0.3, top_p=0.9, max_tokens=1536, guided_decoding=guided_decoding_params))

    judgements = []
    scores = []
    for output in outputs:
        response = output.outputs[0].text.strip()
        judgement = []
        score = -1.0
        try:
            judgement = json.loads(response)
            num_matches = sum([1 for j in judgement if j['match']])
            score = num_matches / len(judgement) if len(judgement) > 0 else -1
        except Exception:
            pass
            #print(traceback.format_exc())

        judgements.append(judgement)
        scores.append(score)
        
    return {
        "cpt_judgement": judgements,
        "cpt_accuracy_score": scores
    }   

In [20]:
def get_score(ds):
    valid_scores = [score for score in ds['cpt_accuracy_score'] if score >= 0]
    return sum(valid_scores) / len(valid_scores)

In [27]:
ds_list = [
    'amang1802/cpt_gen_content_topic_conditioned_L3.1_70B_qna',
    'amang1802/cpt_gen_content_topic_conditioned_L3.1_70B_cpt_qna_epoch19',
    'amang1802/cpt_gen_content_topic_conditioned_L3.1_8B_qna',
    'amang1802/cpt_gen_content_topic_conditioned_L3.1_8B_cpt_qna_epoch19'
]

In [None]:
scores = {}
for ds in ds_list:
    content_ds = load_dataset(ds)['train']
    judged_ds = content_ds.map(compute_gt_accuracy, input_columns=['text', 'cpt_gen_content'], batched=True, batch_size=content_ds.num_rows)
    judged_ds.push_to_hub(ds)
    scores[ds] = get_score(judged_ds)

Map:   0%|          | 0/5119 [00:00<?, ? examples/s]


[Acessed prompts:   0% 0/5119 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[Acessed prompts:   0% 1/5119 [00:31<44:32:29, 31.33s/it, est. speed input: 122.09 toks/s, output: 2.97 toks/s]
[Acessed prompts:   0% 2/5119 [00:32<19:31:00, 13.73s/it, est. speed input: 315.72 toks/s, output: 6.26 toks/s]
[Acessed prompts:   0% 3/5119 [00:33<11:22:09,  8.00s/it, est. speed input: 483.34 toks/s, output: 9.88 toks/s]
[Acessed prompts:   0% 4/5119 [00:34<7:02:59,  4.96s/it, est. speed input: 609.62 toks/s, output: 13.62 toks/s]
[Acessed prompts:   0% 5/5119 [00:35<4:58:59,  3.51s/it, est. speed input: 703.66 toks/s, output: 17.27 toks/s]
[Acessed prompts:   0% 6/5119 [00:35<3:36:48,  2.54s/it, est. speed input: 852.17 toks/s, output: 21.02 toks/s]
[Acessed prompts:   0% 7/5119 [00:36<2:44:34,  1.93s/it, est. speed input: 948.62 toks/s, output: 24.77 toks/s]
[Acessed prompts:   0% 8/5119 [00:37<2:12:27,  1.56s/it, est. speed input: 1079.28 toks/s, output: 28.59 to

In [None]:
print(scores)