In [9]:
import spacy
import requests
import json
import re
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

In [10]:
# Load spaCy and REBEL
nlp = spacy.load("en_core_web_sm")
rel_ext = pipeline("text2text-generation", model="Babelscape/rebel-large")

# Load Mistral
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)

print("[INFO] Models loaded.")


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Models loaded.


In [11]:
# ---------------------------------------------
# Wikidata utility functions
# ---------------------------------------------
def get_wikidata_qid(entity_name):
    try:
        search_url = "https://www.wikidata.org/w/api.php"
        params = {
            "action": "wbsearchentities",
            "language": "en",
            "format": "json",
            "search": entity_name
        }
        response = requests.get(search_url, params=params).json()
        if response['search']:
            return response['search'][0]['id']
    except Exception as e:
        print(f"Error getting QID for {entity_name}: {e}")
    return None


In [12]:
def get_wikidata_types(qid):
    url = f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json"
    try:
        data = requests.get(url).json()
        claims = data["entities"][qid]["claims"]
        if 'P31' in claims:
            types = []
            for inst in claims['P31']:
                if 'datavalue' in inst['mainsnak']:
                    value = inst['mainsnak']['datavalue']['value']
                    if 'id' in value:
                        types.append(value['id'])
            return types
    except Exception as e:
        print(f"Error fetching types for {qid}: {e}")
    return []


In [13]:
def get_label(qid):
    try:
        url = f"https://www.wikidata.org/wiki/Special:EntityData/{qid}.json"
        data = requests.get(url).json()
        return data["entities"][qid]["labels"]["en"]["value"]
    except:
        return qid


In [14]:
def get_pid_from_label(label):
    try:
        url = "https://www.wikidata.org/w/api.php"
        params = {
            "action": "wbsearchentities",
            "language": "en",
            "format": "json",
            "type": "property",
            "search": label
        }
        response = requests.get(url, params=params).json()
        if response.get("search"):
            return response["search"][0]["id"]
    except Exception as e:
        print(f"Error getting PID for {label}: {e}")
    return ""


In [15]:
# ---------------------------------------------
# Detection and extraction logic
# ---------------------------------------------
def is_wikidata_text(text, threshold=0.5):
    doc = nlp(text)
    matched = 0
    for ent in doc.ents:
        qid = get_wikidata_qid(ent.text)
        if qid:
            matched += 1
    ratio = matched / max(1, len(doc.ents))
    print(f"[DEBUG] Entity match ratio: {ratio:.2f}")
    return ratio > threshold


In [16]:
def extract_rebel_triples(text):
    output = rel_ext(text, max_length=512)[0]['generated_text']
    print("\n[DEBUG] Raw REBEL Output:\n", output)

    triples = []
    if "<triplet>" in output:
        for t in output.split("<triplet>")[1:]:
            try:
                parts = t.split("<subj>")[1].split("<obj>")
                subj_pred = parts[0].strip()
                obj = parts[1].strip()
                if "<rel>" in subj_pred:
                    subj, pred = subj_pred.split("<rel>")
                    triples.append((subj.strip(), pred.strip(), obj.strip()))
            except:
                continue
    else:
        tokens = [t.strip() for t in output.split("  ") if t.strip()]
        for i in range(0, len(tokens) - 2, 3):
            triples.append((tokens[i], tokens[i + 1], tokens[i + 2]))
    return triples

In [17]:
# ---------------------------------------------
# Wikidata-based Ontology Builder
# ---------------------------------------------
def build_ontology(text, title="Generated Ontology", ontology_id="ont_wikidata_001"):
    concepts = {}
    relation_defs = []
    seen_relations = set()
    
    doc = nlp(text)
    for ent in doc.ents:
        qid = get_wikidata_qid(ent.text)
        if qid:
            types = get_wikidata_types(qid)
            for t in types:
                label = get_label(t)
                concepts[label] = {"qid": t, "label": label}

    triples = extract_rebel_triples(text)
    
    for subj, rel, obj in triples:
        subj_qid = get_wikidata_qid(subj)
        obj_qid = get_wikidata_qid(obj)
        pid = get_pid_from_label(rel)

        for qid in [subj_qid, obj_qid]:
            if qid:
                types = get_wikidata_types(qid)
                for t in types:
                    label = get_label(t)
                    concepts[label] = {"qid": t, "label": label}

        key = (rel.lower(), subj_qid, obj_qid)
        if key not in seen_relations:
            relation_defs.append({
                "pid": pid,
                "label": rel,
                "domain": subj_qid if subj_qid else "",
                "range": obj_qid if obj_qid else ""
            })
            seen_relations.add(key)

    unique_concepts = {c["qid"]: c for c in concepts.values()}

    return {
        "title": title,
        "id": ontology_id,
        "concepts": list(unique_concepts.values()),
        "relations": relation_defs
    }

In [18]:
# ---------------------------------------------
# Mistral-based fallback logic
# ---------------------------------------------
def mistral_extract_plain_concepts_relations(text):
    prompt = f"""Extract all the key concepts and relations from the text below. 
Return them as a list of triples (subject, relation, object).

Text:
{text}

Output (as list of triples):"""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.7)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = decoded.split("Output (as list of triples):")[-1].strip()
    print("[Stage 1 Output]\n", response)
    return response

