In [1]:
# import os
# import re
# import json
# import glob
# import spacy
# import pickle
# import random
# import difflib
# import textwrap
# import datetime
# import jsonlines
# import numpy as np
# import pandas as pd
# from utils.shared_configs import LLAMA_MODEL_PATH, ZEHPYR_MODEL_PATH, get_sampling_params, initialize_llm

In [2]:
import ast
import pandas as pd
from utils.prompts_utils import construct_contextual_prompt, parse_contextual_el_output, COT_POOL
from utils.llm_configs import setup_llm, get_sampling_params
from utils.io import save_intermediate_outputs
from utils.EL_eval import evaluate_contextual_linking

2025-05-17 20:25:21.640989: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-17 20:25:21.650582: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747513521.662773 2852962 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747513521.666317 2852962 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-17 20:25:21.678649: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [3]:
def build_contextual_prompts(df: pd.DataFrame, model: str = "llama"):
    prompt_records = []
    for idx, row in df.iterrows():
        if not row['candidates_after_pointwise']:
            continue
        prompt = construct_contextual_prompt(row, model=model)
        label_map = {i+1: cand for i, cand in enumerate(row['candidates_after_pointwise'])}
        prompt_records.append((idx, label_map, prompt))
    return prompt_records

In [4]:
def run_contextual_inference(prompt_records, llm, sampling_params, model, batch_size=1000):
    all_outputs = []
    for chunk_start in range(0, len(prompt_records), batch_size):
        batch = prompt_records[chunk_start:chunk_start + batch_size]
        prompts = [rec[2] for rec in batch]
        if model == "llama":
            responses = llm.chat(messages=prompts, sampling_params=sampling_params)
        # else:
        #     responses = llm.generate(prompts=prompts, sampling_params=sampling_params)

        for (record, response) in zip(batch, responses):
            idx, label_map, _ = record
            text = response.outputs[0].text.strip()
            selected_label = parse_contextual_el_output(text)

            selected_candidate = label_map.get(selected_label, 0)
            all_outputs.append((idx, selected_candidate))
    return all_outputs

In [5]:
def write_linked_entities(df, all_outputs):
    top_linked_entities = [0] * len(df)
    for idx, selected_candidate in all_outputs:
        if selected_candidate:
            top_linked_entities[idx] = int(selected_candidate['wiki_id'])
    df['top_linked_entity'] = pd.Series(top_linked_entities, dtype="Int64")
    return df

In [6]:
def main():
    model = "llama"  # or "zephyr"
    sampling_params = get_sampling_params(max_tokens=350, temperature=0.6, top_p=0.9, stops=["</s>", "\n}"])
    llm, sampling_params = setup_llm(model=model)

    df = pd.read_csv(f"outputs/pointwise/final/intermediate_results_{model}_500_v2.csv", dtype={'wiki_ID': 'Int64'})
    df['candidates_after_pointwise'] = df['candidates_after_pointwise'].apply(
        lambda x: ast.literal_eval(x) if pd.notna(x) else []
    )

    prompt_records = build_contextual_prompts(df, model=model)

    all_outputs = run_contextual_inference(prompt_records, llm, sampling_params, model)
    df = write_linked_entities(df, all_outputs)

    save_intermediate_outputs(df, "outputs/contextual/contextual_linked_results.csv")
    metrics = evaluate_contextual_linking(df)
    print(metrics)

In [7]:
if __name__ == "__main__":
    main()

INFO 05-17 20:25:26 __init__.py:207] Automatically detected platform cuda.
INFO 05-17 20:25:36 config.py:549] This model supports multiple tasks: {'generate', 'embed', 'score', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 05-17 20:25:36 config.py:1555] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 05-17 20:25:36 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.3) with config: model='/datasets/ai/llama3/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f', speculative_config=None, tokenizer='/datasets/ai/llama3/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir='/work/pi_wenlongzhao_umass_edu/8/aranade/models', load_format=auto, tensor_parallel_size=1, pipeli

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


INFO 05-17 20:25:41 model_runner.py:1115] Loading model weights took 14.9888 GB
INFO 05-17 20:25:42 worker.py:267] Memory profiling takes 0.56 seconds
INFO 05-17 20:25:42 worker.py:267] the current vLLM instance can use total_gpu_memory (44.40GiB) x gpu_memory_utilization (0.90) = 39.96GiB
INFO 05-17 20:25:42 worker.py:267] model weights take 14.99GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 1.19GiB; the rest of the memory reserved for KV Cache is 23.70GiB.
INFO 05-17 20:25:42 executor_base.py:111] # cuda blocks: 12136, # CPU blocks: 2048
INFO 05-17 20:25:42 executor_base.py:116] Maximum concurrency for 131072 tokens per request: 1.48x
INFO 05-17 20:25:43 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:16<00:00,  2.11it/s]

