Dependencies and helper functions

In [12]:
import os
from typing import Dict

import pickle
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from PyPDF2 import PdfReader
import networkx as nx
from pyvis.network import Network

from transformers import pipeline
from transformers import BartTokenizerFast, MBart50TokenizerFast
import ray

os.environ['TOKENIZERS_PARALLELISM'] = "true"

###################### READ FILES ######################
def process_pages(reader):
    complete_text = ""
    for page in tqdm(range(len(reader.pages))):
        current_page = reader.pages[page]
        current_text = current_page.extract_text()
        #remove hyphens
        current_text = current_text.replace('-\n', '')
        #remove newlines
        current_text = current_text.replace('\n', ' ')
        
        current_text = current_text + ' '
        complete_text = complete_text + current_text

    return complete_text

###################### CHUNKING ######################
# NOTE: chunk size is set by model_max_length * MULTIPLIER
MULTIPLIER = 0.25 

def split_chunks(tokenizer, text) -> list:
    """
    Splits the input text into chunks based on a specified chunk size,
    ensuring that chunks do not split over words.

    Args:
        text (str): The input text to be split into chunks.

    Returns:
        list: A list of text chunks.

    """
    tokenizer_type ='sentence_piece'
    if isinstance(tokenizer, BartTokenizerFast):
        separator = 'Ġ'
    else:
        separator = '▁'

    tokens = tokenizer.tokenize(text)
    text_chunks = list()
    processed_tokens = list()
    chunk_start = 0
    while len(processed_tokens) != len(tokens):
        # consider chunk of tokens
        # adjust chunk size to avoid splitting over words
        if len(tokens) - len(processed_tokens) > int(tokenizer.model_max_length * MULTIPLIER):
            chunk_end = chunk_start + int(tokenizer.model_max_length * MULTIPLIER)
            chunk_end = adjust_chunk_end(tokenizer_type, separator, tokens, chunk_end)
        else:
            chunk_end = len(tokens)
        # select slice with chunk size
        current_chunk_tokens = tokens[chunk_start : chunk_end]
        current_chunk_text = tokenizer.convert_tokens_to_string(current_chunk_tokens)
        if not check_chunk(tokenizer, current_chunk_text):
            # remove 1 more word
            chunk_end -= 1
            chunk_end = adjust_chunk_end(tokenizer_type, separator, tokens, chunk_end)
            current_chunk_tokens = tokens[chunk_start : chunk_end]
            current_chunk_text = tokenizer.convert_tokens_to_string(current_chunk_tokens)
        # append text slice
        text_chunks.append(current_chunk_text) 
        processed_tokens.extend(current_chunk_tokens)
        chunk_start = chunk_end

    return text_chunks

def adjust_chunk_end(tokenizer_type:str, separator:str, tokens:list, chunk_end:int) -> int:
    # adjust chunk size if the split is on a word
    if tokenizer_type == 'word_piece':
        separator='##'
                # check if last 2 symbols are not the separator
        while tokens[chunk_end][:2] == separator:
            chunk_end -= 1
    elif tokenizer_type == 'sentence_piece':
        # separator='▁'
                # check if first symbol is not separator
        while tokens[chunk_end][0] != separator:
            chunk_end -= 1
    else:
        raise Exception('Invalid tokenizer type')
    
    return chunk_end

def check_chunk(tokenizer, chunk:str) -> bool:
    """
    Checks whether a given text chunk is within the allowed chunk size.

    Args:
        chunk (str): The text chunk to be checked.

    Returns:
        bool: True if the chunk is within the allowed size, False otherwise.

    """
    tokens = tokenizer.tokenize(chunk)
    if len(tokens) > int(tokenizer.model_max_length * MULTIPLIER):
        return False
    else:
        return True
    
###################### INFERENCE ######################
# @ray.remote(num_gpus=0.5)
class MRebelExtractor:
    def __init__(self):
        self.model = pipeline('translation_xx_to_yy',
                               model='Babelscape/mrebel-large',
                               tokenizer='Babelscape/mrebel-large',
                               max_length=1024,
                               device="cuda:0"
                            #    device="cpu"
                               )

    def __call__(self, batch: Dict[str, str]) -> Dict[str, list]:
        extracted_texts = []
        for text in batch['item']:
            seqs = self.model(
                text,
                decoder_start_token_id=250058, 
                src_lang="fr_XX", 
                tgt_lang="<triplet>", 
                return_tensors=True, 
                return_text=False)

            token_ids = seqs[0]["translation_token_ids"]
            extracted_text = self.model.tokenizer.batch_decode([token_ids])
            extracted_texts.append(extracted_text)

        return {'output': extracted_texts}  
    
