In [1]:
from pathlib import Path
import networkx as nx
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import warnings
import re
import requests
warnings.filterwarnings("ignore")


  from .autonotebook import tqdm as notebook_tqdm


### Extract the graphs in human readable format

In [3]:
with open("../world_leaders_qa_dataset.json") as f:
    data = json.load(f)

In [4]:
def extract_all_property_ids_from_dataset(dataset):
    property_ids = set()
    for item in dataset:
        question = item.get("question", "")
        matches = re.findall(r'http://www\.wikidata\.org/prop/direct/(p\d+)', question, flags=re.IGNORECASE)
        for match in matches:
            property_ids.add(match.upper())
        graph = item.get("graph", {})
        edges = graph.get("edges", [])
        for edge in edges:
            rel_id = edge.get("relation_id", "")
            if rel_id.startswith("P"):
                property_ids.add(rel_id.upper())
    return list(property_ids)

print(extract_all_property_ids_from_dataset(data)[:5])

['P31', 'P21', 'P735', 'P27', 'P106']


In [5]:
def extract_all_entity_ids_from_dataset(dataset):
    entity_ids = set()
    for item in dataset:
        graph = item.get("graph", {})
        nodes = graph.get("nodes", [])
        edges = graph.get("edges", [])
        for node in nodes:
            entity_ids.add(node["id"])
        for edge in edges:
            entity_ids.add(edge["source"])
            entity_ids.add(edge["target"])
    return list(entity_ids)

print(extract_all_entity_ids_from_dataset(data)[:5])

['Q18921193', 'Q3956186', 'Q2354258', 'Q102257613', 'Q16865067']


In [6]:
property_ids = extract_all_property_ids_from_dataset(data)
entity_ids = extract_all_entity_ids_from_dataset(data)

In [7]:
# human readable labels from wikidata id labels
def get_property_label(property_id):
    url = f"https://www.wikidata.org/wiki/Special:EntityData/{property_id}.json"
    response = requests.get(url)
    if response.status_code != 200:
        return property_id
    try:
        data = response.json()
        return data["entities"][property_id]["labels"]["en"]["value"]
    except:
        return property_id

def get_entity_label(entity_id):
    url = f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json"
    response = requests.get(url)
    if response.status_code != 200:
        return entity_id
    try:
        data = response.json()
        return data["entities"][entity_id]["labels"]["en"]["value"]
    except:
        return entity_id
    
def build_entity_labels_from_graph(dataset):
    labels = {}
    for item in dataset:
        for node in item.get("graph", {}).get("nodes", []):
            node_id = node["id"]
            label = node["label"]

            if isinstance(label, str):
                labels[node_id] = label
            elif isinstance(label, dict):
                name_like_keys = ["Commons category", "en", "name", "label"]
                found = None
                for key in name_like_keys:
                    if key in label:
                        found = label[key]
                        break
                if found:
                    labels[node_id] = found
                else:
                    labels[node_id] = node_id  
            else:
                labels[node_id] = node_id  

    return labels


PROPERTY_LABELS = {pid: get_property_label(pid) for pid in property_ids}
ENTITY_LABELS = {eid: get_entity_label(eid) for eid in entity_ids}
# build_entity_labels_from_graph(data)


In [14]:
def clean_question(question):
    question = re.sub(
        r'http://www\.wikidata\.org/prop/direct/(p\d+)',lambda m: m.group(1).upper(),question,flags=re.IGNORECASE)
    return question

In [15]:
def process_question(question, property_labels):
    question = clean_question(question)
    for prop_id, label in property_labels.items():
        question = question.replace(prop_id, label)
    return question

In [16]:
def get_label_by_id(entity_id, nodes=None):
    return ENTITY_LABELS.get(entity_id, entity_id)

def extract_humanized_facts(graph, focus_entity=None, property_labels=None):
    edges = graph.get("edges", [])
    facts = []
    for edge in edges:
        if focus_entity and edge["source"] != focus_entity:
            continue
        source_label = get_label_by_id(edge["source"])
        target_label = get_label_by_id(edge["target"])
        relation_id = edge.get("relation_id", "")
        relation = property_labels.get(relation_id, edge.get("relation", relation_id))
        facts.append(f"{source_label} {relation} {target_label}")
    return facts


In [17]:
def build_prompt_with_context(question, facts, property_labels):
    cleaned_question = process_question(question, property_labels)
    context = "\n".join(facts)
    return f"Context:\n{context}\n\nQuestion: {cleaned_question}\nAnswer:"


