In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import wikipedia
import IPython
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

In [12]:
test_text='''DEWA Launched DEWA-SAT-1 
nanosatellite  
(the world’s first utility to use 
nanosatellite) on 13 January 
2022 and has received DEWA-SAT-1 
nanosatellite  
(the world’s first utility to use 
nanosatellite)'s  '''

In [11]:
model_inputs = tokenizer(test_text,
                         max_length=512,
                         padding=True,
                         truncation=True,
                         return_tensors='pt')
print(f"Num tokens:{len(model_inputs['input_ids'][0])}")

gen_kwargs={
    "max_length":216,
    "length_penalty":0,
    "num_beams":5,
    "num_return_sequences":3
}

generated_tokens = model.generate(
    **model_inputs,
    **gen_kwargs,
)

decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
decoded_preds

Num tokens:79


['<s><triplet> DEWA-SAT-1 ークnanosatellite <subj> 13 January <obj> service entry</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s><triplet> DEWA-SAT-1 ークnanosatellite <subj> 13 January <obj> start time</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s><triplet> DEWA-SAT-1 ークnanosatellite <subj> 13 January <obj> service entry <triplet> DEWA-SAT-1 ークnanosatellite <subj> 13 January <obj> service entry</s>']

In [5]:
test_text="The Solar Park is the largest in the world. It supports the Dubai Clean Energy Strategy 2050."

In [8]:
model_inputs = tokenizer(test_text,
                         max_length=512,
                         padding=True,
                         truncation=True,
                         return_tensors='pt')
print(f"Num tokens:{len(model_inputs['input_ids'][0])}")

gen_kwargs={
    "max_length":216,
    "length_penalty":0,
    "num_beams":5,
    "num_return_sequences":5
}

generated_tokens = model.generate(
    **model_inputs,
    **gen_kwargs,
)

decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
decoded_preds

Num tokens:21


['<s><triplet> Dubai Clean Energy Strategy 2050 <subj> 2050 <obj> point in time</s>',
 '<s><triplet> Solar Park <subj> Dubai <obj> located in the administrative territorial entity</s>',
 '<s><triplet> Dubai Clean Energy Strategy <subj> 2050 <obj> point in time</s><pad>',
 '<s><triplet> Clean Energy Strategy 2050 <subj> 2050 <obj> point in time</s><pad>',
 '<s><triplet> Solar Park <subj> Dubai Clean Energy Strategy 2050 <obj> part of</s>']

In [13]:
def extract_relations_from_model_output(text):
    relations=[]
    relation, subject, relation, object_='','','',''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>","").replace("<pad>","").replace("</s>","")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head':subject.strip(),
                    'type':relation.strip(),
                    'tail':object_.strip()
                })
                relation = ''
            subject=''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head':subject.strip(),
                    'type':relation.strip(),
                    'tail':object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head':subject.strip(),
            'type':relation.strip(),
            'tail':object_.strip()
        })
    return relations


In [15]:
class KB():
    def __init__(self):
        self.relations = []
    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
    def exists_relation(self,r1):
        return any(self.are_relations_equal(r1,r2) for r2 in self.relations)
    def add_relation(self,r):
        if not self.exists_relation(r):
            self.relations.append(r)
    def print(self):
        print("relations:")
        for r in self_relations:
            print(f"  {r}")

: 

In [None]:
def from_small_text_to_kb(text,verbose=False):
    kb = KB()