In [3]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import re

def extract_triplets_typed(text):
    triplets = []
    relation = ''
    text = text.strip()
    current = 'x'
    subject, relation, object_, object_type, subject_type = '','','','',''

    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
        if token == "<triplet>" or token == "<relation>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                relation = ''
            subject = ''
        elif token.startswith("<") and token.endswith(">"):
            if current == 't' or current == 'o':
                current = 's'
                if relation != '':
                    triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                object_ = ''
                subject_type = token[1:-1]
            else:
                current = 'o'
                object_type = token[1:-1]
                relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
        triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
    return triplets

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-large-32", src_lang="en_XX", tgt_lang="tp_XX") 
# Here we set English ("en_XX") as source language. To change the source language swap the first token of the input for your desired language or change to supported language. For catalan ("ca_XX") or greek ("el_EL") (not included in mBART pretraining) you need a workaround:
# tokenizer._src_lang = "ca_XX"
# tokenizer.cur_lang_code_id = tokenizer.convert_tokens_to_ids("ca_XX")
# tokenizer.set_src_lang_special_tokens("ca_XX")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large-32")
gen_kwargs = {
    "max_length": 256,
    "length_penalty": 0,
    "num_beams": 3,
    "num_return_sequences": 3,
    "forced_bos_token_id": None,
}



def get_pos_from_triplets(triplets, text, idx=None):
    triples = []
    min_start_head = -1
    for triple in triplets:
        min_dist_tail = 100000
        head = triple["head"]
        matches_head = re.finditer(head, text)
        matches_head = [(m.start(0), m.end(0)) for m in matches_head]
        tail = triple["tail"]
        matches_tail = re.finditer(tail, text)
        matches_tail = [(m.start(0), m.end(0)) for m in matches_tail]
        for x in matches_head:
            if x[0]>min_start_head:
                min_start_head = x[0]
                start_pos_head = x[0]
                end_pos_head = start_pos_head+len(head)
                break
        for y in matches_tail:
            dist_tail = abs(y[0]-end_pos_head)
            if dist_tail<min_dist_tail:
                start_pos_tail = y[0]
                min_dist_tail = dist_tail
        if idx==None and len(matches_tail)>0 and len(matches_head)>0:
            output = {"head_start":start_pos_head, 
                      "head_end":start_pos_head+len(head),
                      "head_type":triple["head_type"],
                      "head":head,
                      "tail_start":start_pos_tail,
                      "tail_end":start_pos_tail+len(tail),
                      "tail_type":triple["tail_type"],
                      "tail":tail,
                      "relation":triple["type"]}
            triples.append(output)
        elif len(matches_tail)>0 and len(matches_head)>0:
            output = {"text_idx":idx,
                      "head_start":start_pos_head, 
                      "head_end":start_pos_head+len(head),
                      "head_type":triple["head_type"],
                      "head":head,
                      "tail_start":start_pos_tail,
                      "tail_end":start_pos_tail+len(tail),
                      "tail_type":triple["tail_type"],
                      "tail":tail,
                      "relation":triple["type"]}
            triples.append(output)
    return triples


In [4]:
import csv
with open("../data/sentences.csv", "r", encoding="utf-8") as f:
    data = list(csv.DictReader(f=f, delimiter=","))

In [5]:
from tqdm import tqdm
output = []

pbar = tqdm(total=len(data))
for sample in data:
    triples = []
    text_idx = sample["id"]
    text = sample["sentence"]
    model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
    generated_tokens = model.generate(
    model_inputs["input_ids"].to(model.device),
    attention_mask=model_inputs["attention_mask"].to(model.device),
    decoder_start_token_id = tokenizer.convert_tokens_to_ids("tp_XX"),
    **gen_kwargs,
    )
    # Extract text
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

    for sentence in decoded_preds:
        triplets = extract_triplets_typed(sentence)
        triplets = get_pos_from_triplets(triplets, text, text_idx)
        for idx, triplet in enumerate(triplets):
            print(triplet)
            if triplet not in triples:
                triples.append(triplet)
    output.extend(triples)
    pbar.update(1)
pbar.close()
print(output[0])
keys = output[0].keys()
a_file = open("../results/mrebel/output-32.csv", "w", encoding="utf-8")
dict_writer = csv.DictWriter(a_file, keys)
dict_writer.writeheader()
dict_writer.writerows(output)
a_file.close()

  0%|                                                                                           | 0/55 [00:00<?, ?it/s]

KeyError: 'type_head'