# MQuAKE-Remastered (CF3k) — GPT-2 Subject-Relation Check

This Colab notebook:
1. Loads `henryzhongsc/MQuAKE-Remastered` (split **CF3k**)
2. Extracts two cloze prompts from `single_hops` and ground-truth answers (incl. aliases)
3. Generates GPT-2 answers for each cloze
4. Computes cosine similarity via sentence-transformers and keeps a row only if **both** clozes pass the threshold
5. Saves two CSVs: detailed rows and passing `case_id`s

> You can tweak thresholds and generation params in the **Config** cell below.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title Setup: install dependencies
!pip install -q -U transformers datasets accelerate sentence-transformers


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#@title Imports and Config
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer, util
import torch, re, pandas as pd


In [None]:
import numpy as np

In [None]:
import os

In [None]:
from transformers import logging
logging.set_verbosity_error()

In [None]:

# Dataset/model config
DATASET_NAME = "henryzhongsc/MQuAKE-Remastered"  #@param {type:"string"}
SPLIT = "CF3k"                                   #@param {type:"string"}
MODEL_NAME = "EleutherAI/gpt-j-6B"                               #@param {type:"string"}

#"gpt2"
# Similarity & generation config
SIM_THRESHOLD = 0.7   #@param {type:"number"}
MAX_NEW_TOKENS = 30    #@param {type:"integer"}
TEMPERATURE = 0.2      #@param {type:"number"}
TOP_P = 0.95            #@param {type:"number"}
REQUIRE_EXACTLY_TWO = False #@param {type:"boolean"}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cuda


In [None]:
#@title Helper functions

_PUNCT_BREAK = re.compile(r"[\n\r\t]|[\.!?，。？！]")

# This function is for GPT2x, since it's result will contain something strange
def clean_first_fragment(text: str) -> str:
    if not isinstance(text, str):
        text = str(text)
    parts = [p.strip(' "\'\'') for p in _PUNCT_BREAK.split(text) if p.strip(' "\'\'')]
    return parts[0] if parts else text.strip(' "\'\'')

def get_first_sentence(text: str) -> str:
    """Extract only the first sentence (ends with . ! ? or newline)."""
    text = text.strip()
    parts = re.split(r'[.!?\n]', text)
    return parts[0].strip() if parts else text

def generate_answer(model, tokenizer, prompt: str) -> str:
    # input_text = prompt.strip() + " "
    input_text = prompt.strip()
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            # eos_token_id=tokenizer.eos_token_id,
            # pad_token_id=tokenizer.eos_token_id,
        )
    full = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # cont = full[len(input_text):]
    # return clean_first_fragment(cont)
    # cont = full[len(input_text):].strip()
    # return cont
    if full.startswith(input_text):
      full = full[len(input_text):]

    full = get_first_sentence(full)

    return full.strip()

In [None]:
# #@title Load dataset, GPT-2, and embedding model
# print(f"Loading dataset: {DATASET_NAME} [{SPLIT}] ...")
# ds = load_dataset(DATASET_NAME, split=SPLIT)

# print("Loading GPT-2 ...")
# tok = AutoTokenizer.from_pretrained(MODEL_NAME)
# if tok.pad_token is None:
#     tok.pad_token = tok.eos_token
# mdl = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()

# print("Loading sentence embedding model ...")
# emb = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
# print("Ready.")


In [None]:
#@title Load dataset, GPT-J, and embedding model
# print(f"Loading dataset: {DATASET_NAME} [{SPLIT}] ...")
# ds = load_dataset(DATASET_NAME, split=SPLIT)

print("Loading GPT-J ...")
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

mdl = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
    device_map="auto"
).eval()

print("Loading sentence embedding model ...")
# emb = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
emb = SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-dot-v1", device=DEVICE)
print("Ready.")

Loading dataset: henryzhongsc/MQuAKE-Remastered [CF3k] ...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/CF3k-00000-of-00001.parquet:   0%|          | 0.00/1.53M [00:00<?, ?B/s]

data/CF9k-00000-of-00001.parquet:   0%|          | 0.00/5.14M [00:00<?, ?B/s]

data/CF6334-00000-of-00001.parquet:   0%|          | 0.00/5.16M [00:00<?, ?B/s]

