<a href="https://colab.research.google.com/github/LindaSekhoasha/Axiom-Extraction-From-Text-Using-Deep-Learning/blob/main/Axioms_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Environment Setup (Installs • Imports • Config)

In [None]:
%%capture
!pip -q install "transformers>=4.40" torch spacy datasets rdflib pandas --upgrade
!python -m spacy download en_core_web_sm

In [None]:
# standard library
import io
import gzip
import json
import os
import re
import shutil
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional
from urllib.parse import quote

# third-party
import pandas as pd
import pydotplus
import spacy
import torch
import torch.nn.functional as F
from datasets import load_dataset
from google.colab import files
from IPython.display import display, Image
from rdflib import Graph, URIRef, Literal, Namespace, RDF, XSD
from rdflib import Graph as RDFGraph
from rdflib.namespace import OWL
from rdflib.tools.rdf2dot import rdf2dot
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    LukeTokenizer,
    LukeForEntityPairClassification,
)

## Helper Functions

In [None]:
_TAG = re.compile(r"<[^>]+>")
def _clean_local(s: str) -> str:
    s = _TAG.sub("", s).strip().replace(",", "")
    s = "_".join(s.split())
    s = re.sub(r"[^A-Za-z0-9_\-./:]", "", s)
    return s

def subject_uri_from_head(head: str):
    h = _clean_local(head)
    if not h:
        return None
    return _uri(DBR, h)

def _uri(ns: Namespace, local: str) -> URIRef:
    base = str(ns)
    local = _clean_local(local)
    return URIRef(base + quote(local, safe=":/._-"))

def add_triplets_to_luke_graph(triplets, g: Graph, cls_name: str):
    for h, r, t in triplets:
        head = _clean_local(h)
        pred = _clean_local(r)
        tail = t.strip()
        if not head or not pred:
            continue
        subj_uri = _uri(DBR, head)
        pred_uri = _uri(EX, pred)
        # simple literal/entity heuristic
        if tail.isdigit() and len(tail) == 4:
            obj_val = Literal(tail, datatype=XSD.gYear)
        elif tail.isdigit():
            obj_val = Literal(int(tail), datatype=XSD.integer)
        else:
            obj_val = _uri(DBR, tail)
        g.add((subj_uri, pred_uri, obj_val))
        g.add((subj_uri, RDF.type, DBO[cls_name]))

def add_triplets_to_rebel_graph(triplets, cls_label):
    """
    add extracted triplets to the RDF graph for a given dbpedia class.

    args:
        triplets (list of dict): [{'head': ..., 'type': ..., 'tail': ...}]
        cls_label (int): Index of the dbpedia class from DBPEDIA_CLASSES
    """
    cls_name = DBPEDIA_CLASSES[cls_label]
    g = graphs_rebel[cls_name]

    for t in triplets:
        head = _clean_local(t['head'])
        pred = _clean_local(t['type'])
        tail = t['tail'].strip()

        # skip if predicate vanished after cleaning
        if not pred or not head:
            continue

        subj_uri = _uri(DBR, head)
        pred_uri = _uri(EX, pred)

        # tail: try to distinguish year, number, or entity; clean if entity
        if tail.isdigit() and len(tail) == 4:
            obj_val = Literal(tail, datatype=XSD.gYear)
        elif tail.isdigit():
            obj_val = Literal(int(tail), datatype=XSD.integer)
        else:
            obj_val = _uri(DBR, tail)

        # add the triple
        g.add((subj_uri, pred_uri, obj_val))

        # assert rdf:type for the subject
        g.add((subj_uri, RDF.type, DBO[cls_name]))

# process the text from the rebel model and return the triplets
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

# prune candidate pairs
def candidate_pairs_for_ents(ents):
    """
    ents: List[(text, (s,e), label)]
    returns: List[(h_text, (hs,he), t_text, (ts,te))]
    """
    pairs = []
    for i, (ht, (hs,he), hl) in enumerate(ents):
        for j, (tt, (ts,te), tl) in enumerate(ents):
            if i == j:
                continue
            if ALLOWED_PAIRS and (hl, tl) not in ALLOWED_PAIRS:
                continue
            if MAX_CHAR_DISTANCE and abs(hs - ts) > MAX_CHAR_DISTANCE:
                continue
            pairs.append((ht, (hs,he), tt, (ts,te)))
    if MAX_PAIRS_PER_DOC and len(pairs) > MAX_PAIRS_PER_DOC:
        pairs = pairs[:MAX_PAIRS_PER_DOC]
    return pairs

# rdf graph visualization
def visualize(graphs, doc_subjects, k: int = 5):
    """
    build a subgraph containing ONLY triples whose subjects came from
    dbpedia rows [0..k-1], and render it as a PNG.
    """
    # collect target subject URIs for the first k rows
    targets = set()
    for gidx in range(min(k, len(texts))):
        targets |= doc_subjects.get(gidx, set())

    # build subgraph: include any triple whose subject is in targets
    sub = RDFGraph()
    sub.bind("ex", EX); sub.bind("dbr", DBR); sub.bind("dbo", DBO)

    for g in graphs.values():
        for s, p, o in g:
            if s in targets:
                sub.add((s, p, o))

    if len(sub) == 0:
        print(f"No triples found for the first {k} rows.")
        return

    stream = io.StringIO()
    rdf2dot(sub, stream)
    dg = pydotplus.graph_from_dot_data(stream.getvalue())
    display(Image(dg.create_png()))

