In [1]:
# Add rebel component https://github.com/Babelscape/rebel/blob/main/spacy_component.py
import spacy
import crosslingual_coreference
import requests
import re
import os
import hashlib
import pandas as pd
from spacy import Language
from typing import List
from spacy.tokens import Doc, Span
from transformers import pipeline

[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\GCM\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [2]:
def call_wiki_api(item):
    try:
        url = f"https://www.wikidata.org/w/api.php?action=wbsearchentities&search={item}&language=en&format=json"
        data = requests.get(url).json()
        # Return the first id (Could upgrade this in the future)
        return data['search'][0]['id']
    except:
        return 'id-less'

def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets


@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline("text2text-generation", model=model_name, tokenizer=model_name, device=device)
        self.entity_mapping = {}
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
            Doc.set_extension("rel", default={})

    def get_wiki_id(self, item: str):
        mapping = self.entity_mapping.get(item)
        if mapping:
            return mapping
        else:
            res = call_wiki_api(item)
            self.entity_mapping[item] = res
            return res

    
    def _generate_triplets(self, sent: Span) -> List[dict]:
        output_ids = self.triplet_extractor(sent.text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
        extracted_text = self.triplet_extractor.tokenizer.batch_decode(output_ids[0])
        extracted_triplets = extract_triplets(extracted_text[0])
        return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        for triplet in triplets:

            # Remove self-loops (relationships that start and end at the entity)
            if triplet['head'] == triplet['tail']:
                continue

            # Use regex to search for entities
            head_span = re.search(triplet["head"], doc.text)
            tail_span = re.search(triplet["tail"], doc.text)

            # Skip the relation if both head and tail entities are not present in the text
            # Sometimes the Rebel model hallucinates some entities
            if not head_span or not tail_span:
                continue

            index = hashlib.sha1("".join([triplet['head'], triplet['tail'], triplet['type']]).encode('utf-8')).hexdigest()
            if index not in doc._.rel:
                # Get wiki ids and store results
                doc._.rel[index] = {"relation": triplet["type"], "head_span": {'text': triplet['head'], 'id': self.get_wiki_id(triplet['head'])}, "tail_span": {'text': triplet['tail'], 'id': self.get_wiki_id(triplet['tail'])}}

    def __call__(self, doc: Doc) -> Doc:
        for sent in doc.sents:
            sentence_triplets = self._generate_triplets(sent)
            self.set_annotations(doc, sentence_triplets)
        return doc

In [3]:
DEVICE = 0 # Number of the GPU, -1 if want to use CPU

# Add coreference resolution model
coref = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
coref.add_pipe("xx_coref", 
    config={"chunk_size": 2500, 
            "chunk_overlap": 2, 
            "device": DEVICE}
              )

# Define rel extraction model
rel_ext = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])
rel_ext.add_pipe("rebel", 
                 config={
                     'device':DEVICE,
                     'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
                )

error loading _jsonnet (this is expected on Windows), treating C:\Users\GCM\AppData\Local\Temp\tmp39_md_v9\config.json as plain json
Some weights of the model checkpoint at nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaModel were not initialized from the model checkpoint at nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large and 

<__main__.RebelComponent at 0x25511426d60>

In [4]:
# crosslingual_coreference implementation
def coref_res(text_series, min_words=3):
    coref_text_series = text_series.apply(lambda x : coref(x)._.resolved_text if len(x.split())>min_words else x)
    return(coref_text_series)

# # choose minilm for speed/memory and info_xlm for accuracy
# predictor = Predictor(
#     language="en_core_web_sm", device=-1, model_name="minilm"
# )

In [5]:
def link_entities_text(text):
    try:
        ent_rel_lst = list(rel_ext(text)._.rel.values())
    except:
        print("Could not extract relationships for text")
        ent_rel_lst = [{'relation': 'rel_err',
                        'head_span': {'text': 'rel_err', 'id': 'rel_err'},
                        'tail_span': {'text': 'rel_err', 'id': 'rel_err'}}]
        
    entity_df = pd.DataFrame()
    rel_lst = []
    head_text_lst = []
    head_wiki_id_lst = []
    tail_text_lst = []
    tail_wiki_id_lst = []
    for i in range(len(ent_rel_lst)):
        rel_lst.append(ent_rel_lst[i]['relation'])
        head_text_lst.append(ent_rel_lst[i]['head_span']['text'])
        head_wiki_id_lst.append(ent_rel_lst[i]['head_span']['id'])
        tail_text_lst.append(ent_rel_lst[i]['tail_span']['text'])
        tail_wiki_id_lst.append(ent_rel_lst[i]['tail_span']['id'])
    entity_df['head_text'] = head_text_lst
    entity_df['head_wiki_id'] = head_wiki_id_lst
    entity_df['relation'] = rel_lst
    entity_df['tail_text'] = tail_text_lst
    entity_df['tail_wiki_id'] = tail_wiki_id_lst
    return(entity_df)

def link_entities(text_series):
    entity_df_series = text_series.apply(lambda x : link_entities_text(x))
    return(entity_df_series)

In [6]:
# data_df = pd.read_csv('E:\\GIT_REPOS\\LAB\\Literature_summary\\Test\\Entity_edgelist\\Input\\entomology-machine-learning-csv.csv')
# data_df = data_df.drop('File', axis=1)
# coref.max_length = 2612577
# coref.max_length = 2612577
pwd = os.getcwd()
data_path = os.path.dirname(pwd) + "\\Data\\NDPs\\"
data_df = pd.read_feather(data_path+"docs.feather")
data_df = data_df.drop_duplicates().reset_index(drop=True)
data_df["text"] = data_df["text"].astype(str)
data_df['text'] = data_df['text'].str.split('\n')
data_df = data_df.explode('text').reset_index(drop=True)

In [7]:
data_df['file'].unique()

array(['Bangladesh MDP.pdf', 'Bangladesh NDP.pdf', 'Bangladesh VNR.pdf',
       'Botswana NDP.pdf', 'Botswana VNR.pdf', 'Cameroon MDP.pdf',
       'Cameroon NDP.pdf', 'Eswatini NDP.pdf', 'Eswatini VNR.pdf',
       'Gambia NDP.pdf', 'Gambia VNR.pdf', 'Ghana MDP.pdf',
       'Ghana NDP.pdf', 'Ghana VNR.pdf', 'Kenya MDP.pdf', 'Kenya NDP.pdf',
       'Kenya VNR.pdf', 'Lao NDP.pdf', 'Lao VNR.pdf', 'Liberia MDP.pdf',
       'Liberia NDP.pdf', 'Liberia VNR.pdf', 'Malawi NDP.pdf',
       'Malawi VNR.pdf', 'Namibia MDP.pdf', 'Namibia NDP.pdf',
       'Namibia VNR.pdf', 'Nigeria NDP.pdf', 'Nigeria VNR.pdf',
       'Pakistan NDP.pdf', 'Pakistan VNR.pdf', 'Rwanda NDP.pdf',
       'Rwanda VNR.pdf', 'South Africa NDP.pdf', 'South Africa VNR.pdf',
       'Soutn Africa MDP.pdf', 'Sri Lanka NDP.pdf', 'Sri Lanka VNR.pdf',
       'Tanzania MDP.pdf', 'Tanzania NDP.pdf', 'Togo NDP.pdf',
       'Togo VNR.pdf', 'Zambia NDP.pdf', 'Zambia VNR.pdf',
       'Zimbabwe NDP.pdf', 'Zimbabwe VNR.pdf'], dtype=object)

In [8]:
['Gambia VNR.pdf','Botswana NDP.pdf','Cameroon MDP.pdf']

['Gambia VNR.pdf', 'Botswana NDP.pdf', 'Cameroon MDP.pdf']

In [9]:
file_lst = [ 'Ghana MDP.pdf',
       'Ghana NDP.pdf', 'Ghana VNR.pdf', 'Kenya MDP.pdf', 'Kenya NDP.pdf',
       'Kenya VNR.pdf', 'Lao NDP.pdf', 'Lao VNR.pdf', 'Liberia MDP.pdf',
       'Liberia NDP.pdf', 'Liberia VNR.pdf', 'Malawi NDP.pdf',
       'Malawi VNR.pdf', 'Namibia MDP.pdf', 'Namibia NDP.pdf',
       'Namibia VNR.pdf', 'Nigeria NDP.pdf', 'Nigeria VNR.pdf',
       'Pakistan NDP.pdf', 'Pakistan VNR.pdf', 'Rwanda NDP.pdf',
       'Rwanda VNR.pdf', 'South Africa NDP.pdf', 'South Africa VNR.pdf',
       'Soutn Africa MDP.pdf', 'Sri Lanka NDP.pdf', 'Sri Lanka VNR.pdf',
       'Tanzania MDP.pdf', 'Tanzania NDP.pdf', 'Togo NDP.pdf',
       'Togo VNR.pdf', 'Zambia NDP.pdf', 'Zambia VNR.pdf',
       'Zimbabwe NDP.pdf', 'Zimbabwe VNR.pdf']

In [None]:
# Calculate and save per paper
win_size = 100
start_point = 0
entities_df_lst = []
for j in file_lst:
    print('_______________________________\n\n', j)
    df = data_df[data_df['file']==j]
    for i in range(start_point, len(df), win_size):
        coref_series = coref_res(text_series=df["text"].iloc[i:i+win_size])
        print('coref done', i, '-', i+win_size)
        link_entities_series = link_entities(text_series=coref_series)
        print('entity linking done', i, '-', i+win_size)
        entities_df = pd.concat(link_entities_series.tolist())
        print('df create done', i, '-', i+win_size)
        entities_df_lst.append(entities_df)
        print('df to list done', i, '-', i+win_size, '\n')
    all_entities_df = pd.concat(entities_df_lst)
    all_entities_df.reset_index(drop=True, inplace=True)
    edge_lst_df = all_entities_df.value_counts().reset_index().rename(columns={0: "count"})
    edge_lst_df.to_csv('entity_weighted_edgelist_ALL_'+j[:-4]+'.csv')
    edge_lst_df.to_feather('entity_weighted_edgelist_ALL_'+j[:-4]+'.feather')

_______________________________

 Ghana MDP.pdf
coref done 0 - 100
entity linking done 0 - 100
df create done 0 - 100
df to list done 0 - 100 

coref done 100 - 200
entity linking done 100 - 200
df create done 100 - 200
df to list done 100 - 200 

coref done 200 - 300
entity linking done 200 - 300
df create done 200 - 300
df to list done 200 - 300 

coref done 300 - 400
entity linking done 300 - 400
df create done 300 - 400
df to list done 300 - 400 

coref done 400 - 500
entity linking done 400 - 500
df create done 400 - 500
df to list done 400 - 500 

coref done 500 - 600
entity linking done 500 - 600
df create done 500 - 600
df to list done 500 - 600 

coref done 600 - 700
entity linking done 600 - 700
df create done 600 - 700
df to list done 600 - 700 

coref done 700 - 800
entity linking done 700 - 800
df create done 700 - 800
df to list done 700 - 800 

coref done 800 - 900
entity linking done 800 - 900
df create done 800 - 900
df to list done 800 - 900 

coref done 900 - 1000
en

In [None]:
# win_size = 1000
# start_point = 0 #46000 # default to 0
# entities_df_lst = []
# for i in range(start_point, len(data_df), win_size):
#     coref_series = coref_res(text_series=data_df["text"].iloc[i:i+win_size])
#     print('coref done', i, '-', i+win_size)
#     link_entities_series = link_entities(text_series=coref_series)
#     print('entity linking done', i, '-', i+win_size)
#     entities_df = pd.concat(link_entities_series.tolist())
#     print('df create done', i, '-', i+win_size)
#     entities_df_lst.append(entities_df)
#     print('df to list done', i, '-', i+win_size, '\n')
# all_entities_df = pd.concat(entities_df_lst)
# # all_entities_df.reset_index(drop=True, inplace=True)
# edge_lst_df = all_entities_df.value_counts().reset_index().rename(columns={0: "count"})
# edge_lst_df.to_csv('entity_weighted_edgelist_ALL.csv')
# edge_lst_df.to_feather('entity_weighted_edgelist_ALL.feather')