In [19]:
def parse_triples_from_response(response):
    triples = []
    for line in response.strip().split("\n"):
        if not line.strip():
            continue
        match = re.findall(r"\(?['\"]?(.+?)['\"]?\)?\s*,\s*['\"]?(.+?)['\"]?\s*,\s*['\"]?(.+?)['\"]?\)?", line)
        if match:
            triples.extend(match)
        else:
            parts = [p.strip(" '\"") for p in line.strip().split(" - ")]
            if len(parts) == 3:
                triples.append(tuple(parts))
    return triples


In [20]:
def build_mistral_based_ontology(text, title="Custom Ontology", ontology_id="ont_custom_mistral"):
    stage1_output = mistral_extract_plain_concepts_relations(text)
    triples = parse_triples_from_response(stage1_output)

    concepts = {}
    relation_defs = []
    seen_relations = set()

    for subj, rel, obj in triples:
        subj_qid = get_wikidata_qid(subj)
        obj_qid = get_wikidata_qid(obj)
        pid = get_pid_from_label(rel)

        if subj_qid:
            subj_types = get_wikidata_types(subj_qid)
            for t in subj_types:
                concepts[t] = {"qid": t, "label": get_label(t)}
        if obj_qid:
            obj_types = get_wikidata_types(obj_qid)
            for t in obj_types:
                concepts[t] = {"qid": t, "label": get_label(t)}

        key = (rel.lower(), subj_qid, obj_qid)
        if key not in seen_relations:
            relation_defs.append({
                "pid": pid,
                "label": rel,
                "domain": subj_qid if subj_qid else "",
                "range": obj_qid if obj_qid else ""
            })
            seen_relations.add(key)

    unique_concepts = {c["qid"]: c for c in concepts.values()}

    ontology = {
        "title": title,
        "id": ontology_id,
        "concepts": list(unique_concepts.values()),
        "relations": relation_defs
    }
    return ontology

In [21]:
# ---------------------------------------------
# Decision logic
# ---------------------------------------------
def extract_ontology(text):
    print("[INFO] Detecting if text matches Wikidata...")
    if is_wikidata_text(text):
        print("[INFO] Detected text matches Wikidata. Using REBEL + Wikidata.")
        return build_ontology(text)
    else:
        print("[INFO] Text doesn't match Wikidata. Using Mistral for custom ontology.")
        return build_mistral_based_ontology(text)

In [22]:
# ---------------------------------------------
# Test examples
# ---------------------------------------------
wikidata_text = """Barack Obama (born August 4, 1961, in Honolulu, Hawaii) served as the 44th President of the United States from 2009 to 2017. He is the son of Barack Obama Sr., an economist from Kenya, and Ann Dunham, an anthropologist from Kansas."""

custom_text = """DJ Sonic released his debut album 'Bass Drops' in 2023. The album was produced by Electro Studio."""

# Run on both
print("\n=== Ontology from Wikidata Text ===")
print(json.dumps(extract_ontology(wikidata_text), indent=2))

print("\n=== Ontology from Custom (Mistral) Text ===")
print(json.dumps(extract_ontology(custom_text), indent=2))


=== Ontology from Wikidata Text ===
[INFO] Detecting if text matches Wikidata...
[DEBUG] Entity match ratio: 1.00
[INFO] Detected text matches Wikidata. Using REBEL + Wikidata.

[DEBUG] Raw REBEL Output:
  Barack Obama  August 4, 1961  date of birth  Honolulu, Hawaii  place of birth  President of the United States  position held  Barack Obama  August 4, 1961  date of birth  Honolulu, Hawaii  place of birth  President of the United States  position held
{
  "title": "Generated Ontology",
  "id": "ont_wikidata_001",
  "concepts": [
    {
      "qid": "Q5",
      "label": "human"
    },
    {
      "qid": "Q2784",
      "label": "August 4"
    },
    {
      "qid": "Q47150325",
      "label": "calendar day of a given year"
    },
    {
      "qid": "Q62049",
      "label": "county seat"
    },
    {
      "qid": "Q1549591",
      "label": "big city"
    },
    {
      "qid": "Q3301053",
      "label": "consolidated city-county"
    },
    {
      "qid": "Q35657",
      "label": "U.S. sta

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[DEBUG] Entity match ratio: 0.33
[INFO] Text doesn't match Wikidata. Using Mistral for custom ontology.
[Stage 1 Output]
 [
  ("DJ Sonic", "released", "Bass Drops"),
  ("DJ Sonic", "produced by", "Electro Studio")
]
{
  "title": "Custom Ontology",
  "id": "ont_custom_mistral",
  "concepts": [
    {
      "qid": "Q3624078",
      "label": "sovereign state"
    },
    {
      "qid": "Q619610",
      "label": "social state"
    },
    {
      "qid": "Q4209223",
      "label": "Rechtsstaat"
    },
    {
      "qid": "Q63791824",
      "label": "country bordering the Baltic Sea"
    },
    {
      "qid": "Q6256",
      "label": "country"
    },
    {
      "qid": "Q23718",
      "label": "cardinal direction"
    },
    {
      "qid": "Q11114344",
      "label": "points of the compass"
    }
  ],
  "relations": [
    {
      "pid": "P577",
      "label": "released",
      "domain": "",
      "range": "Q183"
    },
    {
      "pid": "P162",
      "label": "produced by",
      "domain": "",
 