In [None]:
!pip install -U transformers accelerate bitsandbytes sentencepiece


Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.49.0


In [None]:
import json
import re
import os
from tqdm import tqdm
import torch

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)


In [None]:

def init_qwen(device):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4"
    )

    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen2.5-3B-Instruct",
        trust_remote_code=True,
        use_fast=False
    )

    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen2.5-3B-Instruct",
        trust_remote_code=True,
        device_map="auto",
        quantization_config=bnb_config
    )

    model.eval()
    return model, tokenizer


In [None]:

def download_mquake():
    if not os.path.exists("MQuAKE-CF-3k.json"):
        os.system(
            "wget https://raw.githubusercontent.com/dominic-simon/CHECK-Knowledge-Editing/refs/heads/main/datasets/MQuAKE-CF-3k.json"
        )

def load_mquake():
    with open("MQuAKE-CF-3k.json") as f:
        return json.load(f)


In [None]:

EDIT_BANK = []

def populate_edit_bank(dataset):
    bank = []
    for case in dataset:
        r = case["requested_rewrite"][0]
        bank.append({
            "subject": r["subject"].lower().strip(),
            "relation": r["relation_id"].strip(),
            "object": r["target_new"]["str"].strip()
        })
    return bank

def get_edited_object(subject, relation):
    subject = subject.lower().strip()
    for e in EDIT_BANK:
        if e["subject"] == subject and e["relation"] == relation:
            return e["object"]
    return None


In [None]:

def extract_relations(question, model, tokenizer, device, max_new_tokens=64):
    prompt = f"""
Extract the relations needed to answer the question.
Output ONLY relation names separated by |.

Question:
{question}

Relations:
"""

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True
    ).to(device)

    outputs = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=max_new_tokens
    )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    rel_text = text[len(prompt):].strip()
    relations = [r.strip() for r in rel_text.split("|") if r.strip()]

    return relations


In [None]:

def answer_hop(subject, relation, model, tokenizer, device):
    # Knowledge edit override (THIS IS THE EDITING)
    edited = get_edited_object(subject, relation)
    if edited is not None:
        return edited

    # Otherwise ask the model
    prompt = f"Answer concisely.\nWhat is the {relation} of {subject}?\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=32
    )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text[len(prompt):].strip()



In [None]:

def baseline_CHECK(question, start_entity, model, tokenizer, device):
    relations = extract_relations(question, model, tokenizer, device)

    if len(relations) == 0:
        return None

    # HARD CAP: baseline uses short chains
    relations = relations[:3]

    s = start_entity
    for r in reversed(relations):
        s = answer_hop(s, r, model, tokenizer, device)

    return s


In [None]:

def normalize(text):
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()

def is_correct(pred, gold):
    if pred is None:
        return False

    p = normalize(pred)

    if isinstance(gold, dict):
        answers = []
        if "answer" in gold:
            answers.append(gold["answer"])
        if "aliases" in gold:
            answers.extend(gold["aliases"])
    else:
        answers = [gold]

    for g in answers:
        if normalize(g) in p or p in normalize(g):
            return True
    return False


def evaluate_baseline(dataset, model, tokenizer, device, limit=None):
    correct_q = 0
    total_q = 0
    correct_cases = 0

    for case in tqdm(dataset[:limit]):
        case_ok = False

        start_entity = case["requested_rewrite"][0]["subject"]

        for q in case["questions"]:
            pred = baseline_CHECK(q, start_entity, model, tokenizer, device)
            if is_correct(pred, case["new_answer"]):
                correct_q += 1
                case_ok = True
            total_q += 1

        if case_ok:
            correct_cases += 1

    return {
        "per_question_accuracy": correct_q / total_q,
        "per_case_accuracy": correct_cases / len(dataset[:limit]),
        "correct_questions": correct_q,
        "total_questions": total_q
    }


In [None]:

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model, tokenizer = init_qwen(device)

    download_mquake()
    data = load_mquake()

    EDIT_BANK = populate_edit_bank(data)

    # Start with small subset for sanity
    results = evaluate_baseline(
        data,
        model,
        tokenizer,
        device,
        limit=300   # increase to None for full run
    )

    print("\nBASELINE RESULTS (WITH KNOWLEDGE EDITING)")
    print(results)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]