In [None]:
!pip install -U transformers accelerate bitsandbytes sentencepiece

Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.49.0


In [None]:
import json
import re
import itertools
import os
from tqdm import tqdm
import torch

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModel,
    BitsAndBytesConfig
)


In [None]:
# =========================================================
# Utility: Embeddings + Similarity (Contriever)
# =========================================================

def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    return token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]

def get_sent_embeddings(sents, contriever, tok, device, BSZ=32):
    all_embs = []
    for i in range(0, len(sents), BSZ):
        batch = sents[i:i+BSZ]
        inputs = tok(batch, padding=True, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = contriever(**inputs)
            emb = mean_pooling(outputs[0], inputs["attention_mask"])
        all_embs.append(emb.cpu())
    return torch.vstack(all_embs)

def retrieve_similarities_cos(query_emb, fact_embs, device="cpu"):
    return torch.nn.functional.cosine_similarity(
        query_emb.to(device), fact_embs.to(device)
    )

def retrieve_similarities_dot(query_emb, fact_embs, device="cpu"):
    return query_emb.to(device) @ fact_embs.T.to(device)


In [None]:
def greedy_chain_repair(ro, template_bank):
      ro = list(ro)

      for i in range(len(ro) - 1):
          r1, r2 = ro[i], ro[i + 1]

          out_types = template_bank.get(r1, {}).get("out", [])
          in_types = template_bank.get(r2, {}).get("in", [])

          if not set(out_types).intersection(set(in_types)):
              # Try to find a later relation that matches
              for j in range(i + 2, len(ro)):
                  rj = ro[j]
                  in_j = template_bank.get(rj, {}).get("in", [])
                  if set(out_types).intersection(set(in_j)):
                      ro[i + 1], ro[j] = ro[j], ro[i + 1]
                      break
      return ro

In [None]:
class CHECK:
    def __init__(
        self,
        model,
        tokenizer,
        refined_model,
        type_prompt="",
        extraction_prompt="",
        subq_prompt="",
        qa_prompt="",
        type_template=[],
        similarity="cos",
        sim_thresh=0.8,
        device="cuda"
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.refined = refined_model
        self.device = device

        self.type_prompt = type_prompt
        self.extraction_prompt = extraction_prompt
        self.subq_prompt = subq_prompt
        self.qa_prompt = qa_prompt

        self.similarity = similarity
        self.sim_thresh = sim_thresh

        # Edit memory
        self.edit_bank = []
        self.embedding_bank = []
        self.edited_entity_bank = []

        # Contriever
        self.contriever = AutoModel.from_pretrained(
            "facebook/contriever-msmarco"
        ).to(device)
        self.c_tok = AutoTokenizer.from_pretrained(
            "facebook/contriever-msmarco"
        )

        # Relation type templates
        self.template_bank = {}
        relations = []

        for r in type_template:
          r = r.strip()

          if ": " not in r:
              raise ValueError(f"Invalid template format (missing ': '): {r}")

          rel, io = r.split(": ", 1)

          # Allow relations with only one side by assuming symmetric types
          if ", " in io:
              ins, outs = io.split(", ", 1)
          else:
              # Fallback: use same types for input and output
              ins = io
              outs = io


          relations.append(rel)

          self.template_bank[rel] = {
              "in": ins.split(" "),
              "out": outs.split(" ")
          }


        self.template_embedding = get_sent_embeddings(
            relations, self.contriever, self.c_tok, device
        )

    # -----------------------------------------------------

    def query_model(self, prompt, max_new_tokens=64, temperature=0.0):
        inputs = self.tokenizer(
            prompt, return_tensors="pt", truncation=True
        ).to(self.device)

        outputs = self.model.generate(
            **inputs,
            do_sample=False,
            temperature=temperature,
            max_new_tokens=max_new_tokens
        )

        text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return text[len(prompt):].strip()

    # -----------------------------------------------------

    def add_edits(self, edits, edit_sentences):
        """
        edits: list of (s, r, o)
        """
        self.edit_bank = edits
        sr = [" ".join(e[:2]) for e in edits]

        self.embedding_bank = get_sent_embeddings(
            sr, self.contriever, self.c_tok, self.device
        )

        self.edited_entity_bank = []
        for sent in edit_sentences:
            self.edited_entity_bank.append(sent.lower())

    # -----------------------------------------------------

    def check_sr(self, s, r):
        sro_embed = get_sent_embeddings(
            [f"{s} {r}"], self.contriever, self.c_tok, self.device
        )

        best_sim = 0
        best_o = None

        for i in range(len(self.edit_bank)):
            if self.similarity == "cos":
                sim = retrieve_similarities_cos(
                    sro_embed, self.embedding_bank[i].unsqueeze(0), self.device
                )[0].item()
            else:
                sim = retrieve_similarities_dot(
                    sro_embed, self.embedding_bank[i].unsqueeze(0), self.device
                )[0][0].item()

            if sim > best_sim:
                best_sim = sim
                best_o = self.edit_bank[i][-1]

        if best_sim > self.sim_thresh:
            return best_o
        return None

    # -----------------------------------------------------

    def answer_question(self, question, max_new_tokens=64):
      # Extract relation chain
      prompt = f"{self.extraction_prompt}\nQuestion: {question}\nRelations:"
      ro_text = self.query_model(prompt, max_new_tokens)
      ro = [r.strip() for r in ro_text.split("|") if r.strip()]

      if len(ro) == 0:
          return None

      # HARD CAP chain length (important)
      MAX_CHAIN_LEN = 3
      if len(ro) > MAX_CHAIN_LEN:
          ro = ro[:MAX_CHAIN_LEN]

      # FAST chain repair (NO permutations)
      best_chain = greedy_chain_repair(ro, self.template_bank)

      # Execute chain
      s = question
      for r in reversed(best_chain):
          o = self.check_sr(s, r)
          if o is None:
              q_prompt = f"{self.qa_prompt}\nQuestion: What is the {r} of {s}?\nAnswer:"
              o = self.query_model(q_prompt, max_new_tokens)
          s = o

      return s



In [None]:

def init_qwen(device):
    bnb = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4"
    )

    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen2.5-3B-Instruct",
        trust_remote_code=True,
        use_fast=False
    )

    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen2.5-3B-Instruct",
        trust_remote_code=True,
        device_map="auto",
        quantization_config=bnb
    )

    model.eval()
    return model, tokenizer


