<a href="https://colab.research.google.com/github/Theonimfi/Text-mining/blob/main/merged_coreference_preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# import joblib
# keep_docs = joblib.load('/content/drive/MyDrive/Text_mining_Shared/enwiki20220701-stripped/random/Saved_docs')

In [3]:
# for doc in keep_docs:
#   print(doc)
#   break;

In [4]:
!pip install transformers
!pip install coreferee
!pip install fastcoref
!python3 -m coreferee install en
!python3 -m spacy download en_core_web_lg

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.23.1-py3-none-any.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 5.6 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 65.7 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 44.4 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.10.1 tokenizers-0.13.1 transformers-4.23.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting coreferee
  Downloading coreferee-1.3.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB

In [5]:
import re
import coreferee
from typing import List,Tuple
import json
import pandas as pd 
import spacy
from spacy import Language, util
from spacy.tokens import Doc, Span
from transformers import pipeline
import time
import string
from fastcoref import FCoref

In [6]:
def extract_triplets(text: str) -> List[str]:
    """
    parses the text to triplets
    1. Split the text into tokens
    2. If the token is <triplet>, <subj>, or <obj>, then set the current variable to the appropriate value
    3. If the token is not one of the above, then append it to the appropriate variable
    4. If the current variable is <subj>, then append the triplet to the list of triplets
    :param text: str - the text to be parsed
    :type text: str
    :return: A list of dictionaries.
    """
    triplets = []
    relation, subject, relation, object_ = "", "", "", ""
    text = text.strip()
    current = "x"
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = "t"
            if relation != "":
                triplets.append(
                    {"head": subject.strip(), "type": relation.strip(), "tail": object_.strip()}
                )
                relation = ""
            subject = ""
        elif token == "<subj>":
            current = "s"
            if relation != "":
                triplets.append(
                    {"head": subject.strip(), "type": relation.strip(), "tail": object_.strip()}
                )
            object_ = ""
        elif token == "<obj>":
            current = "o"
            relation = ""
        else:
            if current == "t":
                subject += " " + token
            elif current == "s":
                object_ += " " + token
            elif current == "o":
                relation += " " + token
    if subject != "" and relation != "" and object_ != "":
        triplets.append(
            {"head": subject.strip(), "type": relation.strip(), "tail": object_.strip()}
        )

    return triplets
  

In [7]:


@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline(
            "text2text-generation", model=model_name, tokenizer=model_name, device=device
        )
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
            Doc.set_extension("rel", default={})

    def _generate_triplets(self, sents: List[Span]) -> List[List[dict]]:
        """
        1. We pass the text of the sentence to the triplet extractor.
        2. The triplet extractor returns a list of dictionaries.
        3. We extract the token ids from the dictionaries.
        4. We decode the token ids into text.
        5. We extract the triplets from the text.
        6. We return the triplets.
        The triplet extractor is a model that takes a sentence as input and returns a list of dictionaries.
        Each dictionary contains the token ids of the extracted triplets.
        The token ids are the numbers that represent the words in the sentence.
        For example, the token id of the word "the" is 2.
        The token ids are decoded into text using the tokenizer.
        The tokenizer is a model that takes a list of token ids as input and returns a list of words.
        :param sents: List[Span]
        :type sents: List[Span]
        :return: A list of lists of dicts.
        """
        
        output_ids = self.triplet_extractor(
            [sent.text for sent in sents], return_tensors=True, return_text=False
        )  # [0]["generated_token_ids"]
        extracted_texts = self.triplet_extractor.tokenizer.batch_decode(
            [out["generated_token_ids"] for out in output_ids]
        )
        extracted_triplets = []
        for text in extracted_texts:
            extracted_triplets.extend(extract_triplets(text))
        return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        """
        The function takes a spacy Doc object and a list of triplets (dictionaries) as input.
        For each triplet, it finds the substring in the Doc object that matches the head and tail of the triplet.
        It then creates a spacy span object for each of the head and tail.
        Finally, it creates a dictionary of the relation type, head span and tail span and adds it to the Doc object
        :param doc: the spacy Doc object
        :type doc: Doc
        :param triplets: List[dict]
        :type triplets: List[dict]
        """
        for triplet in triplets:
            # get substring to spacy span
            head_span = re.search(triplet["head"], doc.text)
            tail_span = re.search(triplet["tail"], doc.text)
            # get spacy span
            if head_span is not None:
                head_span = doc.char_span(head_span.start(), head_span.end())
            else:
                #print(f"can't find spacy head span: {triplet['head']}")
                continue
                #head_span = triplet["head"]
            if tail_span is not None:
                tail_span = doc.char_span(tail_span.start(), tail_span.end())
            else:
                #print(f"can't find spacy tail span: {triplet['tail']}")
                continue
                #tail_span = triplet["tail"]
            
            if head_span is not None and tail_span is not None:
              offset = (head_span.start, tail_span.start)
              if offset not in doc._.rel:
                  doc._.rel[offset] = {
                      "relation": triplet["type"],
                      "head_span": head_span,
                      "tail_span": tail_span,
                  }

    def __call__(self, doc: Doc) -> Doc:
        """
        The function takes a doc object and returns a doc object
        :param doc: Doc
        :type doc: Doc
        :return: A Doc object with the sentence triplets added as annotations.
        """
        sentence_triplets = self._generate_triplets(doc.sents)
        self.set_annotations(doc, sentence_triplets)
        return doc

    def pipe(self, stream, batch_size=128):
        """
        It takes a stream of documents, and for each document,
        it generates a list of sentence triplets,
        and then sets the annotations for each sentence in the document
        :param stream: a generator of Doc objects
        :param batch_size: The number of documents to process at a time, defaults to 128 (optional)
        """
        for docs in util.minibatch(stream, size=batch_size):
            sents = []
            for doc in docs:
                sents += doc.sents
            sentence_triplets = self._generate_triplets(sents)
            index = 0
            for doc in docs:
                n_sent = len(list(doc.sents))
                self.set_annotations(doc, sentence_triplets[index : index + n_sent])
                index += n_sent
                yield doc



