In [1]:
import spacy
# import crosslingual_coreference

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
print(sys.executable)


c:\Users\layas\anaconda3\envs\Python310\python.exe


In [3]:
# Add rebel component https://github.com/Babelscape/rebel/blob/main/spacy_component.py
import requests
import re
import hashlib
from spacy import Language
from typing import List

from spacy.tokens import Doc, Span

from transformers import pipeline

def call_wiki_api(item):
  print("call_wiki_api: ")
  try:
    url = f"https://www.wikidata.org/w/api.php?action=wbsearchentities&search={item}&language=en&format=json"
    data = requests.get(url).json()
    # Return the first id (Could upgrade this in the future)
    print(data['search'][0]['id'])
    return data['search'][0]['id']
  except:
    return 'id-less'

def extract_triplets(text):
    print("extract_triplets: ")
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("", "").replace("", "").replace("", "").split():
        print("Token:",token)
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(),
                                 'type': relation.strip(),
                                 'tail': object_.strip()})
                print("triplets1:",triplets)
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                print("triplets2:",triplets)
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
        print("triplets3:",triplets)
    print("Final triplet:",triplets)
    return triplets


@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):        
        assert model_name is not None, ""
        self.triplet_extractor = pipeline("text2text-generation", model=model_name, tokenizer=model_name, device=device)
        self.entity_mapping = {}
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
          Doc.set_extension("rel", default={})

    def get_wiki_id(self, item: str):
        print("get_wiki_id:")
        mapping = self.entity_mapping.get(item)
        if mapping:
          print("mapping:",mapping)
          return mapping
        else:
          res = call_wiki_api(item)
          self.entity_mapping[item] = res
          print("Res",res)
          return res

    def _generate_triplets(self, sent: Span) -> List[dict]:
        output_ids1 = self.triplet_extractor(sent.text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
        extracted_text1 = self.triplet_extractor.tokenizer.batch_decode(output_ids1[0])
        print("1:",extracted_text1)
        extracted_triplets1 = extract_triplets(extracted_text1[0])
        print("1:",extracted_triplets1)
        return extracted_triplets1
 

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        print("set_annotations")
        for triplet in triplets:
            print("set_annotations: triplet:",triplet)
            # Remove self-loops (relationships that start and end at the entity)
            if triplet['head'] == triplet['tail']:
                continue

            # Use regex to search for entities
            head_span = re.search(triplet["head"], doc.text)
            print("Heads:",head_span,"\n")
            tail_span = re.search(triplet["tail"], doc.text)
            print("Tails:",tail_span,"\n")

            # Skip the relation if both head and tail entities are not present in the text
            # Sometimes the Rebel model hallucinates some entities
            if not head_span or not tail_span:
              continue

            index = hashlib.sha1("".join([triplet['head'], triplet['tail'], triplet['type']]).encode('utf-8')).hexdigest()
            if index not in doc._.rel:
                # Get wiki ids and store results
                doc._.rel[index] = {"relation": triplet["type"], "head_span": {'text': triplet['head'], 'id': self.get_wiki_id(triplet['head'])}, "tail_span": {'text': triplet['tail'], 'id': self.get_wiki_id(triplet['tail'])}}

    def __call__(self, doc: Doc) -> Doc:
        print("__call__")
        for sent in doc.sents:
            print("SENT",sent)
            sentence_triplets = self._generate_triplets(sent)
            self.set_annotations(doc, sentence_triplets)
        print("Doc:",doc)
        return doc

In [4]:

coref = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
coref.add_pipe("xx_coref", config={"chunk_size": 2500, "chunk_overlap": 2, "device": -1})

[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\layas\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\layas\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [None]:
input_text = '''The Solar Park is the largest in the world. It supports the Dubai Clean Energy Strategy 2050.'''

coref_text = coref(input_text)._.resolved_text
print(coref_text)

The Solar Park is the largest in the world. The Solar Park supports the Dubai Clean Energy Strategy 2050.


In [None]:
input_text = '''DEWA has identified the following 6 SDGs where it can have the greatest impact. These goals are also critical for DEWA as a power and water provider.'''

coref_text = coref(input_text)._.resolved_text
print(coref_text)

DEWA has identified the following 6 SDGs where DEWA can have the greatest impact. the following 6 SDGs where it can have the greatest impact are also critical for DEWA as a power and water provider.


In [None]:
input_text = '''DEWA takes into account other excellence models including EFQM, Harvard and others. Its achievements have become a role model for excellence.'''

coref_text = coref(input_text)._.resolved_text
print(coref_text)

DEWA takes into account other excellence models including EFQM, Harvard and others. DEWA's achievements have become a role model for excellence.


DEWA plays a pivotal role in multiple national and international organisations, councils and committees. multiple national and international organisations, councils and committees include but are not limited.


In [None]:
# Add coreference resolution model
coref = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
coref.add_pipe(
    "xx_coref", config={"chunk_size": 2500, "chunk_overlap": 2, "device": -1})

# Define rel extraction model

rel_ext = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])
rel_ext.add_pipe("rebel", config={
    'device':-1, # Number of the GPU, -1 if want to use CPU
    'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
    )


error loading _jsonnet (this is expected on Windows), treating C:\Users\layas\AppData\Local\Temp\tmpxe4yo2sa\config.json as plain json
Some weights of the model checkpoint at nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large were not used when initializing XLMRobertaModel: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaModel were not initialized from the model checkpoint at nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large an

<__main__.RebelComponent at 0x27a8552ed10>

In [None]:
input_text = '''The Solar Park is the largest in the world. It supports the Dubai Clean Energy Strategy 2050.'''

coref_text = coref(input_text)._.resolved_text

doc = rel_ext(coref_text)

for value, rel_dict in doc._.rel.items():
    print(f"{value}: {rel_dict}")

__call__
SENT The Solar Park is the largest in the world.
1: ['<s><triplet> Solar Park <subj> largest in the world <obj> instance of</s>']
1: <s><triplet> Solar Park <subj> largest in the world <obj> instance of</s>
extract_triplets: 
Final triplet: []
set_annotations
SENT The Solar Park supports the Dubai Clean Energy Strategy 2050.
1: ['<s><triplet> Dubai Clean Energy Strategy 2050 <subj> 2050 <obj> point in time</s>']
1: <s><triplet> Dubai Clean Energy Strategy 2050 <subj> 2050 <obj> point in time</s>
extract_triplets: 
Final triplet: []
set_annotations
Doc: The Solar Park is the largest in the world. The Solar Park supports the Dubai Clean Energy Strategy 2050.


In [None]:
print(coref_text)

Christian Drosten works in Germany. Christian Drosten likes to work for Google.


In [None]:
print(doc)

Christian Drosten works in Germany. Christian Drosten likes to work for Google.


In [None]:
item = "Google"
try:
    url = f"https://www.wikidata.org/w/api.php?action=wbsearchentities&search={item}&language=en&format=json"
    data = requests.get(url).json()
    # Return the first id (Could upgrade this in the future)
    print(data['search'][0]['id'])
except:
    print('id-less')

Q95
