In [1]:
from pathlib import Path
import networkx as nx
import json


In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "microsoft/phi-1_5"  #

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

Tok = AutoTokenizer.from_pretrained(MODEL_NAME)

LLM = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_cfg,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

print("Model loaded on:", next(LLM.parameters()).device)



  from .autonotebook import tqdm as notebook_tqdm


Model loaded on: cuda:0


In [3]:
from transformers import pipeline as hf_pipeline

pipe = hf_pipeline(
    "text-generation",
    model=LLM,
    tokenizer=Tok,
    max_new_tokens=128,
    temperature=0.2,
    do_sample=False,
)



Device set to use cuda:0


In [110]:
question = "What is the capital of France?"

response = pipe(
    f"Question: {question}\nAnswer:", 
    max_new_tokens=100
)

full_text = response[0]["generated_text"]
lines = full_text.split("\n")

answer_line = ""
for line in lines:
    if "Answer:" in line:
        answer_line = line.split("Answer:")[-1].strip()
        break

print("Model Answer:", answer_line)




Model Answer: Paris.


### Extract the graphs in human readable format

In [None]:
import json
import re
import requests

In [66]:
with open("../world_leaders_qa_dataset.json") as f:
    data = json.load(f)

In [52]:
def extract_all_property_ids_from_dataset(dataset):
    property_ids = set()
    for item in dataset:
        question = item.get("question", "")
        matches = re.findall(r'http://www\.wikidata\.org/prop/direct/(p\d+)', question, flags=re.IGNORECASE)
        for match in matches:
            property_ids.add(match.upper())
        graph = item.get("graph", {})
        edges = graph.get("edges", [])
        for edge in edges:
            rel_id = edge.get("relation_id", "")
            if rel_id.startswith("P"):
                property_ids.add(rel_id.upper())
    return list(property_ids)

print(extract_all_property_ids_from_dataset(data)[:5])

['P102', 'P22', 'P40', 'P106', 'P35']


In [53]:
def extract_all_entity_ids_from_dataset(dataset):
    entity_ids = set()
    for item in dataset:
        graph = item.get("graph", {})
        nodes = graph.get("nodes", [])
        edges = graph.get("edges", [])
        for node in nodes:
            entity_ids.add(node["id"])
        for edge in edges:
            entity_ids.add(edge["source"])
            entity_ids.add(edge["target"])
    return list(entity_ids)

print(extract_all_entity_ids_from_dataset(data)[:5])

['Q12757697', 'Q4461939', 'Q114371233', 'Q93923', 'Q1192712']


In [54]:
property_ids = extract_all_property_ids_from_dataset(data)
entity_ids = extract_all_entity_ids_from_dataset(data)

In [55]:
# get corresponing labels from wikidata
def get_property_label(property_id):
    url = f"https://www.wikidata.org/wiki/Special:EntityData/{property_id}.json"
    response = requests.get(url)
    if response.status_code != 200:
        return property_id
    try:
        data = response.json()
        return data["entities"][property_id]["labels"]["en"]["value"]
    except:
        return property_id

def get_entity_label(entity_id):
    url = f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json"
    response = requests.get(url)
    if response.status_code != 200:
        return entity_id
    try:
        data = response.json()
        return data["entities"][entity_id]["labels"]["en"]["value"]
    except:
        return entity_id

PROPERTY_LABELS = {pid: get_property_label(pid) for pid in property_ids}
ENTITY_LABELS = {eid: get_entity_label(eid) for eid in entity_ids}

In [56]:
def clean_question(question):
    question = re.sub(
        r'http://www\.wikidata\.org/prop/direct/(p\d+)',lambda m: m.group(1).upper(),question,flags=re.IGNORECASE)
    return question

In [57]:
def process_question(question, property_labels):
    question = clean_question(question)
    for prop_id, label in property_labels.items():
        question = question.replace(prop_id, label)
    return question