In [8]:
@Language.factory(
    "rebel_optimized",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class OptimizedRebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline(
            "text2text-generation", model=model_name, tokenizer=model_name, device=device
        )
        # Register custom extension on the Doc
        if not Doc.has_extension("rel"):
            Doc.set_extension("rel", default={})
    
    def _filter_sentences(self,sents:List[Span]) -> List[List[str]]:
        """"Filters a list of sentences returning only the ones containing multiple people as text"""
        sentences = []
        for sentence in sents:
            people = []
            for ent in sentence.ents:
                if ent.label_ == "PERSON" and ent.text not in people:
                    people.append(ent.text)
            if len(people) > 1:
                sentences.append(sentence.text)
        return sentences
            
    def _generate_triplets(self, sents: List[Span]) -> List[List[dict]]:
        """
        1. We pass the text of the sentence to the triplet extractor.
        2. The triplet extractor returns a list of dictionaries.
        3. We extract the token ids from the dictionaries.
        4. We decode the token ids into text.
        5. We extract the triplets from the text.
        6. We return the triplets.
        The triplet extractor is a model that takes a sentence as input and returns a list of dictionaries.
        Each dictionary contains the token ids of the extracted triplets.
        The token ids are the numbers that represent the words in the sentence.
        For example, the token id of the word "the" is 2.
        The token ids are decoded into text using the tokenizer.
        The tokenizer is a model that takes a list of token ids as input and returns a list of words.
        :param sents: List[Span]
        :type sents: List[Span]
        :return: A list of lists of dicts.
        """
        sentences = self._filter_sentences(sents)
        if len(sentences) == 0: return []
        output_ids = self.triplet_extractor(
            sentences, return_tensors=True, return_text=False
        )  # [0]["generated_token_ids"]
        extracted_texts = self.triplet_extractor.tokenizer.batch_decode(
            [out["generated_token_ids"] for out in output_ids]
        )
        extracted_triplets = []
        for text in extracted_texts:
            extracted_triplets.extend(extract_triplets(text))
        return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        """
        The function takes a spacy Doc object and a list of triplets (dictionaries) as input.
        For each triplet, it finds the substring in the Doc object that matches the head and tail of the triplet.
        It then creates a spacy span object for each of the head and tail.
        Finally, it creates a dictionary of the relation type, head span and tail span and adds it to the Doc object
        :param doc: the spacy Doc object
        :type doc: Doc
        :param triplets: List[dict]
        :type triplets: List[dict]
        """
        for triplet in triplets:
            # get substring to spacy span
            head_span = re.search(triplet["head"], doc.text)
            tail_span = re.search(triplet["tail"], doc.text)
            # get spacy span
            if head_span is not None:
                head_span = doc.char_span(head_span.start(), head_span.end())
            else:
                #print(f"can't find spacy head span: {triplet['head']}")
                continue
                #head_span = triplet["head"]
            if tail_span is not None:
                tail_span = doc.char_span(tail_span.start(), tail_span.end())
            else:
                #print(f"can't find spacy tail span: {triplet['tail']}")
                continue
                #tail_span = triplet["tail"]
            
            if head_span is not None and tail_span is not None:
              offset = (head_span.start, tail_span.start)
              if offset not in doc._.rel:
                  doc._.rel[offset] = {
                      "relation": triplet["type"],
                      "head_span": head_span,
                      "tail_span": tail_span,
                  }

    def __call__(self, doc: Doc) -> Doc:
        """
        The function takes a doc object and returns a doc object
        :param doc: Doc
        :type doc: Doc
        :return: A Doc object with the sentence triplets added as annotations.
        """
        sentence_triplets = self._generate_triplets(doc.sents)
        self.set_annotations(doc, sentence_triplets)
        return doc

    def pipe(self, stream, batch_size=128):
        """
        It takes a stream of documents, and for each document,
        it generates a list of sentence triplets,
        and then sets the annotations for each sentence in the document
        :param stream: a generator of Doc objects
        :param batch_size: The number of documents to process at a time, defaults to 128 (optional)
        """
        
        for docs in util.minibatch(stream, size=batch_size):
            for doc in docs:
                sentence_triplets = self._generate_triplets(doc.sents)
                self.set_annotations(doc, sentence_triplets)
                yield doc



In [9]:

@Language.factory(
    "coref_resolver",
    assigns=["doc._.resolved_text"],
    default_config={
        "model_name": "FCoref",
        "device": "cuda:0",
    },
)

class CorefResolver:
    """a class that implements the logic from
    https://towardsdatascience.com/how-to-make-an-effective-coreference-resolution-model-55875d2b5f19"""
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: str,
    ):
        assert model_name is not None, ""
            
        self.coref_model = FCoref(device=device)
        # Register custom extension on the Doc
        if not Doc.has_extension("resolved_text"):
            Doc.set_extension("resolved_text", default=None)
    def get_span_noun_indices(self, doc: Doc, cluster: List[Tuple]) -> List[int]:
        """
        > Get the indices of the spans in the cluster that contain at least one noun or proper noun
        :param doc: Doc
        :param cluster: List[List[int]]
        :return: A list of indices of spans that contain at least one noun or proper noun.
        """
        spans = [doc.char_span(span[0],span[1]) for span in cluster]
        spans_pos = [[token.pos_ for token in span] for span in spans]
        span_noun_indices = [
            i for i, span_pos in enumerate(spans_pos) if any(pos in span_pos for pos in ["NOUN", "PROPN"])
        ]
        return span_noun_indices
    def get_cluster_head(self, doc: Doc, cluster: List[Tuple], noun_indices: List[int]):
        """
        > Given a spaCy Doc, a list of clusters, and a list of noun indices, return the head span and its start and end
        indices
        :param doc: the spaCy Doc object
        :type doc: Doc
        :param cluster: a list of lists, where each sublist is a span of tokens in the document
        :type cluster: List[List[int]]
        :param noun_indices: a list of indices of the nouns in the cluster
        :type noun_indices: List[int]
        :return: The head span and the start and end indices of the head span.
        """
        head_idx = noun_indices[0]
        head_start,head_end = cluster[head_idx]
        head_span = doc.char_span(head_start,head_end)
        return head_span, [head_start, head_end]
    def is_containing_other_spans(self,span: List[int], all_spans: List[List[int]]):
        """
        It returns True if there is any span in all_spans that is contained within span and is not equal to span
        :param span: the span we're checking to see if it contains other spans
        :type span: List[int]
        :param all_spans: a list of all the spans in the document
        :type all_spans: List[List[int]]
        :return: A list of all spans that are not contained in any other span.
        """
        return any([s[0] >= span[0] and s[1] <= span[1] and s != span for s in all_spans])
    def core_logic_part(self,document: Doc, coref: List[int], resolved: List[str], mention_span: Span):
        """
        If the last token of the mention is a possessive pronoun, then add an apostrophe and an s to the mention.
        Otherwise, just add the last token to the mention
        :param document: Doc object
        :type document: Doc
        :param coref: List[int]
        :param resolved: list of strings, where each string is a token in the sentence
        :param mention_span: The span of the mention that we want to replace
        :return: The resolved list is being returned.
        """
        char_span = document.char_span(coref[0],coref[1])
        final_token = char_span[-1]
        final_token_tag = str(final_token.tag_).lower()
        test_token_test = False
        for option in ["PRP$", "POS", "BEZ"]:
            if option.lower() in final_token_tag:
                test_token_test = True
                break
        if test_token_test:
            resolved[char_span.start] = mention_span.text + "'s" + final_token.whitespace_
        else:
            resolved[char_span.start] = mention_span.text + final_token.whitespace_
        for i in range(char_span.start + 1, char_span.end):
            resolved[i] = ""
        return resolved
    def _has_multiple_people(self,doc:Doc) -> bool:
        return True
        # people = []
        # for entity in doc.ents:
        #     if entity.label_ == 'PERSON' and entity.text not in people:
        #         people.append(entity.text)
        #     if len(people)>1:
        #         return True
        # return False
    def __call__(self, doc: Doc) -> Doc:
        """
        The function takes a doc object and returns a doc object
        :param doc: Doc
        :type doc: Doc
        :return: A Doc object with the sentence triplets added as annotations.
        """
        preds = self.coref_model.predict(
                texts=[doc.text])
        clusters = preds[0].get_clusters(as_strings=False)
        resolved = list(tok.text_with_ws for tok in doc)
        cluster_heads = {}
        all_spans = [span for cluster in clusters for span in cluster]
        for cluster in clusters:
            indices = self.get_span_noun_indices(doc,cluster)
            if indices:
                mention_span, mention = self.get_cluster_head(doc, cluster, indices)
                cluster_heads[str(mention_span)] = mention

                for coref in cluster:
                    if coref != mention and not self.is_containing_other_spans(coref, all_spans):
                        self.core_logic_part(doc, coref, resolved, mention_span)
        doc._.resolved_text = "".join(resolved)
        return doc
    def pipe(self, stream, batch_size=512):
        for docs in util.minibatch(stream, size=batch_size):
            filtered_docs = [doc for doc in docs if self._has_multiple_people(doc)]
            preds = self.coref_model.predict(
                    texts=[doc.text for doc in filtered_docs],max_tokens_in_batch=512)
                    
            doc_keys = [doc.text[:15] for doc in filtered_docs]
            for pred in preds:
                clusters = pred.get_clusters(as_strings=False)
                doc = filtered_docs[doc_keys.index(pred.text[:15])] #Find document since preds returns random order
                resolved = list(tok.text_with_ws for tok in doc)
                cluster_heads = {}
                all_spans = [span for cluster in clusters for span in cluster]
                for cluster in clusters:
                    try:
                        indices = self.get_span_noun_indices(doc,cluster)
                        if indices:
                            mention_span, mention = self.get_cluster_head(doc, cluster, indices)
                            cluster_heads[str(mention_span)] = mention

                            for coref in cluster:
                                if coref != mention and not self.is_containing_other_spans(coref, all_spans):
                                    self.core_logic_part(doc, coref, resolved, mention_span)
                    except:
                        print("Couldn't process: ",doc.text[:15])
                
                        continue
                doc._.resolved_text = "".join(resolved)
                yield doc
  