###################### EXTRACT TRIPLES ######################
def extract_triplets_typed(text):
    triplets = []
    relation = ''
    text = text.strip()
    current = 'x'
    subject, relation, object_, object_type, subject_type = '','','','',''

    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
        if token == "<triplet>" or token == "<relation>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                relation = ''
            subject = ''
        elif token.startswith("<") and token.endswith(">"):
            if current == 't' or current == 'o':
                current = 's'
                if relation != '':
                    triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                object_ = ''
                subject_type = token[1:-1]
            else:
                current = 'o'
                object_type = token[1:-1]
                relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
        triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
    return triplets

###################### CREATE KNOWLEDGE BASE ######################
import wikipedia
import asyncio

class KB():
    def __init__(self):
        self.entities = set()
        self.ent_type_map = {}
        self.ent_types = set()
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr].lower() == r2[attr].lower() for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]

    def add_entity(self, e):
        self.entities.add(e["title"])
        self.ent_types.add(e["type"])
        self.ent_type_map.update({e["title"]: e["type"]})

    async def get_wikipedia_data(self, candidate_entity):
        try:
            page = await asyncio.get_running_loop().run_in_executor(None, wikipedia.page, candidate_entity, False)
            entity_data = {
                "title": page.title,
                "url": page.url,
                "summary": page.summary
            }
            return entity_data
        except:
            return None
    
    def bad_entity_check(self, entity):

        digit_ratio = sum(c.isdigit() for c in entity) / len(entity)
        if digit_ratio > 0.25:
            return True
        
        return False
    
    def capital_count(self, string):
        return sum(c.isupper() for c in string)

    # async def add_relation(self, r):
    def add_relation(self, r):
        candidate_entities = [r["head"], r["tail"]]
        candidate_entity_types = [r["head_type"], r["tail_type"]]
        #filter self-reference
        if candidate_entities[0].lower() == candidate_entities[1].lower():
            return
        #filter 1 or 2 letter entities
        if any(len(ent) < 3 for ent in candidate_entities):
            return
        if any(self.bad_entity_check(ent) for ent in candidate_entities):
            return
        
        entities = candidate_entities # offline

        # manage new entities
        for e, ent_type in zip(entities, candidate_entity_types):
            ent = {"title": e,
                   "type": ent_type,
                "url": '',
                "summary": ''}
            self.add_entity(ent)

        # rename relation entities with their wikipedia titles
        r["head"] = entities[0]
        r["tail"] = entities[1]

        # manage new relation
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Entities:")
        for e in self.entities:
            print(f"  {e}")
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

###################### CREATE KNOWLEDGE GRAPH ######################
def save_network_html(kb, filename="network.html"):
    # create network
    G = nx.DiGraph()

    G.add_nodes_from(kb.entities)
    G.add_edges_from((r["head"], r["tail"], {'relation': r['type']}) for r in kb.relations)

    net = Network(directed=True, width="1920px", height="1080px", bgcolor="#eeeeee")
 
    # palette = itertools.cycle(sns.color_palette())
    colors = {}
    for ent_type, color_unit in zip(kb.ent_types, sns.color_palette().as_hex()):
        colors[ent_type] = color_unit
    # nodes
    color_entity = "#00FF00"
    for e in kb.entities:
        net.add_node(e, 
                     shape="circle", 
                     color=colors[kb.ent_type_map[e]]
                    #  color=color_entity
                     )

    # edges
    for r in kb.relations:
        net.add_edge(r["head"], r["tail"],
                    title=r["type"], label=r["type"])
        
    # save network
    net.repulsion(
        node_distance=250,
        central_gravity=0.2,
        spring_length=250,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename, notebook=False)

Read PDFs

In [None]:
f_dir = './irsn/corpus/'
dir_list = os.listdir(f_dir)
dir_list = [name for name in dir_list if 'pdf' in name]
# len(dir_list)
corpus_dict = {
    'text': [],
    'file_name': []
}
for name in dir_list:
    print(f'Processing {name}')
    corpus_dict['file_name'].append(name)
    reader = PdfReader(f_dir + name)
    text = process_pages(reader)
    corpus_dict['text'].append(text)

