## Document to Knowledge Graph Conversion using REBEL

In [None]:
# COMBINED RELATION EXTRACTION

import spacy
import torch
import math
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

# Initialize model and tokenizer (example model, replace with actual one you're using)
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load spaCy model for dependency parsing
nlp = spacy.load("en_core_web_sm")


class KB:
    def __init__(self):
        self.relations = []

    def add_relation(self, relation):
        if relation not in self.relations:
            self.relations.append(relation)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

# Extract relations using spaCy dependency parsing


def extract_spacy_relations(text):
    doc = nlp(text)
    relations = []

    for token in doc:
        # Extract subject-verb-object relations
        if token.dep_ == "ROOT":  # The main verb in the sentence
            subject = [w for w in token.lefts if w.dep_ in (
                "nsubj", "nsubjpass")]
            obj = [w for w in token.rights if w.dep_ in (
                "dobj", "attr", "prep", "pobj")]
            if subject and obj:
                relations.append({
                    'head': subject[0].text,
                    'type': token.text,  # The verb as the relation type
                    'tail': obj[0].text,
                    'meta': {'sentence': text}
                })
    return relations

# extract triples from the text


def extract_relations_from_model_output(text, sentence):
    relations = []
    relation, subject, object_ = '', '', ''
    current = 'x'
    text_replaced = text.replace("<s>", "").replace(
        "<pad>", "").replace("</s>", "")

    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation and subject and object_:
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip(),
                    # Initialize with empty spans list
                    'meta': {'sentence': sentence, 'spans': []}
                })
                relation, subject, object_ = '', '', ''
        elif token == "<subj>":
            current = 's'
        elif token == "<obj>":
            current = 'o'
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token

    # Add last triple if exists
    if subject and relation and object_:
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip(),
            # Initialize with empty spans list
            'meta': {'sentence': sentence, 'spans': []}
        })
    return relations


# construct KG
def from_text_to_kb(text, kb, span_length=64):

    sentences = [sentence.strip()
                for sentence in text.split(".") if sentence.strip()]

    for sentence in sentences:
        # Use spaCy to enhance relations with dependency parsing
        spacy_relations = extract_spacy_relations(sentence)

        # Combine Rebel and spaCy relations for a more comprehensive KB
        for relation in spacy_relations:  # or extend with rebel_relations if using both
            kb.add_relation(relation)

        # Use Rebel model to extract relations
        inputs = tokenizer([sentence], return_tensors="pt")

        # Compute span boundaries for the sentence
        num_tokens = len(inputs["input_ids"][0])
        num_spans = math.ceil(num_tokens / span_length)
        overlap = math.ceil(
            (num_spans * span_length - num_tokens) / max(num_spans - 1, 1))
        spans_boundaries = []
        start = 0
        for i in range(num_spans):
            spans_boundaries.append(
                [start + span_length * i, start + span_length * (i + 1)])
            start -= overlap

        tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
        tensor_masks = [inputs["attention_mask"][0][boundary[0]                                                    :boundary[1]] for boundary in spans_boundaries]
        # inputs = {"input_ids": torch.stack(tensor_ids), "attention_mask": torch.stack(tensor_masks)}
        inputs = {"input_ids": torch.stack(tensor_ids).to(
            device), "attention_mask": torch.stack(tensor_masks).to(device)}

        # Generate triples for the sentence
        num_return_sequences = 5
        gen_kwargs = {
            "max_length": 512,
            "length_penalty": 1.0,
            "num_beams": 5,
            "num_return_sequences": num_return_sequences
        }

        generated_tokens = model.generate(**inputs, **gen_kwargs)
        decoded_preds = tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=False)

        for sentence_pred in decoded_preds:
            relations = extract_relations_from_model_output(
                sentence_pred, sentence)
            for relation in relations:
                kb.add_relation(relation)

    return kb

# Example usage
# text = """
# AI ethics addresses questions of fairness, transparency, and accountability in AI design. Machine learning involves training algorithms on data to make predictions or decisions. Natural Language Processing is a major field in AI that enables interaction between humans and computers using natural language
# """
# # Generate KB from text
# kb = KB()
# kb = from_text_to_kb(text, kb)
# kb.print()