In [None]:
# Add rebel component https://github.com/Babelscape/rebel/blob/main/spacy_component.py
import spacy
import crosslingual_coreference
import requests
import re
import hashlib
import pandas as pd
from spacy import Language
from typing import List
from spacy.tokens import Doc, Span
from transformers import pipeline

In [None]:
def call_wiki_api(item):
    try:
        url = f"https://www.wikidata.org/w/api.php?action=wbsearchentities&search={item}&language=en&format=json"
        data = requests.get(url).json()
        # Return the first id (Could upgrade this in the future)
        return data['search'][0]['id']
    except:
        return 'id-less'

def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets


@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline("text2text-generation", model=model_name, tokenizer=model_name, device=device)
        self.entity_mapping = {}
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
            Doc.set_extension("rel", default={})

    def get_wiki_id(self, item: str):
        mapping = self.entity_mapping.get(item)
        if mapping:
            return mapping
        else:
            res = call_wiki_api(item)
            self.entity_mapping[item] = res
            return res

    
    def _generate_triplets(self, sent: Span) -> List[dict]:
        output_ids = self.triplet_extractor(sent.text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
        extracted_text = self.triplet_extractor.tokenizer.batch_decode(output_ids[0])
        extracted_triplets = extract_triplets(extracted_text[0])
        return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        for triplet in triplets:

            # Remove self-loops (relationships that start and end at the entity)
            if triplet['head'] == triplet['tail']:
                continue

            # Use regex to search for entities
            head_span = re.search(triplet["head"], doc.text)
            tail_span = re.search(triplet["tail"], doc.text)

            # Skip the relation if both head and tail entities are not present in the text
            # Sometimes the Rebel model hallucinates some entities
            if not head_span or not tail_span:
                continue

            index = hashlib.sha1("".join([triplet['head'], triplet['tail'], triplet['type']]).encode('utf-8')).hexdigest()
            if index not in doc._.rel:
                # Get wiki ids and store results
                doc._.rel[index] = {"relation": triplet["type"], "head_span": {'text': triplet['head'], 'id': self.get_wiki_id(triplet['head'])}, "tail_span": {'text': triplet['tail'], 'id': self.get_wiki_id(triplet['tail'])}}

    def __call__(self, doc: Doc) -> Doc:
        for sent in doc.sents:
            sentence_triplets = self._generate_triplets(sent)
            self.set_annotations(doc, sentence_triplets)
        return doc

In [None]:
DEVICE = 0 # Number of the GPU, -1 if want to use CPU

# Add coreference resolution model
coref = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
coref.add_pipe("xx_coref", 
    config={"chunk_size": 2500, 
            "chunk_overlap": 2, 
            "device": DEVICE}
              )

# Define rel extraction model
rel_ext = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])
rel_ext.add_pipe("rebel", 
                 config={
                     'device':DEVICE,
                     'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
                )

In [None]:
# crosslingual_coreference implementation
def coref_res(text_series):
    coref_text_series = text_series.apply(lambda x : coref(x)._.resolved_text)
    return(coref_text_series)

# # choose minilm for speed/memory and info_xlm for accuracy
# predictor = Predictor(
#     language="en_core_web_sm", device=-1, model_name="minilm"
# )

In [None]:
def link_entities_text(text):
    try:
        ent_rel_lst = list(rel_ext(text)._.rel.values())
    except:
        print("Could not extract relationships for text")
        ent_rel_lst = [{'relation': 'rel_err',
                        'head_span': {'text': 'rel_err', 'id': 'rel_err'},
                        'tail_span': {'text': 'rel_err', 'id': 'rel_err'}}]
        
    entity_df = pd.DataFrame()
    rel_lst = []
    head_text_lst = []
    head_wiki_id_lst = []
    tail_text_lst = []
    tail_wiki_id_lst = []
    for i in range(len(ent_rel_lst)):
        rel_lst.append(ent_rel_lst[i]['relation'])
        head_text_lst.append(ent_rel_lst[i]['head_span']['text'])
        head_wiki_id_lst.append(ent_rel_lst[i]['head_span']['id'])
        tail_text_lst.append(ent_rel_lst[i]['tail_span']['text'])
        tail_wiki_id_lst.append(ent_rel_lst[i]['tail_span']['id'])
    entity_df['head_text'] = head_text_lst
    entity_df['head_wiki_id'] = head_wiki_id_lst
    entity_df['relation'] = rel_lst
    entity_df['tail_text'] = tail_text_lst
    entity_df['tail_wiki_id'] = tail_wiki_id_lst
    return(entity_df)

def link_entities(text_series):
    entity_df_series = text_series.apply(lambda x : link_entities_text(x))
    return(entity_df_series)

In [None]:
data_df = pd.read_csv('E:\\GIT_REPOS\\LAB\\Literature_summary\\Test\\Entity_edgelist\\Input\\entomology-machine-learning-csv.csv')
# data_df = data_df.drop('File', axis=1)
data_df = data_df.drop_duplicates().reset_index(drop=True)
data_df

In [None]:
# import data
# data_df = pd.read_csv('E:\GIT_REPOS\LAB\Literature_summary\TPN\Papers\\scopus.csv')
data_df = data_df[data_df["Abstract"] != '[No abstract available]']
data_df.reset_index(inplace=True, drop=True)
data_df["Abstract"] = data_df["Abstract"].str.replace(r'(', '')
data_df["Abstract"] = data_df["Abstract"].str.replace(r')', '')
data_df["Abstract"] = data_df["Abstract"].str.replace(r"'", '')
data_df["Abstract"] = data_df["Abstract"].str.replace(r"'", '')
data_df["Abstract"] = data_df["Abstract"].str.replace(r'"', '')
data_df["Abstract"] = data_df["Abstract"].str.replace(r'"', '')
data_df["Abstract"] = data_df["Abstract"].astype(str)

In [None]:
win_size = 100
start_point = 0 # default to 0
# coref_lst = []
entities_df_lst = []
for i in range(start_point, len(data_df), win_size):
    coref_series = coref_res(text_series=data_df["Abstract"].iloc[i:i+win_size])
    print('coref done', i)
    link_entities_series = link_entities(text_series=coref_series)
    print('entity linking done', i)
    entities_df = pd.concat(link_entities_series.tolist())
    print('df create done', i)
    entities_df_lst.append(entities_df)
    print('df to list done', i, '\n')
all_entities_df = pd.concat(entities_df_lst)
all_entities_df.reset_index(drop=True, inplace=True)
edge_lst_df = all_entities_df.value_counts().reset_index().rename(columns={0: "count"})
edge_lst_df.to_csv('entity_weighted_edgelist_entemology_ML.csv')

In [None]:
# coref_series = coref_res(text_series=data_df["Abstract"][:1000])
# link_entities_series = link_entities(text_series=coref_series)
# all_entities_df = pd.concat(link_entities_series.tolist())