corpus_df = pd.DataFrame(corpus_dict)
corpus_df.head()

In [None]:
# Load previously pickled results
# with open('./irsn_corpus_df.pkl', 'rb') as handle:
# corpus_df = pd.read_pickle(handle)

Chunk documents

In [5]:
tokenizer = MBart50TokenizerFast.from_pretrained('Babelscape/mrebel-large')

chunk_dict = {

}
chunks = []
chunk_length = 0
for file_name, text in tqdm(zip(corpus_df.file_name.values, corpus_df.text.values), 
                                           total = len(corpus_df.text.values)):
    text_chunks = split_chunks(tokenizer, text)
    for i, chunk in enumerate(text_chunks):
        chunk_dict[chunk_length + i] = file_name
    chunk_length += len(text_chunks) 
    chunks.extend(text_chunks)

len(chunks)

Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|          | 0/58 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (12774 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 58/58 [00:18<00:00,  3.15it/s]


4487

Run inference

In [6]:
BATCH_SIZE = 8

if not ray.is_initialized():
    ray.init(num_cpus=16, num_gpus=1, log_to_driver=False) #num_gpus is the hardware count
    ray.data.DataContext.get_current().execution_options.verbose_progress = True

try:
    ds = ray.data.from_items(chunks)
    # ds = ray.data.from_pandas([chunks_df,])

    extracts = ds.map_batches(
        MRebelExtractor,
        # batch_format="pandas",
        num_gpus=0.5, # per actor !!!
        batch_size=BATCH_SIZE,
        compute=ray.data.ActorPoolStrategy(size=2),
        )
    # results = extracts.take(ds.count()) # to process all
    results = extracts.take()

except Exception as inference_exc:
    print(inference_exc)
    if ray.is_initialized():
        ray.shutdown()

if ray.is_initialized():
    ray.shutdown()

2023-11-21 10:19:33,217	INFO worker.py:1636 -- Started a local Ray instance.

Learn more here: https://docs.ray.io/en/master/data/faq.html#migrating-to-strict-mode[0m
2023-11-21 10:19:34,174	INFO dataset.py:2087 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2023-11-21 10:19:34,177	INFO streaming_executor.py:91 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(MRebelExtractor)]
2023-11-21 10:19:34,178	INFO streaming_executor.py:92 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=True)
2023-11-21 10:19:34,194	INFO actor_pool_map_operator.py:114 -- MapBatches(MRebelExtractor): Waiting for 2 pool actors to start...


- MapBatches(MRebelExtractor) 1:   0%|          | 0/200 [00:00<?, ?it/s]

Running 0:   0%|          | 0/200 [00:00<?, ?it/s]

2023-11-21 10:20:17,016	INFO streaming_executor.py:149 -- Shutting down <StreamingExecutor(Thread-7, started daemon 140036175546112)>.


In [None]:
# Load previously pickled results
# with open('./irsn_rebel_output.pkl', 'rb') as handle:
#     results = pickle.load(handle)

Parse Triples

In [7]:
extracted_triplets = []
for result in results:
    raw_output = result['output']
    raw_output = ' '.join(raw_output)
    extracted_triplets.extend(extract_triplets_typed(raw_output))

triplets_df = pd.DataFrame.from_dict(extracted_triplets, orient='columns')
triplets_df.head()
# triplets_df.groupby(['type']).size()

Unnamed: 0,head,head_type,type,tail,tail_type
0,2010,time,point in time,janvier 2010,date
1,Finlande,loc,diplomatic relation,États-Unis,loc
2,États-Unis,loc,diplomatic relation,Finlande,loc
3,Olkiluoto,loc,country,Finlande,loc
4,2010,time,point in time,janvier 2010,date


Construct Knowledge Base

In [9]:
kb = KB()
for _, entry in tqdm(triplets_df.iterrows(), total=triplets_df.shape[0]):
    relation = {"head": entry['head'], 
                "type":entry['type'], 
                'tail':entry['tail'],
                "head_type": entry["head_type"],
                "tail_type": entry["head_type"]}
    kb.add_relation(relation)

100%|██████████| 35/35 [00:00<00:00, 14838.84it/s]


Make Knowledge Graph

In [13]:
filename = "kg_test.html"
save_network_html(kb, filename=filename)

kg_test.html