data/T-00000-of-00001.parquet:   0%|          | 0.00/595k [00:00<?, ?B/s]

Generating CF3k split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating CF9k split:   0%|          | 0/9171 [00:00<?, ? examples/s]

Generating CF6334 split:   0%|          | 0/9171 [00:00<?, ? examples/s]

Generating T split:   0%|          | 0/1864 [00:00<?, ? examples/s]

Loading GPT-J ...


tokenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/930 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

Loading sentence embedding model ...


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Ready.


In [None]:
def semantic_score(pred, truth):
    """Compute cosine similarity between predicted and ground-truth sentences"""
    emb1 = emb.encode(pred, convert_to_tensor=True)
    emb2 = emb.encode(truth, convert_to_tensor=True)
    score = float(util.cos_sim(emb1, emb2))
    return score

In [None]:
#@title Load DataSet

from datasets import load_dataset

REPO_ID   = "zjunlp/KnowEdit"
DATA_FILE = "benchmark/wiki_counterfact/train_cf.json"

ds = load_dataset(REPO_ID, data_files={"train": DATA_FILE}, split="train")

README.md: 0.00B [00:00, ?B/s]

train_cf.json: 0.00B [00:00, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
print(ds[0])


{'subject': 'Goursez Vreizh', 'prompt': 'The name of the country which Goursez Vreizh is associated with is', 'target_new': 'Franche-Comté', 'ground_truth': 'France', 'portability': {'Logical_Generalization': None, 'Reasoning': [{'ground_truth': 'Besançon', 'prompt': 'The name of the capital city of the country Goursez Vreizh is associated with is'}], 'Subject_Aliasing': [{'ground_truth': 'Franche-Comté', 'prompt': 'The name of the country which Gorsedd of Brittany is associated with is'}]}, 'locality': {'Forgetfulness': None, 'Relation_Specificity': [{'ground_truth': 'Jean Le Fustec', 'prompt': 'The name of the founder of Goursez Vreizh is'}]}}


In [None]:
#@title Turn into Question
import re

_THE_PREFIX = re.compile(r"^\s*The\s+", re.I)

def is_the_prompt(p: str) -> bool:
    return isinstance(p, str) and bool(_THE_PREFIX.match(p or ""))

_THE_TO_Q = re.compile(r"^\s*The\s+(.*?)(?:\s+is)?\s*$", re.I)

def to_what_is_question(prompt: str) -> str | None:
    """
      The occupation of Priest Petrus is  →  What is the occupation of Priest Petrus?
      The name of the country ... is      →  What is the name of the country ...?
    """
    if not isinstance(prompt, str):
        return None
    m = _THE_TO_Q.match(prompt.strip())
    if not m:
        return None
    phrase = m.group(1).strip()
    chars = list(phrase)
    for i, c in enumerate(chars):
        if c.isalpha():
            chars[i] = c.lower()
            break
    phrase_lc = "".join(chars)
    return f"What is the {phrase_lc}?"

In [None]:
def extract_locality_items(locality_field):
    items = []
    if isinstance(locality_field, dict):
        for _, val in locality_field.items():
            if isinstance(val, list):
                for obj in val:
                    if isinstance(obj, dict):
                        p = obj.get("prompt", "")
                        gt = obj.get("ground_truth", "")
                        if isinstance(p, str) and p.strip() and isinstance(gt, str) and gt.strip():
                            items.append((p, gt))
    return items

In [None]:
#@title Run evaluation
rows = []
passed_case_ids = []

In [None]:
subset = ds.select(range(100))

In [None]:
# def extract_truth(hop):
#   cloze = hop.get("cloze", "")
#   ans = hop.get("answer", "")
#   aliases = hop.get("answer_alias", []) or []
#   return cloze, ans, aliases

In [None]:
len(ds)

1427

In [None]:
# from collections import Counter

In [None]:
# lens = [len(x.get("single_hops", []) or []) for x in ds]
# cnt = Counter(lens)
# print("single_hops length counts:", cnt)
# print(">=2 hops:", sum(l >= 2 for l in lens))

In [None]:
# for i, row in enumerate(subset):
#     case_id = row.get("case_id", i)
#     single_hops = row.get("single_hops", None)
#     if not single_hops or not isinstance(single_hops, list):
#         continue
#     if REQUIRE_EXACTLY_TWO and len(single_hops) != 2:
#         continue

#     hop1, hop2 = single_hops[0], single_hops[1]

#     ans1, alias1 = extract_truth(hop1)
#     ans2, alias2 = extract_truth(hop2)
#     cloze1 = hop1.get("cloze", "")
#     cloze2 = hop2.get("cloze", "")

#     try:
#         gen1 = generate_answer(mdl, tok, cloze1)
#         gen2 = generate_answer(mdl, tok, cloze2)
#     except Exception as e:
#         print(f"[warn] generation failed at case_id={case_id}: {e}")
#         continue

#     pool1 = [ans1] + (alias1 or [])
#     pool2 = [ans2] + (alias2 or [])

#     sim1 = float(util.cos_sim(
#         emb.encode([gen1], convert_to_tensor=True),
#         emb.encode(pool1, convert_to_tensor=True)
#     ).max()) if pool1 else 0.0

#     sim2 = float(util.cos_sim(
#         emb.encode([gen2], convert_to_tensor=True),
#         emb.encode(pool2, convert_to_tensor=True)
#     ).max()) if pool2 else 0.0

#     both_pass = (sim1 >= SIM_THRESHOLD) and (sim2 >= SIM_THRESHOLD)
#     if both_pass:
#         passed_case_ids.append(case_id)

#     rows.append({
#         "case_id": case_id,
#         "cloze_1": cloze1,
#         "true_answer_1": ans1,
#         "gen_answer_1": gen1,
#         "sim_1": sim1,
#         "cloze_2": cloze2,
#         "true_answer_2": ans2,
#         "gen_answer_2": gen2,
#         "sim_2": sim2,
#         "both_pass": both_pass,
#     })

#     if (i + 1) % 100 == 0:
#         print(f"Processed {i+1} rows ...")

In [None]:
from statistics import mean

In [None]:
# for i, row in enumerate(ds):
#     case_id = row.get("case_id", i)
#     single_hops = row.get("single_hops", None)
#     if not single_hops or not isinstance(single_hops, list):
#         continue

#     hop_results = []
#     all_pass = True

#     for hop_idx, hop in enumerate(single_hops):
#         cloze, ans, alias = extract_truth(hop)
#         question = hop.get("question", "")

#         try:
#             gen = generate_answer(mdl, tok, question)
#         except Exception as e:
#             print(f"[warn] generation failed at case_id={case_id}, hop={hop_idx}: {e}")
#             all_pass = False
#             break

#         # pool = [ans] + (alias or [])
#         pool = cloze + ans
#         sim = 0.0
#         if pool:
#             sim = float(util.cos_sim(
#                 emb.encode([gen], convert_to_tensor=True),
#                 emb.encode(pool, convert_to_tensor=True)))
#             # ).max())

#         hop_results.append({
#             "question": question,
#             "cloze": cloze,
#             "true_answer": ans,
#             "gen_answer": gen,
#             "sim": sim,
#         })

#         if sim < SIM_THRESHOLD:
#             all_pass = False

#     if all_pass:
#         passed_case_ids.append(case_id)

#     row_dict = {"case_id": case_id, "all_pass": all_pass}
#     for idx, res in enumerate(hop_results, start=1):
#         row_dict.update({
#             f"question_{idx}": res["question"],
#             f"cloze_{idx}": res["cloze"],
#             f"true_answer_{idx}": res["true_answer"],
#             f"gen_answer_{idx}": res["gen_answer"],
#             f"sim_{idx}": res["sim"],
#         })
#     rows.append(row_dict)

#     if (i + 1) % 100 == 0:
#         print(f"Processed {i+1} rows ...")


In [None]:
for i, row in enumerate(ds):
    case_id = row.get("case_id", i)
    subject = row.get("subject", "")
    prompt  = row.get("prompt", "")
    truth   = row.get("ground_truth", "")
    if not prompt or not truth:
        continue

    true_sentence_main = f"{prompt.strip()} {truth.strip()}"

    if not is_the_prompt(prompt):
      continue

    question = to_what_is_question(prompt)
    if not question:
        continue

    gen_main = generate_answer(mdl, tok, question)
    sim_main = semantic_score(gen_main, true_sentence_main)


    locality = row.get("locality", {})
    loc_items = extract_locality_items(locality)
    loc_sims, loc_prompts, loc_truths, loc_gens = [], [], [], []

    for (lp, lt) in loc_items:
        ts = f"{lp.strip()} {lt.strip()}"

        if not is_the_prompt(lp):
          continue

        lq = to_what_is_question(lp)
        if not question:
          continue

        g  = generate_answer(mdl, tok, lq)
        s  = semantic_score(g, ts)
        loc_sims.append(s); loc_prompts.append(lp); loc_truths.append(lt); loc_gens.append(g)

    mean_loc_sim = float(mean(loc_sims)) if loc_sims else float("nan")
    pass_main = sim_main >= SIM_THRESHOLD
    pass_loc  = (mean_loc_sim >= SIM_THRESHOLD) if loc_sims else True
    pass_both = pass_main and pass_loc
    if pass_both:
        passed_case_ids.append(case_id)

    rows.append({
        "case_id": case_id,
        "subject": subject,
        "prompt": prompt,
        "ground_truth": truth,
        "true_sentence_main": true_sentence_main,
        "gen_main": gen_main,
        "sim_main_dot": sim_main,
        "num_locality": len(loc_sims),
        "mean_loc_sim_cos": mean_loc_sim,
        "loc_sims_dot": "; ".join(f"{x:.6f}" for x in loc_sims),
        "loc_prompts": " || ".join(loc_prompts),
        "loc_truths": " || ".join(loc_truths),
        "loc_gens": " || ".join(loc_gens),
        "pass_main": pass_main,
        "pass_loc": pass_loc,
        "pass_both": pass_both,
    })

    if (i + 1) % 100 == 0:
        print(f"Processed {i+1}/{len(ds)} rows ...")

Processed 100/1427 rows ...
Processed 200/1427 rows ...
Processed 300/1427 rows ...
Processed 400/1427 rows ...
Processed 500/1427 rows ...
Processed 600/1427 rows ...
Processed 700/1427 rows ...
Processed 900/1427 rows ...
Processed 1000/1427 rows ...
Processed 1100/1427 rows ...
Processed 1200/1427 rows ...
Processed 1300/1427 rows ...
Processed 1400/1427 rows ...


In [None]:
notebook_dir = '/content/drive/MyDrive/DL_FINAL'
save_dir = notebook_dir

In [None]:
"""
df = pd.DataFrame(rows)
df.to_csv("gpt2_subject_check_full.csv", index=False)
pd.DataFrame({"case_id": passed_case_ids}).to_csv("gpt2_subject_check_passed.csv", index=False)

print("\nDone.")
print(f"Saved detailed results -> gpt2_subject_check_full.csv  (rows={len(df)})")
print(f"Saved passed case_ids   -> gpt2_subject_check_passed.csv (count={len(passed_case_ids)})")
"""

In [None]:
# df = pd.DataFrame(rows)

# full_path = os.path.join(save_dir, "gpt2_subject_check_full.csv")
# passed_path = os.path.join(save_dir, "gpt2_subject_check_passed.csv")

# df.to_csv(full_path, index=False)
# pd.DataFrame({"case_id": passed_case_ids}).to_csv(passed_path, index=False)

# print(f"✅ Saved detailed results -> {full_path}")
# print(f"✅ Saved passed case_ids   -> {passed_path}")

In [None]:
df = pd.DataFrame(rows)

full_path = os.path.join(save_dir, "KnowEdit_gptJ_subject_check_full.csv")
passed_path = os.path.join(save_dir, "KnowEdit_gptJ_subject_check_passed.csv")

df.to_csv(full_path, index=False)
pd.DataFrame({"case_id": passed_case_ids}).to_csv(passed_path, index=False)

print(f"✅ Saved detailed results -> {full_path}")
print(f"✅ Saved passed case_ids   -> {passed_path}")

✅ Saved detailed results -> /content/drive/MyDrive/DL_FINAL/KnowEdit_gptJ_subject_check_full.csv
✅ Saved passed case_ids   -> /content/drive/MyDrive/DL_FINAL/KnowEdit_gptJ_subject_check_passed.csv


In [None]:
passed = df[df['pass_both']]
passed.shape[0]

601