In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import pandas as pd


In [2]:
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

In [3]:
# Extract relations from model output

def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

In [4]:
# Implement KB class

class KB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self,r1,r2):
        return all(r1[attr] == r2[attr] for attr in ["head","type","tail"])
    
    def exists_relation(self,r1):
        return any(self.are_relations_equal(r1,r2) for r2 in self.relations)
    
    def add_relation(self,r):
        if not self.exists_relation(r):
            self.relations.append(r)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

In [5]:
def from_small_text_to_kb(text,verbose=False):
    kb = KB()

    # Tokenizer text
    model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
    if verbose:
        print(f"Num tokens: {len(model_inputs['input_ids'][0])}")

    # Generate
    gen_kwargs = {
        "max_length": 216,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": 3
    }
    generated_tokens = model.generate(
        **model_inputs,
        **gen_kwargs,
    )
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

    # create kb
    for sentence_pred in decoded_preds:
        relations = extract_relations_from_model_output(sentence_pred)
        for r in relations:
            kb.add_relation(r)

    return kb


In [6]:
# Read the article titles and generate kb

article_info = pd.read_csv("articles.csv")
article_title = list(article_info['Title'][:10])

kb = from_small_text_to_kb(article_title,verbose=True)
kb.print()

Num tokens: 50
Relations:
  {'head': 'tick', 'type': 'subclass of', 'tail': 'pathogens'}
  {'head': 'Anaplasma phagocytophilum', 'type': 'instance of', 'tail': 'bacteria'}
  {'head': 'Tick-borne encephalitis virus', 'type': 'has cause', 'tail': 'Borrelia burgdorferi sensu lato'}
  {'head': 'SNHG6', 'type': 'subclass of', 'tail': 'lncRNA'}
  {'head': 'SNHG6', 'type': 'instance of', 'tail': 'lncRNA'}
  {'head': 'lncRNA', 'type': 'use', 'tail': 'transcription regulation'}
  {'head': 'interleukin-18', 'type': 'part of', 'tail': 'monocytic cells'}
  {'head': 'monocytic cells', 'type': 'product or material produced', 'tail': 'interleukin-18'}
  {'head': 'interleukin-18', 'type': 'part of', 'tail': 'monocytic cell'}
  {'head': 'Metabolism-regulating non-coding RNAs', 'type': 'facet of', 'tail': 'breast cancer'}
  {'head': 'Metabolism-regulating non-coding RNA', 'type': 'facet of', 'tail': 'breast cancer'}
  {'head': 'Metabolism-regulating non-coding RNAs', 'type': 'medical condition treated',