In [None]:
def load_mquake():
    if not os.path.exists("MQuAKE-CF-3k.json"):
        os.system(
            "wget https://raw.githubusercontent.com/dominic-simon/CHECK-Knowledge-Editing/refs/heads/main/datasets/MQuAKE-CF-3k.json"
        )
    with open("MQuAKE-CF-3k.json") as f:
        return json.load(f)

def normalize(s):
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s]", "", s)
    return re.sub(r"\s+", " ", s).strip()

def is_correct(pred, gold):
    if pred is None:
        return False
    p = normalize(pred)
    if isinstance(gold, dict):
        gold = gold.get("answer", "")
    return normalize(gold) in p or p in normalize(gold)


In [None]:
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model, tokenizer = init_qwen(device)

    checker = CHECK(
        model=model,
        tokenizer=tokenizer,
        refined_model=None,  # optional entity linker
        extraction_prompt="Extract the relation chain separated by |",
        qa_prompt="Answer concisely.",
        type_template=[
            "P27: person place",
            "P19: person place",
            "P36: place place",
            "P50: thing person"
        ],
        device=device
    )

    data = load_mquake()

    edits = []
    edit_sents = []

    for c in data:
        r = c["requested_rewrite"][0]
        edits.append((r["subject"].lower(), r["relation_id"], r["target_new"]["str"]))
        edit_sents.append(r["target_new"]["str"])

    checker.add_edits(edits, edit_sents)

    correct = 0
    total = 0

    for case in tqdm(data[:500]):
        for q in case["questions"]:
            pred = checker.answer_question(q)
            if is_correct(pred, case["new_answer"]):
                correct += 1
            total += 1

    print("Per-question accuracy:", correct / total)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

  0%|          | 0/500 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  4%|▍         | 22/500 [19:47<7:08:31, 53.79s/it]