In [10]:
def resolve_corefs(doc):
  s = set(["his","her","their"])
  mutable_doc = [str(word) for word in doc]
  if doc._.coref_chains:
    for chain in doc._.coref_chains.chains:
      #print(chain)
      for token in chain:
        #print(token)
        if len(token) == 1:
          coref = doc._.coref_chains.resolve(doc[token[0]])
          if coref:
            if len(coref) > 1:
              mutable_doc[token[0]] = " and ".join([str(word) for word in coref])
            else:
              t = str(doc[token[0]])
              c = str(coref[0])
              #print(t)
              if t in s:
                mutable_doc[token[0]] = f"{c}'s"
              else:
                mutable_doc[token[0]] = c
  output = ""
  cnt = 0
  for token in mutable_doc:
    if token not in string.punctuation and cnt>0:
      output+=f" {token}"
    else:
      output+=token
    cnt+=1
  return output

In [11]:
# import joblib
# keep_docs = joblib.load('/content/drive/MyDrive/Text_mining_Shared/enwiki20220701-stripped/random/Saved_docs')
# print(len(keep_docs))

In [12]:

# # Using readlines()
# file1 = open('/content/drive/MyDrive/Text mining_Shared/enwiki20220701-stripped/AB/wiki_63', 'r')
# Lines = file1.readlines()
  