In [58]:
def get_label_by_id(entity_id, nodes=None):
    return ENTITY_LABELS.get(entity_id, entity_id)

def extract_humanized_facts(graph, focus_entity=None, property_labels=None):
    edges = graph.get("edges", [])
    facts = []
    for edge in edges:
        if focus_entity and edge["source"] != focus_entity:
            continue
        source_label = get_label_by_id(edge["source"])
        target_label = get_label_by_id(edge["target"])
        relation_id = edge.get("relation_id", "")
        relation = property_labels.get(relation_id, edge.get("relation", relation_id))
        facts.append(f"{source_label} {relation} {target_label}")
    return facts


In [59]:
def build_prompt_with_context(question, facts, property_labels):
    cleaned_question = process_question(question, property_labels)
    context = "\n".join(facts)
    return f"Context:\n{context}\n\nQuestion: {cleaned_question}\nAnswer:"


In [89]:
# exmaple question + context 
sample = data[4]  

leader_id = sample.get("leader_id")  
facts = extract_humanized_facts(sample["graph"], focus_entity=leader_id, property_labels=PROPERTY_LABELS)

prompt = build_prompt_with_context(sample["question"], facts, PROPERTY_LABELS)
print(prompt)


Context:
Donald Trump relative John G. Trump
Donald Trump relative Vanessa Trump
Donald Trump relative Jared Kushner
Donald Trump relative Donald Trump III
Donald Trump relative Elizabeth Christ Trump
Donald Trump relative Lara Trump
Donald Trump relative Mary L. Trump
Donald Trump relative John Whitney Walter
Donald Trump medical condition COVID-19
Donald Trump topic's main Wikimedia portal Portal:Donald J. Trump

Question: Which person serves as the head of state of United States?
Answer:


In [107]:
def query_llm(prompt, model, tokenizer, max_tokens=200):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "Answer:" in decoded:
        response = decoded.split("Answer:")[-1].strip()
        return response.split("Question:")[0].strip()
    return decoded.strip()

In [91]:
answer = query_llm(prompt, LLM, Tok)
print(prompt)
print("Model answer:", answer)
print("Ground truth:", sample["answer"])

Context:
Donald Trump relative John G. Trump
Donald Trump relative Vanessa Trump
Donald Trump relative Jared Kushner
Donald Trump relative Donald Trump III
Donald Trump relative Elizabeth Christ Trump
Donald Trump relative Lara Trump
Donald Trump relative Mary L. Trump
Donald Trump relative John Whitney Walter
Donald Trump medical condition COVID-19
Donald Trump topic's main Wikimedia portal Portal:Donald J. Trump

Question: Which person serves as the head of state of United States?
Answer:
Model answer: The White House

Question:
Ground truth: Donald Trump


In [103]:
def build_prompt_with_context_given_instruction(question, facts, property_labels):
    cleaned_question = process_question(question, property_labels)
    context = "\n".join(facts)
    return ("The context provided contains relevant facts. Stick to them when answering.\n\n"
        f"Context:\n{context}\n\nQuestion: {cleaned_question}\nAnswer:"
    )


In [109]:
prompt = build_prompt_with_context_given_instruction(sample["question"], facts, PROPERTY_LABELS)
answer = query_llm(prompt, LLM, Tok)
print(prompt)
print("Model answer:", answer)
print("Ground truth:", sample["answer"])

The context provided contains relevant facts. Stick to them when answering.

Context:
Donald Trump relative John G. Trump
Donald Trump relative Vanessa Trump
Donald Trump relative Jared Kushner
Donald Trump relative Donald Trump III
Donald Trump relative Elizabeth Christ Trump
Donald Trump relative Lara Trump
Donald Trump relative Mary L. Trump
Donald Trump relative John Whitney Walter
Donald Trump medical condition COVID-19
Donald Trump topic's main Wikimedia portal Portal:Donald J. Trump

Question: Which person serves as the head of state of United States?
Answer:
Model answer: Donald Trump
Ground truth: Donald Trump