INFO 05-17 20:26:00 model_runner.py:1562] Graph capturing finished in 17 secs, took 0.26 GiB
INFO 05-17 20:26:00 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 18.92 seconds





INFO 05-17 20:26:01 chat_utils.py:332] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.


Processed prompts: 100%|██████████| 465/465 [01:59<00:00,  3.90it/s, est. speed input: 8584.95 toks/s, output: 467.98 toks/s]



1. Cont — No parseable candidate index found

The context is: "The Pu — No parseable candidate index found

1. Context: The text mentions the London  — No parseable candidate index found

1. Context: The context is abou — No parseable candidate index found
"final_decision": "2.Olympic Games",
"reasoning": "The context is about the men's team final in artistic gymnastics at — No parseable candidate index found

1.  **Context:** The text mentio — No parseable candidate index found

1. **Context**: The text is about th — No parseable candidate index found

1. Context: The mention is in a tex — No parseable candidate index found

1. **Context**: The  — No parseable candidate index found

1. Context: The text discusses the IPC, disa — No parseable candidate index found

1. Context: The text mentions that Colo — No parseable candidate index found

1. **Context**: The text mentions the 2012 Summer — No parseable candidate index found

From th — No parseable candidate index found

1. Context:

In [11]:
df1 = pd.read_csv("outputs/contextual/contextual_linked_results.csv")
df1.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 500 entries, 0 to 499
Data columns (total 13 columns):
 #   Column                      Non-Null Count  Dtype 
---  ------                      --------------  ----- 
 0   article_text                500 non-null    object
 1   date                        500 non-null    object
 2   article_title               500 non-null    object
 3   entity_salience             500 non-null    int64 
 4   offsets                     500 non-null    object
 5   wiki_ID                     500 non-null    int64 
 6   entity_title                500 non-null    object
 7   surrounding_context         500 non-null    object
 8   candidates                  500 non-null    object
 9   pre_pt_len_candidates       500 non-null    int64 
 10  candidates_after_pointwise  500 non-null    object
 11  post_pt_len_candidates      500 non-null    int64 
 12  top_linked_entity           500 non-null    int64 
dtypes: int64(5), object(8)
memory usage: 50.9+ KB


In [15]:
len(df1['article_title'].unique())

123

In [21]:
df1[df1['top_linked_entity'] != 0]