# content = []
# count = 0
# # Strips the newline character
# for line in Lines:
#     content.append(json.loads(line.split('\n')[0]))

In [13]:
import pandas as pd
from google.colab import auth
import gspread
from google.auth import default
#autenticating to google
auth.authenticate_user()
creds, _ = default()
gc = gspread.authorize(creds)

directory = '/content/drive/MyDrive/Data Science and AI/Text mining_Shared/enwiki20220701-stripped/Evaluation dataset'

# read data and put it in a dataframe
gsheets = gc.open_by_url('https://docs.google.com/spreadsheets/d/1sRp-1FyAQ-WWzVvxxClKY7-rQOBVi6JOfIaVlfsPkqE/edit#gid=1048303571')
sheets = gsheets.worksheet('Clustered doc').get_all_values()
df_evaluation_dataset = pd.DataFrame(sheets[1:], columns=sheets[0])

In [14]:
# df_evaluation_dataset.rename(columns={"title": "text"}, inplace = True)
df_evaluation_dataset.head()

Unnamed: 0,text,cluster,text_id,COUNTER
0,Lidia Chojecka-Leandro (born 25 January 1977 i...,0,91,0
1,"Nubkhesbed (""Gold and Lapis lazuli"") was an an...",0,92,1
2,Gustavo Selbach (born 25 August 1974 in Três C...,0,93,2
3,"The Friedland was an 80-gun ""Bucentaure""-class...",0,94,3
4,Sergey Shayslamov (born 23 July 1970) is an Uz...,0,95,4


In [15]:
# df.head()

In [16]:
# df = df.loc[df['text']!= ""] # Discard empty pages
# df = df.reset_index()
# print(len(df))

## Actual Pipeline

In [17]:
# preprocess = spacy.load("en_core_web_lg")
# preprocess.add_pipe('coreferee')
preprocess = spacy.load("en_core_web_lg")
preprocess.add_pipe("coref_resolver",config={'device':'cuda:0'})
nlp = spacy.load("en_core_web_lg")
nlp.add_pipe("rebel_optimized",config={
    'device':-1, # Number of the GPU, -1 if want to use CPU
    'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
    )

Downloading:   0%|          | 0.00/819 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/393 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/362M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/123 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/344 [00:00<?, ?B/s]

<__main__.OptimizedRebelComponent at 0x7fdcbaf73b50>

In [18]:
# #Only used with coreferee
# def has_multiple_people(doc):
#   people = []
#   for entity in doc.ents:
#     if entity.label_ == 'PERSON' and entity.text not in people:
#       people.append(entity.text)
#     if len(people)>1:
#       return True
#   return False

In [19]:
sample_df = df_evaluation_dataset[:5]
sample_df

Unnamed: 0,text,cluster,text_id,COUNTER
0,Lidia Chojecka-Leandro (born 25 January 1977 i...,0,91,0
1,"Nubkhesbed (""Gold and Lapis lazuli"") was an an...",0,92,1
2,Gustavo Selbach (born 25 August 1974 in Três C...,0,93,2
3,"The Friedland was an 80-gun ""Bucentaure""-class...",0,94,3
4,Sergey Shayslamov (born 23 July 1970) is an Uz...,0,95,4


In [20]:
import math 

def is_first_batch(batch_counter):
  if batch_counter>0:
      return False
  return True

def cast_rel_dict(rel_dict):
  return {"head": str(rel_dict["head_span"]), "relation": rel_dict["relation"], "tail": str(rel_dict["tail_span"])}

relations_lst = []
idx_list = []
## New implementation
RELATIONSHIPS = {'spouse','sibling','father','child','family','mother','relative','student of'} #set
BATCH_SIZE = 1 # Colab usually breaks with higher batch sizes.
NUM_BATCHES = math.ceil(len(sample_df)/BATCH_SIZE)

for batch in range(0,NUM_BATCHES):
    print("Processing batch: ",batch)
    start = time.time()
    texts = sample_df.iloc[batch*BATCH_SIZE:min((batch+1)*BATCH_SIZE,len(sample_df))]["text"].values
    docs = preprocess.pipe(texts,batch_size=BATCH_SIZE) 
    prepped_texts = [doc._.resolved_text for doc in docs]
    #prepped_texts = [resolve_corefs(doc) for doc in docs if has_multiple_people(doc)]
    
    print(f"Resolving coreferences took {time.time() - start} seconds")
    start = time.time()
    docs = nlp.pipe(prepped_texts)
    relations = [[cast_rel_dict(rel_dict) for _,rel_dict in doc._.rel.items() if rel_dict["relation"] in RELATIONSHIPS] for doc in docs]
    relations_lst.append(relations)
    idx_list.append(batch)

    print(f"Relation extraction took {time.time() - start} seconds")
    # pd.DataFrame(relations).to_csv("/content/drive/MyDrive/relations.csv",mode='a',index=False,header=is_first_batch(batch))

Processing batch:  0


  0%|          | 0/1 [00:00<?, ?ba/s]


1it [00:00, 38.44it/s]


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Resolving coreferences took 1.5738275051116943 seconds
Relation extraction took 24.32903265953064 seconds
Processing batch:  1


  0%|          | 0/1 [00:00<?, ?ba/s]


1it [00:00, 49.23it/s]


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Resolving coreferences took 1.658994197845459 seconds
Relation extraction took 36.17378640174866 seconds
Processing batch:  2


  0%|          | 0/1 [00:00<?, ?ba/s]


1it [00:00, 36.44it/s]


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Resolving coreferences took 0.9595954418182373 seconds
Relation extraction took 51.98384475708008 seconds
Processing batch:  3


  0%|          | 0/1 [00:00<?, ?ba/s]


1it [00:00, 75.21it/s]


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Resolving coreferences took 0.9831550121307373 seconds
Relation extraction took 10.4212007522583 seconds
Processing batch:  4


  0%|          | 0/1 [00:00<?, ?ba/s]


1it [00:00, 64.96it/s]


Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Resolving coreferences took 0.9242794513702393 seconds
Relation extraction took 7.817996263504028 seconds


In [21]:
new_list = [x[0] for x in relations_lst]
new_list[4]

[{'head': 'Sergey Shayslamov',
  'relation': 'sibling',
  'tail': 'Vladimir Shayslamov'},
 {'head': 'Vladimir Shayslamov',
  'relation': 'sibling',
  'tail': 'Sergey Shayslamov'}]

In [22]:
df_relations = pd.DataFrame(
    {'text_index': idx_list,
     'relations': new_list,
    })
df_relations

Unnamed: 0,text_index,relations
0,0,"[{'head': 'Lidia Chojecka-Leandro', 'relation'..."
1,1,"[{'head': 'Ramesses VI', 'relation': 'child', ..."
2,2,"[{'head': 'Gustavo Selbach', 'relation': 'sibl..."
3,3,"[{'head': 'Napoleon', 'relation': 'spouse', 't..."
4,4,"[{'head': 'Sergey Shayslamov', 'relation': 'si..."


In [23]:
import json 

# Convert data to json format
for relation in df_relations.relations:
  convert_json = []
  for rel in relation:
    convert_json.append(json.loads(str(rel).replace("'","\"")))
  df_relations['relations'] = df_relations['relations'].replace(relation,convert_json)
df_relations

Unnamed: 0,text_index,relations
0,0,"[{'head': 'Lidia Chojecka-Leandro', 'relation'..."
1,1,"[{'head': 'Ramesses VI', 'relation': 'child', ..."
2,2,"[{'head': 'Gustavo Selbach', 'relation': 'sibl..."
3,3,"[{'head': 'Napoleon', 'relation': 'spouse', 't..."
4,4,"[{'head': 'Sergey Shayslamov', 'relation': 'si..."


In [24]:
# %%time
# # PROCESS_FIRST_N_ARTICLES = 100

# # Keep only rows with non empty text
# # df = df[df.text != '']
# # df['text_preprocessed'] = df['text']
# relationships = ['spouse','sibling','father','child','employer','family','mother','relative','student of']

# # Reindexing
# # df.index = range(len(df))
# # test= df[:PROCESS_FIRST_N_ARTICLES] # only use first 100 for test change this later to process everything
# sample_df["relations"] = ''
# relations_lst = []
# idx_list = []
# # testing_df['text_preprocessed']= testing_df['text_preprocessed'].apply(lambda x: remove_punctuation(x))
# cnt = 0
# for idx,row in sample_df.iterrows():
#   cnt+=1
#   start = time.time()
#   processed_doc = preprocess(row["text"])
#   # if has_multiple_people(processed_doc):
#     #print(processed_doc)
#   resolved = resolve_corefs(processed_doc)
#   doc = nlp(resolved)
#   doc_list = nlp.pipe([doc])
  
#   relations = [rel_dict for _,rel_dict in doc._.rel.items() if rel_dict["relation"] in relationships]
#   relations_lst.append(relations)
#   idx_list.append(cnt)
#   print(f"{idx} took {time.time()-start} seconds..")

In [25]:
# relations_lst

In [26]:
# df_relations = pd.read_csv("/content/drive/MyDrive/relations.csv")
# df_relations

# Filter and close

In [27]:
# df_filtered = pd.read_csv("/content/drive/MyDrive/relations.csv")

In [93]:
df_filtered = df_relations
df_filtered

Unnamed: 0,text_index,relations
0,0,"[{'tail': 'Lidia Chojecka-Leandro', 'head': 'J..."
1,1,"[{'relation': 'sibling', 'tail': 'Ramesses VII..."
2,2,"[{'head': 'Gustavo Selbach', 'tail': 'Leonardo..."
3,3,"[{'tail': 'Napoleon', 'head': 'Marie Louise', ..."
4,4,"[{'tail': 'Vladimir Shayslamov', 'relation': '..."


In [275]:
def filter_data(df):
    # Filter the names beginning with a lower case letter
    df["relations"] = df["relations"].apply(lambda rels: [rel for rel in rels if rel["head"][0].isupper() and rel["tail"][0].isupper()] if rels != "" else [])
    print(df["relations"])
    # Filter the names that consist of the upper case letter only
    df["relations"] = df["relations"].apply(lambda rels: [rel for rel in rels if not rel["head"].isupper() and not rel["tail"].isupper()])
    print(df["relations"])
    # Filter the names that consist of the lower case letter only
    df["relations"] = df["relations"].apply(lambda rels: [rel for rel in rels if not rel["head"].islower() and not rel["tail"].islower()])
    print(df["relations"])
    # Certain relations are symmetric. So, we need to add the reverse relation as well.
    # For perfectly symmetric:
    df["relations"] = df["relations"].apply(lambda rels: rels + [{"head":rel["tail"], "tail":rel["head"], "relation":rel["relation"]} for rel in rels if rel["relation"] in ["spouse","sibling","family","relative"]])
    print(df["relations"])
    # For symmetric ones with non-trivial symmetry:
    df["relations"] = df["relations"].apply(lambda rels: rels + [{"head":rel["tail"], "tail":rel["head"], "relation":"child"} for rel in rels if rel["relation"] in ["father","mother"]])
    print(df["relations"])
    # Remove reflexive relations
    df["relations"] = df["relations"].apply(lambda rels: [rel for rel in rels if  rel["head"] != rel["tail"]])
    print(df["relations"])
    # Remove duplicates just in case
    df["relations"] = df["relations"].apply(lambda rels: [dict(s) for s in set(frozenset(d.items()) for d in rels)])

    return df
df_filtered = filter_data(df_filtered)
unique_ids = len(df_filtered)

0    [{'tail': 'Lidia Chojecka-Leandro', 'head': 'J...
1    [{'relation': 'sibling', 'tail': 'Ramesses VII...
2    [{'head': 'Gustavo Selbach', 'tail': 'Leonardo...
3    [{'relation': 'spouse', 'head': 'Marie Louise'...
4    [{'tail': 'Vladimir Shayslamov', 'relation': '...
Name: relations, dtype: object
0    [{'tail': 'Lidia Chojecka-Leandro', 'head': 'J...
1    [{'relation': 'sibling', 'tail': 'Ramesses VII...
2    [{'head': 'Gustavo Selbach', 'tail': 'Leonardo...
3    [{'relation': 'spouse', 'head': 'Marie Louise'...
4    [{'tail': 'Vladimir Shayslamov', 'relation': '...
Name: relations, dtype: object
0    [{'tail': 'Lidia Chojecka-Leandro', 'head': 'J...
1    [{'relation': 'sibling', 'tail': 'Ramesses VII...
2    [{'head': 'Gustavo Selbach', 'tail': 'Leonardo...
3    [{'relation': 'spouse', 'head': 'Marie Louise'...
4    [{'tail': 'Vladimir Shayslamov', 'relation': '...
Name: relations, dtype: object
0    [{'tail': 'Lidia Chojecka-Leandro', 'head': 'J...
1    [{'relation': 'sibling

In [276]:
# df_filtered.to_csv("/content/drive/MyDrive/relations_filtered.csv")

# Evaluation

In [290]:
# Read evaluation data and put it in a dataframe
gsheets = gc.open_by_url('https://docs.google.com/spreadsheets/d/1sRp-1FyAQ-WWzVvxxClKY7-rQOBVi6JOfIaVlfsPkqE/edit#gid=1048303571')
sheets = gsheets.worksheet('eval').get_all_values()

# Get only first 20 rows for testing
df_evaluation = pd.DataFrame(sheets[1:], columns=sheets[0])[:18]

In [291]:
df_evaluation

Unnamed: 0,text_index,Head,Type,Tail
0,0,Lidia Chojecka-Leandro,spouse,Jean-Marc Léandro.
1,0,Jean-Marc Léandro.,spouse,Lidia Chojecka-Leandro
2,1,Nubkhesbed,spouse,Pharaoh Ramesses VI
3,1,Pharaoh Ramesses VI,spouse,Nubkhesbed
4,1,Nubkhesbed,spouse,Pharaoh Ramesses VII
5,1,Nubkhesbed,mother,Princess Iset
6,1,Nubkhesbed,mother,Amenherkhepshef
7,1,Nubkhesbed,mother,Panebenkemyt.
8,1,Amenherkhepshef's tomb,son,Nubkhesbed
9,1,Princess Iset,daugher,Nubkhesbed


In [292]:
# Create the same format as filtered format to filter also the evaluation set
df_evaluation['relations'] = ''
for i in range(len(df_evaluation)):
  value = "{'head': '"+df_evaluation['Head'][i].replace("'"," ")+"','relation': '"+df_evaluation['Type'][i]+"','tail': '"+df_evaluation['Tail'][i].replace("'"," ")+"'}"
  df_evaluation['relations'][i] = json.loads(value.replace("'","\"").replace(".",""))
df_evaluation = df_evaluation.drop(columns=['Head','Type','Tail'])
df_evaluation

Unnamed: 0,text_index,relations
0,0,"{'head': 'Lidia Chojecka-Leandro', 'relation':..."
1,0,"{'head': 'Jean-Marc Léandro', 'relation': 'spo..."
2,1,"{'head': 'Nubkhesbed', 'relation': 'spouse', '..."
3,1,"{'head': 'Pharaoh Ramesses VI', 'relation': 's..."
4,1,"{'head': 'Nubkhesbed', 'relation': 'spouse', '..."
5,1,"{'head': 'Nubkhesbed', 'relation': 'mother', '..."
6,1,"{'head': 'Nubkhesbed', 'relation': 'mother', '..."
7,1,"{'head': 'Nubkhesbed', 'relation': 'mother', '..."
8,1,"{'head': 'Amenherkhepshef s tomb', 'relation':..."
9,1,"{'head': 'Princess Iset', 'relation': 'daugher..."


In [293]:
# Filter also the evaluation to compare it later
df_evaluation = df_evaluation.groupby('text_index')['relations'].apply(list).reset_index(name="relations")

# Problem with filtering
# df_evaluation = filter_data(df_evaluation)

In [294]:
# Unlist elements in relations column
df_evaluation_exploded = df_evaluation.explode('relations')
df_results_exploded = df_filtered.explode('relations')
df_evaluation_exploded

Unnamed: 0,text_index,relations
0,0,"{'head': 'Lidia Chojecka-Leandro', 'relation':..."
0,0,"{'head': 'Jean-Marc Léandro', 'relation': 'spo..."
1,1,"{'head': 'Nubkhesbed', 'relation': 'spouse', '..."
1,1,"{'head': 'Pharaoh Ramesses VI', 'relation': 's..."
1,1,"{'head': 'Nubkhesbed', 'relation': 'spouse', '..."
1,1,"{'head': 'Nubkhesbed', 'relation': 'mother', '..."
1,1,"{'head': 'Nubkhesbed', 'relation': 'mother', '..."
1,1,"{'head': 'Nubkhesbed', 'relation': 'mother', '..."
1,1,"{'head': 'Amenherkhepshef s tomb', 'relation':..."
1,1,"{'head': 'Princess Iset', 'relation': 'daugher..."


In [295]:
# Find equal relations
count_total = 0
count_equal = 0
for i in range(unique_ids):
  evaluate = df_evaluation_exploded[df_evaluation_exploded['text_index']==str(i)]['relations'].values
  results = df_results_exploded[df_results_exploded['text_index']==i]['relations'].values
  for relation_e in evaluate:
    count_total+=1
    for relation_r in results:
      # if relation_e['head']!=relation_r['head'] or relation_e['tail']!=relation_r['tail'] or relation_e['relation']!=relation_r['relation']:
      #     continue
      if (((relation_r['head'] in relation_e['head']) or (relation_e['head'] in relation_r['head'])) and \
      ((relation_r['tail'] in relation_e['tail'])  or (relation_e['tail'] in relation_r['tail'])) and \
      ((relation_r['relation'] in relation_e['relation']) or (relation_e['relation'] in relation_r['relation']))):
          count_equal +=1
          break
print("Total relations in evaluation set", count_total, ". From these, rebel found:", count_equal)

Total relations in evaluation set 18 . From these, rebel found: 8


# Save to the evaluation format

In [None]:
reformatted_df = pd.DataFrame(columns=["Index", "Head", "Type", "Tail"])

for i in range(len(df_filtered)):
  for relation in df_filtered["relations"][i]:
    reformatted_df = reformatted_df.append({"Index": df_filtered["text_index"][i], "Head": relation["head"], "Type": relation["relation"], "Tail": relation["tail"]}, ignore_index=True)

reformatted_df

In [None]:
reformatted_df.to_csv("/content/drive/MyDrive/relations_reformatted.csv")

In [None]:
for index in df_evaluation.index:
  print(index)
  print(df_evaluation.loc[df_evaluation['Index'] == str(index), 'Head'].values)
  break

# Visualisation

In [None]:
!pip install networkx

In [None]:
test_vis = pd.read_csv(f"/content/drive/MyDrive/relations_filtered.csv")

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

edges = [(rel_dict['head'], rel_dict['tail']) for relations in test_vis["relations"] for rel_dict in eval(relations) if relations != ""]
G = nx.Graph()
G.add_edges_from(edges)
pos = nx.spring_layout(G, k=5, scale=1.0)
plt.figure(figsize=(16, 10))
nx.draw(
    G,
    pos,
    edge_color='black',
    width=1,
    linewidths=1,
    node_color='pink',
    labels={node: node for node in G.nodes()}
)
nx.draw_networkx_edge_labels(
    G,
    pos,
    edge_labels={
        (rel_dict['head'], rel_dict['tail']): rel_dict['relation']
        for relations in test_vis["relations"] for rel_dict in eval(relations) if relations != ""
    },
    font_color='grey'
)
plt.axis('off')
plt.show()