### 2. Extraction Notebook

##### In this notebook, we:
1. use the latest version of Python, Pytorch and Transformers unlike the Data Notebook.
2. load the data that we preprocessed in the previous notebook.
3. set the generation parameters for the model
4. define a post-processing function to shape the triples from the REBEL output
5. execute the extraction
6. save all the extracted triples, entities and relations separately in a jsonl file.

In [1]:
#REBEL environment details:
!python --version
import torch
print('Pythorch version: ', torch.__version__)
import transformers
print('Transformers version: ', transformers.__version__)
import json

Python 3.10.9
Pythorch version:  2.0.1
Transformers version:  4.30.1


In [2]:
# Let's start by loading the data we preprocessed in the previous notebook.
data = []
with open('data/preprocessed_data_for_extraction.jsonl', 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line.rstrip('None\n')))
print('Number of data instances: ', len(data))

Number of data instances:  917


In [3]:
# we call the tokenizer and the model from the HuggingFace library
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large").to("cuda")

In [4]:
# we set the generation parameters for the model
gen_kwargs = {
    "max_length": 1024,
    "length_penalty": 0,
    "num_beams": 10, # 10 beams is NOT the default value but we opted for it to get more diverse results
    "num_return_sequences": 10, # 10 sequences is NOT the default value but we opted for it to get long tail triple extraction
}

In [5]:
# It is a post-processing function to shape the triples from the REBEL output
def extract_triples(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append((subject.strip(), relation.strip(), object_.strip()))
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append((subject.strip(), relation.strip(), object_.strip()))
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append((subject.strip(), relation.strip(), object_.strip()))
    return triplets

In [6]:
# Extraction time!
for line in data:
    triple_set = set() # we save the triples in a set to avoid duplicates
    entity_set = set() # we save the entities and relations separately
    relation_set = set()
    inputs = line["Chunk text"] 
    model_inputs = tokenizer(inputs, max_length=1024, padding=True, truncation=True, return_tensors = 'pt')
    #print(model_inputs['input_ids'].size())
    generated_tokens = model.generate(
                            model_inputs["input_ids"].to('cuda'),
                            attention_mask=model_inputs["attention_mask"].to('cuda'),
                            **gen_kwargs,
                            )
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
    for pred in decoded_preds:
        #print(pred)	
        triples = extract_triples(pred)
        for triple in triples:
            triple_set.add(triple)
            subj = triple[0]
            entity_set.add(subj)
            rel = triple[1]
            relation_set.add(rel)
            obj = triple[2]
            entity_set.add(obj)
    #print(triple_set )
    line["Extracted Triples"] = list(triple_set)
    line["Extracted Entities"] = list(entity_set)
    line["Extracted Relations"] = list(relation_set)
            

In [7]:
# Let's save the data in a jsonl file
with open('data/preprocessed_data_with_REBEL_extracted_triples.jsonl', 'w', encoding='utf-8') as f:
    for line in data:
        line = json.dump(line, f, ensure_ascii=False)
        f.write(f'{line}\n')