Unnamed: 0,article_text,date,article_title,entity_salience,offsets,wiki_ID,entity_title,surrounding_context,candidates,pre_pt_len_candidates,candidates_after_pointwise,post_pt_len_candidates,top_linked_entity
0,"Homebush Bay, New South Wales —Earlier today, ...",2012-07-20,Australian Gliders glide past China women's na...,1,"(50, 60)",4689264,Australia,"Homebush Bay, New South Wales —Earlier today, ...","[{'mentions': 'australia', 'wiki_id': 4689264,...",93,"[{'mentions': 'australia', 'wiki_id': 4689264,...",27,20611325
2,"Homebush Bay, New South Wales —Earlier today, ...",2012-07-20,Australian Gliders glide past China women's na...,1,"(78, 85)",5405,China,"Homebush Bay, New South Wales —Earlier today, ...","[{'mentions': 'china', 'wiki_id': 5405, 'title...",85,"[{'mentions': 'china', 'wiki_id': 5405, 'title...",35,887850
4,"Homebush Bay, New South Wales —Earlier today, ...",2012-07-20,Australian Gliders glide past China women's na...,1,"(14, 29)",21654,New South Wales,"Homebush Bay, ###New South Wales### —Earlier t...","[{'mentions': 'new south wales', 'wiki_id': 21...",76,"[{'mentions': 'new south wales', 'wiki_id': 21...",37,21654
5,Border patrols at Britain's airports may be le...,2012-07-20,UK border officers go on Olympic strike,1,"(584, 597)",419342,David Cameron,"PCS members working in the Home Office, which ...","[{'mentions': 'david cameron', 'wiki_id': 4193...",11,"[{'mentions': 'david cameron', 'wiki_id': 4193...",9,419342
8,"Homebush Bay, New South Wales —Earlier today, ...",2012-07-21,China women's national wheelchair basketball t...,1,"(137, 158)",848348,wheelchair basketball,"Homebush Bay, New South Wales —Earlier today, ...","[{'mentions': 'wheelchair basketball', 'wiki_i...",23,"[{'mentions': 'wheelchair basketball', 'wiki_i...",19,848348
...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,The Philippines said yesterday it will take Ch...,2013-01-23,Philippines seeks United Nations arbitration o...,1,"(782, 787)",106539,ASEAN,Disputes such as those involving the Scarborou...,"[{'mentions': 'asean', 'wiki_id': 28741, 'titl...",8,"[{'mentions': 'asean', 'wiki_id': 28741, 'titl...",6,106539
496,The Philippines said yesterday it will take Ch...,2013-01-23,Philippines seeks United Nations arbitration o...,1,"(771, 777)",25734,Taiwan,Disputes such as those involving the Scarborou...,"[{'mentions': 'taiwan', 'wiki_id': 25734, 'tit...",84,"[{'mentions': 'taiwan', 'wiki_id': 25734, 'tit...",51,25734
497,The Philippines said yesterday it will take Ch...,2013-01-23,Philippines seeks United Nations arbitration o...,1,"(194, 209)",74209,South China Sea,The Philippines said yesterday it will take Ch...,"[{'mentions': 'south china sea', 'wiki_id': 74...",4,"[{'mentions': 'south china sea', 'wiki_id': 74...",3,74209
498,The Philippines said yesterday it will take Ch...,2013-01-23,Philippines seeks United Nations arbitration o...,1,"(4, 15)",23440,Philippines,The ###Philippines### said yesterday it will t...,"[{'mentions': 'philippines', 'wiki_id': 23440,...",90,"[{'mentions': 'philippines', 'wiki_id': 23440,...",25,23440


In [23]:
grouped = df1.groupby('article_title').apply(
    lambda group: list(
        group[['entity_title', 'top_linked_entity']].to_dict(orient='records')
    )
).reset_index(name='salient_entities')

grouped.to_json("outputs/article_to_salient_entities.json", orient="records", indent=2)

print("Done. Output saved to article_to_salient_entities.json")

Done. Output saved to article_to_salient_entities.json


  grouped = df1.groupby('article_title').apply(


In [8]:
    INSTRUCTION_PROMPT = "1. Context: Look at the surrounding text to understand the topic.\n2. Categories: Consider the type of the entity (person, organization, location, etc.).\n3. Modifiers: Pay attention to words or phrases that add details to the mention.\n4. Co-references: Check other mentions of the same entity in the text.\n5. Temporal and Geographical Factors: Consider when and where the text was written.\n6. External Knowledge: Use knowledge from outside the text.\nRemember, effective entity disambiguation requires understanding the text thoroughly, having world knowledge, and exercising good judgment.\n"

In [9]:
def contextual_el_prompt(entity, candidates, CoT_POOL=COT_POOL, INSTRUCTION_PROMPT=INSTRUCTION_PROMPT):
    ex = random.choice(COT_POOL)
    
    # CoT exemplar
    ex_ctx = f"{ex['left_context']} ###{ex['mention']}### {ex['right_context']}"
    ex_block = textwrap.dedent(f"""\
        The following example illustrates the task:
        Mention: {ex['mention']}
        Context: {ex_ctx}
        Candidates: {ex['candidates']}
        Answer: {ex['answer']}""")
    
    # create candidate map for later reconcilation
    # cand_map, target_lines = {}, []
    # for idx, cand in enumerate(random.sample(entity['candidates'], len(entity['candidates'])), 1):
    #     cand_map[idx] = cand
    #     target_lines.append(f"Entity {idx}: {cand['cand_name']}. {cand['cand_summary']}")
    label_map = {}                         # label → candidate‑dict
    cand_lines = []
    for lbl, cand in enumerate(random.sample(entity['candidates'], len(entity['candidates'])), 1):
        cand_lines.append(f"{lbl}. {cand['cand_name']} – {cand['cand_summary']}")
        label_map[lbl] = cand              # store mapping
        cand['prompt_label'] = lbl         # optional: keep inside dict

    tgt_ctx = (entity['left_context'].strip() + ' ###' + entity['entity_title'] + '### ' + entity['right_context'].strip())
    tgt_cand_lines = []
    for idx, cand in enumerate(random.sample(entity['candidates'], len(entity['candidates'])), 1):
        tgt_cand_lines.append(f"{idx}. {cand['cand_name']} – {cand['cand_summary']}")

    tgt_block = textwrap.dedent(f"""\
        Now I will give you a new mention, its context, and a list of candidate entities.
        The mention is highlighted with '###'.

        Mention: {entity['entity_title']}
        Context: {tgt_ctx}
        {'; '.join(tgt_cand_lines)}

        Think step by step.  At the end output exactly one line with the ID
        and name of the chosen entity, e.g.  '3.Barack Obama'.
        If none fit, output '-1.None'.
    """)
    
    prompt = f"""
        <|begin_of_text|><|start_header_id|>system<|end_header_id|>
        {SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
        {INSTRUCTION_PROMPT}{ex_block}{tgt_block}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """.strip()
    
    return prompt.strip(), label_map

In [10]:
MAX_PROMPTS = 7500
# (art_idx, ent_idx, cand_map, prompt)
cx_prompt_records = []
count=0
for art_idx, art in enumerate(pointwise_sed_outputs):
    for ent_idx, ent in enumerate(art["entities"]):
        if not ent.get("candidates"):
            count+=1
            continue

        if len(cx_prompt_records) >= MAX_PROMPTS:
            break

        prompt, cmap = contextual_el_prompt(ent, ent.get("candidates"))
        cx_prompt_records.append((art_idx, ent_idx, cmap, prompt))

    if len(cx_prompt_records) >= MAX_PROMPTS:
        break
        
print(f"Collected {len(cx_prompt_records)} prompts "
      f"(last = {cx_prompt_records[-1][:2]})")

NameError: name 'pointwise_sed_outputs' is not defined

In [None]:
count

In [None]:
cx_prompt_records[0]

In [None]:
outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)

In [None]:
# answer_pat = re.compile(r'(-?\d+)\s*\.(.+)')  
# for rec, out in zip(cx_prompt_records, outputs):
#     art_idx, ent_idx, cmap, _ = rec
#     text = out.outputs[0].text.strip()

#     # --- grab the first "<id>.<name>" we see (top‑to‑bottom)
#     m = answer_pat.search(text)
#     if not m:
#         chosen_id = -1
#     else:
#         chosen_id = int(m.group(1))

#     # --- write back
#     ent = pointwise_sed_outputs[art_idx]["entities"][ent_idx]
#     if chosen_id in cmap:                   # LLM picked a valid label
#         ent["candidates"] = [cmap[chosen_id]]
#         ent["linker_response"] = text       # (optional, for inspection)
#     else:                                   # -1.None  or invalid label
#         ent["candidates"] = []
#         ent["linker_response"] = text

answer_pat = re.compile(r'(-?\d+)\s*\.\s*(.+)', re.I)   # e.g.  3.Barack Obama

for record, output in zip(cx_prompt_records, outputs):
    art_idx, ent_idx, label_map, _ = record
    ent   = pointwise_sed_outputs[art_idx]["entities"][ent_idx]

    # raw text
    linker_resp = output.outputs[0].text.strip()
    ent["linker_response"] = linker_resp

    # default: no prediction
    ent["top_linked_entity"] = None

    # grab the last non‑empty line and parse "<label>.<name>"
    lines = [ln.strip() for ln in linker_resp.splitlines() if ln.strip()]
    m     = answer_pat.search(lines[-1]) if lines else None
    if not m:                         # failed to parse → leave None
        continue

    lbl = int(m.group(1))
    if lbl < 0 or lbl not in label_map:
        continue                      # "-1.None" or invalid label

    cand = label_map[lbl]             # candidate dict chosen by the model
    ent["top_linked_entity"] = {
        "cand_name":    cand["cand_name"],
        "cand_wiki_id": cand["cand_wiki_id"]
    }

In [None]:
with open(f"outputs/pointwise/pointwise_iteration_curr/contextual_el_sed_outputs_v2.json", "w", encoding="utf-8") as f:
    json.dump(pointwise_sed_outputs, f, indent=2)

In [None]:
def evaluate_linking(data):
    total = 0        # every entity
    linked = 0        # entities with a prediction
    gts = 0        # entities that have a gold wiki_ID
    correct = 0

    for art in data:
        for ent in art["entities"]:
            total += 1
            gt = ent.get("entity wiki_ID") or None
            if gt:
                gts += 1

            pred_id = (ent.get("top_linked_entity") or {}).get("cand_wiki_id")
            if pred_id:
                linked += 1

            if gt and pred_id and gt == pred_id:
                correct += 1

    accuracy  = correct / total if total else 0
    precision = correct / linked if linked else 0
    recall    = correct / gts if gts else 0

    print(f"Entities evaluated : {total}")
    print(f"Ground truths non‑null : {gts}")
    print(f"Predictions made : {linked}")
    print(f"Correct links : {correct}\n")

    print(f"Accuracy : {accuracy:.4%}")
    print(f"Precision @ linked : {precision:.4%}")
    print(f"Recall : {recall:.4%}")

evaluate_linking(pointwise_sed_outputs)