def _is_uri(x): return isinstance(x, URIRef)
def _is_lit(x): return isinstance(x, Literal)

def compute_metrics_for_graph(g, cls_name, DBO):
    cls_uri = URIRef(str(DBO) + cls_name)

    # individuals: subjects that are typed as this class
    individuals = {s for s, p, o in g.triples((None, RDF.type, cls_uri))}

    # basic counts
    uniques_s = set()
    uniques_p = set()
    uniques_o = set()
    obj_triples = 0
    data_triples = 0

    for s, p, o in g:
        uniques_s.add(s); uniques_p.add(p); uniques_o.add(o)
        if _is_lit(o):
            data_triples += 1
        else:
            obj_triples += 1

    total_props = obj_triples + data_triples
    relationship_richness = (obj_triples / total_props) if total_props else 0.0
    attribute_richness   = (data_triples / max(len(individuals), 1)) if len(individuals) else 0.0

    # degrees over individuals (within this graph)
    out_deg = defaultdict(int)
    in_deg  = defaultdict(int)
    for s, p, o in g:
        if s in individuals:
            out_deg[s] += 1
        if o in individuals:
            in_deg[o] += 1
    if individuals:
        avg_degree = sum(out_deg[i] + in_deg[i] for i in individuals) / len(individuals)
    else:
        avg_degree = 0.0

    return {
        "class": cls_name,
        "triples": len(g),
        "individuals": len(individuals),
        "unique_subjects": len(uniques_s),
        "unique_predicates": len(uniques_p),
        "unique_objects": len(uniques_o),
        "object_triples": obj_triples,
        "data_triples": data_triples,
        "relationship_richness": relationship_richness,
        "attribute_richness": attribute_richness,
        "avg_degree_individuals": avg_degree,
    }

## CONFIG & Variables

In [None]:
LUKE_MODEL          = "studio-ousia/luke-large-finetuned-tacred"  # pretrained, non-LLM baseline
REBEL_MODEL         = "Babelscape/rebel-large"                    # pretrained, LLMs

CONF_THRESHOLD      = 0.50                    # accept a relation only if softmax prob >= this
MAX_DOCS            = 120_000                      # train[:N] (used subset of dbpedia_14 do to hardware limitations)
BATCH_SIZE          = 128                     # number of texts per loop (NER + inference per pair)
NUM_RETURNS         = 3
PRINT_EVERY         = 100
OUTPUT_DIR_LUKE     = "out_luke"
OUTPUT_DIR_REBEL    = "out_rebel"
SAVE_XML_CLASS      = "Company"               # set None to skip
VALID_NER_LABELS    = {"PERSON","ORG","GPE","PRODUCT","FAC","WORK_OF_ART"}
MAX_PAIRS_PER_DOC   = 30                      # cap pairs per doc for speed; set None for all
doc_subjects_luke   = {}  # global_doc_idx -> set of subject URIs added from that row
doc_subjects_rebel  = {}  # global_doc_idx -> set of subject URIs added from that row

ALLOWED_PAIRS = {
    ("PERSON","ORG"), ("ORG","PERSON"),
    ("PERSON","GPE"), ("GPE","PERSON"),
    ("ORG","ORG"), ("ORG","GPE"), ("GPE","ORG"),
}
MAX_CHAR_DISTANCE = 120     # skip far-apart mentions; set None to disable
PAIR_BATCH        = 128     # pairs per LUKE forward; tune 32–128 on A100

# DATASET: [dbpedia_14](https://huggingface.co/datasets/fancyzhx/dbpedia_14/viewer/dbpedia_14/train?views%5B%5D=train&row=5)
A supervised text-classification benchmark built from DBpedia article abstracts. Each sample is short text (title + abstract) with one of 14 ontology classes (e.g., Company, Person, Place). The Hugging Face variant fancyzhx/dbpedia_14 mirrors the widely used setup from the literature: ~560k training docs and 70k test docs with label IDs and class names.

In [None]:
dataset = load_dataset("fancyzhx/dbpedia_14", split=f"train[:{MAX_DOCS}]")
texts = [ex["content"] for ex in dataset]
labels = [ex["label"] for ex in dataset]

DBPEDIA_CLASSES = dataset.features["label"].names
print(f"DBPEDIA_CLASSES (ONTOLOGIES):\n{DBPEDIA_CLASSES}")

In [None]:
# NAMESPACES
EX  = Namespace("http://example.org/ontology/")
DBR = Namespace("http://dbpedia.org/resource/")
DBO = Namespace("http://dbpedia.org/ontology/")

graphs_luke = {cls: Graph() for cls in DBPEDIA_CLASSES}
graphs_rebel = {cls: Graph() for cls in DBPEDIA_CLASSES}

In [None]:
# batched spaCy NER once for all texts - LUKE
nlp = spacy.load("en_core_web_sm", disable=["tagger","parser","lemmatizer","attribute_ruler"])
nlp.max_length = max(len(t) for t in texts) + 10

ents_by_doc = []  # List[List[ (ent_text, (start_char,end_char), ent_label) ]]
for doc in nlp.pipe(texts, batch_size=256, n_process=2):
    ents = [(e.text, (e.start_char, e.end_char), e.label_) for e in doc.ents if e.label_ in VALID_NER_LABELS]
    ents_by_doc.append(ents)