In [19]:
# exmaple question + context 
sample = data[4]  

leader_id = sample.get("leader_id")  
facts = extract_humanized_facts(sample["graph"], focus_entity=leader_id, property_labels=PROPERTY_LABELS)

prompt = build_prompt_with_context(sample["question"], facts, PROPERTY_LABELS)
print(prompt)


Context:
Donald Trump relative John G. Trump
Donald Trump relative Vanessa Trump
Donald Trump relative Jared Kushner
Donald Trump relative Donald Trump III
Donald Trump relative Elizabeth Christ Trump
Donald Trump relative Lara Trump
Donald Trump relative Mary L. Trump
Donald Trump relative John Whitney Walter
Donald Trump medical condition COVID-19
Donald Trump topic's main Wikimedia portal Portal:Donald J. Trump

Question: Which person serves as the head of state of United States?
Answer:


In [20]:
def load_llm(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        ),
        device_map="auto",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    return model, tokenizer


In [21]:
def query_llm(prompt, model, tokenizer, max_tokens=40):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=False,
        temperature=0.0,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    if "Answer:" in output:
        return output.split("Answer:")[1].split("\n")[0].strip()
    else:
        return output.strip().split("\n")[0]


In [22]:
models = {
    "phi": "microsoft/phi-1_5",
    "deepseek": "deepseek-ai/deepseek-coder-1.3b-instruct",
    "tinyllama": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
}

loaded = {name: load_llm(path) for name, path in models.items()}



In [34]:
sample = data[4]
leader_id = sample.get("leader_id")
facts = extract_humanized_facts(sample["graph"], focus_entity=leader_id, property_labels=PROPERTY_LABELS)
prompt = build_prompt_with_context(sample["question"], facts, PROPERTY_LABELS)

print("Prompt:\n" + prompt)
print("\n--- Model Answers ---")

for i, (name, (model, tokenizer)) in enumerate(loaded.items()):
    # print(prompt)
    answer = query_llm(prompt, model, tokenizer)
    print(f"[{name}] Answer:", answer)

print("\nGround truth:", sample["answer"])

Prompt:
Context:
Donald Trump relative John G. Trump
Donald Trump relative Vanessa Trump
Donald Trump relative Jared Kushner
Donald Trump relative Donald Trump III
Donald Trump relative Elizabeth Christ Trump
Donald Trump relative Lara Trump
Donald Trump relative Mary L. Trump
Donald Trump relative John Whitney Walter
Donald Trump medical condition COVID-19
Donald Trump topic's main Wikimedia portal Portal:Donald J. Trump

Question: Which person serves as the head of state of United States?
Answer:

--- Model Answers ---
[phi] Answer: Donald Trump
[deepseek] Answer: John F. Kennedy
[tinyllama] Answer: Donald Trump

Ground truth: Donald Trump


In [37]:
def build_prompt_with_context_given_instruction(question, facts, property_labels):
    cleaned_question = process_question(question, property_labels)
    context = "\n".join(facts)
    return ("The context provided contains relevant facts. Answer based on them.\n\n"
        f"Context:\n{context}\n\nQuestion: {cleaned_question}\nAnswer:"
    )


In [38]:
prompt = build_prompt_with_context_given_instruction(sample["question"], facts, PROPERTY_LABELS)
print("Prompt:\n" + prompt)
print("\n--- Model Answers ---")

for i, (name, (model, tokenizer)) in enumerate(loaded.items()):
    # print(prompt)
    answer = query_llm(prompt, model, tokenizer)
    print(f"[{name}] Answer:", answer)

print("\nGround truth:", sample["answer"])

Prompt:
The context provided contains relevant facts. Answer based on them.

Context:
Donald Trump relative John G. Trump
Donald Trump relative Vanessa Trump
Donald Trump relative Jared Kushner
Donald Trump relative Donald Trump III
Donald Trump relative Elizabeth Christ Trump
Donald Trump relative Lara Trump
Donald Trump relative Mary L. Trump
Donald Trump relative John Whitney Walter
Donald Trump medical condition COVID-19
Donald Trump topic's main Wikimedia portal Portal:Donald J. Trump

Question: Which person serves as the head of state of United States?
Answer:

--- Model Answers ---
[phi] Answer: Donald Trump
[deepseek] Answer: Donald J. Trump
[tinyllama] Answer: Donald Trump

Ground truth: Donald Trump
