In [1]:
import requests
import re
import hashlib
from spacy import Language, util
from typing import List
from spacy.tokens import Doc, Span
from transformers import pipeline
import crosslingual_coreference
import spacy
from os.path import isfile
import os
import ftfy
import json
import glob
from tqdm import tqdm
import pandas as pd
import numpy as np

[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\Sverre\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [2]:
path = "preprocessed-rebel/"

In [3]:
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets

In [4]:
@Language.factory(
    "rebel",
    requires=["doc.sents"],
    assigns=["doc._.rel"],
    default_config={
        "model_name": "Babelscape/rebel-large",
        "device": 0,
    },
)
class RebelComponent:
    def __init__(
        self,
        nlp,
        name,
        model_name: str,
        device: int,
    ):
        assert model_name is not None, ""
        self.triplet_extractor = pipeline("text2text-generation", model=model_name, tokenizer=model_name, device=device)
        self.entity_mapping = {}
        if not Doc.has_extension("rel"):
            Doc.set_extension("rel", default={})

    def _generate_triplets(self, sent: Span) -> List[dict]:
        output_ids = self.triplet_extractor(sent.text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
        extracted_text = self.triplet_extractor.tokenizer.batch_decode(output_ids[0])
        extracted_triplets = extract_triplets(extracted_text[0])
        return extracted_triplets

    def set_annotations(self, doc: Doc, triplets: List[dict]):
        for triplet in triplets:

            # Skip relationships, where the subject is the same as the object
            if triplet['head'] == triplet['tail']:
                continue

            # Regular expression search for subjects and objects
            head_span = re.search(triplet["head"], doc.text)
            tail_span = re.search(triplet["tail"], doc.text)

            # Skip relationships, where both subject and object are not in text
            if not head_span or not tail_span:
                continue

            index = hashlib.sha1("".join([triplet['head'], triplet['tail'], triplet['type']]).encode('utf-8')).hexdigest()
            if index not in doc._.rel:
                doc._.rel[index] = {"relation": triplet["type"], "head_span": triplet['head'], "tail_span": triplet['tail']}

    def __call__(self, doc: Doc) -> Doc:
        for sent in doc.sents:
            sentence_triplets = self._generate_triplets(sent)
            self.set_annotations(doc, sentence_triplets)
        return doc

In [5]:
DEVICE = -1

# Define rel extraction model
rel_ext = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer', 'attribute_rules', 'tagger'])
rel_ext.add_pipe("rebel", config={
    'device':DEVICE,
    'model_name':'Babelscape/rebel-large'}
    )

<__main__.RebelComponent at 0x1fba8048220>

In [6]:
def find_relations(par):
    all_relations = []
    doc = rel_ext(par)  
    for value, rel_dict in doc._.rel.items():
        all_relations.append(rel_dict)
    
    return all_relations

In [7]:
evaluation_df = pd.read_csv('Evaluation_Set_CSV.csv')
# unique indexes in EVAL set
indexes = evaluation_df['PAR ID'].unique()

In [8]:
"""
        Function to apply preprocessing on a selection of files and store it in a separate folder
        - path: root folder (preprocessed-rebel), 
        - subf desired subfolder: AA or AB, must be passed as a string (e.g. "AA") 
        - files in subfolder AA: e.g. p_r_wiki_00 
        - files in subfolder AB: e.g. p_r_wiki_00
        - start: start number file, e.g. 0-99 (no need to fill in 00, 0 is fine)
        - end: end number file, e.g. 0-99
        - the range is inclusive which means, e.g. with (0, 0) you select & pre-process file wiki_00,
        - with (32, 50) you select file wiki_32 up till wiki_50
""" 
def rel_extraction_mul_files(path, subf=None, start=None, end=None):
    relations = {}
    if subf:
        # from file start to end
        for i in tqdm(range(start, end+1)):
            # to match the filename p_r_wiki_00 up till p_r_wiki_09 we add a zero in front of the number from user input if necessary
            if i < 10:
                i = "0" + str(i)
            
            print("Processing wiki file: ", i, subf)
            # construct path to file name that falls within range
            f = path + subf + "/p_r_wiki_{}".format(i) 
            
            # if file exists
            if isfile(f):
                # OPEN FILE, GO THROUGH EACH PAR AND PASS THAT INTO COREF FUNCTION
                file = open(f, 'r', encoding='utf-8')
                doc = json.load(file)
                for k, v in doc.items():
                    v = ftfy.fix_text(v) # FIX ANY ENCODINGS
                    # if paragraph has more than one word
                    if len(v.split(" ")) > 1:
                        relations[k] = find_relations(v)
                        
    else:
        for f in glob.glob('preprocessed-rebel/*/*'):
            # OPEN EACH FILE, GO THROUGH EACH PARAGRAPH AND PASS THAT INTO COREF FUNC
            file = open(f, 'r')
            doc = json.load(file)
            for k, v in doc.items():
                v = ftfy.fix_text(v) # FIX ANY ENCODINGS
                # if paragraph has more than one word
                if len(v.split(" ")) > 1:
                    relations[k] = find_relations(v)
    
    # save to file
    with open('relations.txt', 'w') as convert_file:
        convert_file.write(json.dumps(relations))
        
    return relations

In [9]:
# OPT 1: SELECT SPECIFIC FILES TO FEED REBEL IN ONE PARTICULAR MAP / # OPT 2: FEED ALL FILES, in AA and AB BY ONLY KEEPING PATH IN THERE
# relations = rel_extraction_mul_files(path, "AB", 1, 1) # rel_extraction_mul_files(path)

In [10]:
# rel_extraction_mul_files(path, "AB", 0, 68)

In [11]:
# rel_extraction_mul_files(path, "AA", 0, 99)

LOAD EXTRACTED RELATIONS

In [12]:
new_f = open('relations.txt', encoding="utf8").readlines()[0]
relations = json.loads(new_f)

REBEL OUTPUT TO DATAFRAME

In [13]:
# load evaluation set
evaluation_df = pd.read_csv('Evaluation_Set_CSV.csv')
# unique indexes in EVAL set
indexes = evaluation_df['PAR ID'].unique()
# unique predicates in EVAL set
predicates = evaluation_df['PRED'].unique()

In [14]:
indexes

array(['51207047-1', '51207047-3', '51207047-6', '51207047-8',
       '51207047-11', '51207047-12', '51218601-1', '51218601-3',
       '51237053-1', '51237053-3', '51237053-4', '51237053-10',
       '51239375-1', '51239375-3', '51239375-6', '51239375-8',
       '51239375-15', '51244418-1', '51244418-3', '51251385-1',
       '51251385-3', '51251385-6', '51345817-1', '51345817-4',
       '51376486-1', '51376486-3', '51376486-5', '51554035-1',
       '51561288-1', '51561288-3', '51579506-1', '51579506-3',
       '51694863-1', '51694863-3', '51694863-7', '51745138-1',
       '51745138-3', '51745138-5', '51758444-1', '51758444-6',
       '51758444-7', '51758444-10', '51861231-1', '51861231-4',
       '51861231-6', '51861231-7', '51947774-1', '51947774-2',
       '51947774-5', '51947774-7', '51956447-1', '51956447-3',
       '51956447-4', '51990887-1', '52025159-1', '52025159-2',
       '52025159-5', '52025159-8', '52025159-12', '52025159-14',
       '52037349-1', '52037349-3', '52041825-1',

In [15]:
extract = dict()
for i in ['PAR ID', 'SUB', 'PRED', 'OBJ']:
    extract[i] = []


for key, list_of_dicts in relations.items():
    if key in indexes:
        for dic in list_of_dicts:
            extract['PAR ID'].append(key)
            extract['SUB'].append(dic['head_span'])
            extract['PRED'].append(dic['relation'])
            extract['OBJ'].append(dic['tail_span'])

# turn REBEL output to a DF
rebel_df = pd.DataFrame.from_dict(extract)

In [16]:
# all predicates found by REBEL
rebel_df['PRED'].unique()

array(['main subject', 'employer', 'date of birth',
       'country of citizenship', 'educated at', 'founded by', 'developer',
       'owned by', 'owner of', 'parent organization', 'subsidiary',
       'industry', 'occupation', 'field of this occupation',
       'practiced by', 'facet of', 'subclass of', 'has effect',
       'has cause', 'part of', 'field of work', 'award received',
       'manufacturer', 'instance of', 'member of', 'place of birth',
       'country', 'contains administrative territorial entity',
       'located in the administrative territorial entity', 'inception',
       'editor', 'headquarters location', 'studied by', 'date of death',
       'member of political party', 'notable work', 'author',
       'affiliation', 'sport', 'conferred by', 'studies', 'publisher',
       'point in time', 'genre', 'location', 'has part', 'cast member',
       'record label', 'product or material produced', 'operating system',
       'creator', 'developer <triplet>rtty', 'place of d

In [17]:
# all predicates found in EVAL set
predicates

array(['occupation', 'educated at', 'date of birth', 'place of birth',
       'founded by', 'member of', 'date of death', 'spouse'], dtype=object)

KEEP ONLY PREDICATES FROM EVAL SET

In [18]:
filtered_rebel_df = rebel_df[rebel_df['PRED'].isin(predicates)]
filtered_rebel_df

# remove par id so that we can filter duplicate triples later
filtered_rebel_df_id = filtered_rebel_df.set_index('PAR ID')
filtered_rebel_df_id

Unnamed: 0_level_0,SUB,PRED,OBJ
PAR ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
27049342-1,Willem van Biljon,date of birth,1961
27049342-2,Willem van Biljon,date of birth,1961
27049342-2,Willem van Biljon,educated at,University of Cape Town
27049342-3,Willem van Biljon,date of birth,1961
27049342-4,Willem van Biljon,date of birth,1961
...,...,...,...
1775773-2,"Charles Patrick ""Chuck"" Thacker",date of birth,"February 26, 1943"
1775773-2,Ralph Scott Thacker,date of birth,1906
1775773-2,"Charles Patrick ""Chuck"" Thacker",date of birth,1922
1775773-3,"Charles Patrick ""Chuck"" Thacker",educated at,"University of California, Berkeley"


In [19]:
evaluation_df

# remove par id so that we can filter duplicate triples later
evaluation_df_id = evaluation_df.set_index('PAR ID')
evaluation_df_id

Unnamed: 0_level_0,SUB,PRED,OBJ
PAR ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
51207047-1,Elissa Shevinsky,occupation,"american technology executive , entrepreneur ,..."
51207047-3,Elissa Shevinsky,educated at,Benjamin Cardozo High School
51207047-3,Elissa Shevinsky,educated at,Williams College
51207047-6,Elissa Shevinsky,occupation,chief executive of Glimpse
51207047-8,Elissa Shevinsky,occupation,Head of Product
...,...,...,...
308365-2,Amir Pnueli,educated at,the Weizmann Institute of Science
308365-2,Amir Pnueli,member of,Stanford University
308365-3,Amir Pnueli,occupation,researcher
308365-3,the computer science department at Tel Aviv Un...,founded by,Amir Pnueli


REMOVE DUPLICATE TRIPLES

In [20]:
eval_list_true_id = np.vstack({tuple(row) for row in evaluation_df_id.to_numpy()})

  eval_list_true_id = np.vstack({tuple(row) for row in evaluation_df_id.to_numpy()})


In [21]:
eval_list_pred_id = np.vstack({tuple(row) for row in filtered_rebel_df_id.to_numpy()})

  eval_list_pred_id = np.vstack({tuple(row) for row in filtered_rebel_df_id.to_numpy()})


In [22]:
dicty = dict()
for i in eval_list_pred_id:
    if not i[1] in dicty.keys():
        dicty[i[1]] = 1
    else:
        dicty[i[1]] = dicty[i[1]] + 1
dicty

{'educated at': 141,
 'date of birth': 87,
 'founded by': 27,
 'place of birth': 38,
 'occupation': 30,
 'member of': 33,
 'date of death': 14,
 'spouse': 10}

(OPTION 1) EVALUATION BASED ON CONTAINMENT ON EACH TRIPLET PART

In [23]:
new_true = [[str(element) for element in triple] for triple in eval_list_true_id]
new_pred = [[str(element) for element in triple] for triple in eval_list_pred_id]

In [24]:
correct = 0
for triple_pred in new_pred:
    for triple_true in new_true:
        if (triple_pred[0] in triple_true[0] or triple_true[0] in triple_pred[0]) and (triple_pred[1] in triple_true[1] or triple_true[1] in triple_pred[1]) and (triple_pred[2] in triple_true[2] or triple_true[2] in triple_pred[2]):
            print(triple_true)
            print(triple_pred)
            print()
            correct += 1

recall = correct/len(new_true)
precision = correct/len(new_pred)

['Jeff Seibert', 'educated at', 'Stanford University']
['Jeff Seibert', 'educated at', 'Stanford University']

['Vladimir Naumovich Vapnik', 'date of birth', '6 December 1936']
['Vladimir Naumovich Vapnik', 'date of birth', '6 December 1936']

['Cynthia Barnhart', 'occupation', 'civil engineer']
['Cynthia Barnhart', 'occupation', 'civil engineer']

['Donald Bruce Gillies', 'date of birth', 'October 15, 1928']
['Bruce Gillies', 'date of birth', 'October 15, 1928']

['David Isaac Murray', 'educated at', 'Carnegie Mellon University']
['David Isaac Murray', 'educated at', 'Carnegie Mellon University']

['Donald Bruce Gillies', 'date of birth', 'October 15, 1928']
['Donald Bruce Gillies', 'date of birth', 'October 15, 1928']

['Dr. Herbert Freeman', 'date of birth', 'December 13, 1925']
['Herbert Freeman', 'date of birth', 'December 13, 1925']

['Dharmendra S. Modha', 'educated at', 'Jacobs School of Engineering']
['Dharmendra S. Modha', 'educated at', 'Jacobs School of Engineering']

['Rob

In [25]:
print('Number of correct predictions: {}'.format(correct), 'Number of evaluation triples: {}'.format(len(new_true)))

print('Recall: {}'.format(recall))
print('Precision: {}'.format(precision))
print('F1: {}'.format(2*precision*recall/(precision+recall)))

Number of correct predictions: 290 Number of evaluation triples: 779
Recall: 0.3722721437740693
Precision: 0.7631578947368421
F1: 0.5004314063848145


SHOW ALL EVALUATION TRIPLETS THAT WERE NOT PREDICTED BY REBEL

In [26]:
missing_triples = []
for triple_true in new_true:
    triple_missing = True
    for triple_pred in new_pred:
        if (triple_pred[0] in triple_true[0] or triple_true[0] in triple_pred[0]) and (triple_pred[1] in triple_true[1] or triple_true[1] in triple_pred[1]) and (triple_pred[2] in triple_true[2] or triple_true[2] in triple_pred[2]):
            triple_missing = False
    if triple_missing:
        missing_triples.append(triple_true)

print('Number of articles not predicted by REBEL, but in Evaluation set: {}'.format(len(missing_triples)))
missing_triples

Number of articles not predicted by REBEL, but in Evaluation set: 502


[['Andi (Andrei) Gutmans', 'educated at', 'Israel Institute of Technology'],
 ['Stephen Richard ', 'educated at', 'Trinity College, Cambridge'],
 ['LiveJournal', 'founded by', 'Bradley Joseph Fitzpatrick'],
 ['Tara Hernandez', 'educated at', 'University of California'],
 ['Neil T. Heffernan', 'occupation', 'associate professor'],
 ['Berxel Photonics Co. Ltd.', 'founded by', 'Constance J. Chang-Hasnain'],
 ['Ashok K. Chandra',
  'member of',
  'Database and Distributed Systems at IBM Almaden Research Center'],
 ['Lucinda Lucy Sanders', 'educated at', 'Louisiana State University'],
 ['Dharmendra S. Modha', 'occupation', 'lead researcher'],
 ['GeneProt', 'founded by', 'Ron David Appel'],
 ['Martin Richards', 'occupation', 'computer scientist'],
 ['Elissa Shevinsky', 'occupation', 'chief operating officer'],
 ['Paul Vixie', 'occupation', 'software engineer'],
 ['Theo H\\u00e4rder', 'date of birth', 'August 28, 1945'],
 ['Stephen Richard ', 'educated at', "King's college London, England"],


SHOW ALL PREDICTED TRIPLETS THAT WERE NOT IN THE EVALUATION SET

In [27]:
missing_triples = []

for triple_pred in new_pred:
    triple_missing = True
    for triple_true in new_true:
        if (triple_pred[0] in triple_true[0] or triple_true[0] in triple_pred[0]) and (triple_pred[1] in triple_true[1] or triple_true[1] in triple_pred[1]) and (triple_pred[2] in triple_true[2] or triple_true[2] in triple_pred[2]):
            triple_missing = False
    if triple_missing:
        missing_triples.append(triple_pred)

print('Number of articles predicted by REBEL, but not in Evaluation set: {}'.format(len(missing_triples)))
missing_triples

Number of articles predicted by REBEL, but not in Evaluation set: 91


[['first successful commercial anti-spam service', 'founded by', 'Paul Vixie'],
 ['Roopam Sharma', 'place of birth', 'Faridabad'],
 ['Bernard Chazelle', 'member of', 'ACM'],
 ['Zend Technologies', 'founded by', 'Zeev Suraski'],
 ['Bhawani Shankar Chowdhry', 'educated at', 'University of Southampton'],
 ['Robin Milner', 'date of death', '20 March 2010'],
 ['Constance J. Chang-Hasnain', 'educated at', 'Stanford University'],
 ['Marilyn Kirsch', 'spouse', 'Gabor Tamas Herman'],
 ['Bernard Chazelle', 'member of', 'American Academy of Arts and Sciences'],
 ['Pixar', 'founded by', 'Ed Catmull'],
 ['Frances Yao', 'spouse', 'Andrew Chi-Chih Yao'],
 ['Bernard Chazelle',
  'member of',
  'John Simon Guggenheim Memorial Foundation'],
 ['Lucinda "Lucy" Sanders', 'educated at', 'Louisiana State University'],
 ['Andrey Petrovich Yershov',
  'member of',
  'Academy of Sciences of the Soviet Union'],
 ['Remi El-Ouazzane', 'place of birth', 'Neuilly-sur-Seine'],
 ['Zeev Suraski', 'educated at', 'Israel

EVALUATION ON EACH PREDICATE BASED ON CONTAINMENT ON EACH TRIPLET PART

In [28]:
predicates = ['date of birth', 'occupation', 'educated at', 'spouse', 'place of birth', 'member of', 'founded by', 'date of death']
recall_per_predicate = dict()
precision_per_predicate = dict()
f1_per_predicate = dict()

for predicate in predicates:
    
    true_predicate = []
    pred_predicate = []
    
    for triple in eval_list_true_id:
        if triple[1] == predicate:
            true_predicate.append(triple)
    
    for triple in eval_list_pred_id:
        if triple[1] == predicate:
            pred_predicate.append(triple)

    new_true = [[str(element) for element in triple] for triple in true_predicate]
    new_pred = [[str(element) for element in triple] for triple in pred_predicate]

    correct = 0
    for triple_pred in new_pred:
        for triple_true in new_true:
            if (triple_pred[0] in triple_true[0] or triple_true[0] in triple_pred[0]) and (triple_pred[1] in triple_true[1] or triple_true[1] in triple_pred[1]) and (triple_pred[2] in triple_true[2] or triple_true[2] in triple_pred[2]):
                correct += 1

    recall = correct/len(new_true)
    recall_per_predicate[predicate] = recall
    
    precision = correct/len(new_pred)
    precision_per_predicate[predicate] = precision
    
    f1 = 2*precision*recall/(precision+recall)
    f1_per_predicate[predicate] = f1

In [29]:
print('Recall per predicate')
recall_per_predicate

Recall per predicate


{'date of birth': 0.9759036144578314,
 'occupation': 0.09848484848484848,
 'educated at': 0.6580645161290323,
 'spouse': 0.23076923076923078,
 'place of birth': 0.6530612244897959,
 'member of': 0.12903225806451613,
 'founded by': 0.21794871794871795,
 'date of death': 1.0}

In [30]:
print('Precision per predicate')
precision_per_predicate

Precision per predicate


{'date of birth': 0.9310344827586207,
 'occupation': 0.8666666666666667,
 'educated at': 0.723404255319149,
 'spouse': 0.3,
 'place of birth': 0.8421052631578947,
 'member of': 0.48484848484848486,
 'founded by': 0.6296296296296297,
 'date of death': 0.9285714285714286}

In [31]:
print('F1 per predicate')
f1_per_predicate

F1 per predicate


{'date of birth': 0.9529411764705883,
 'occupation': 0.1768707482993197,
 'educated at': 0.6891891891891891,
 'spouse': 0.2608695652173913,
 'place of birth': 0.735632183908046,
 'member of': 0.20382165605095542,
 'founded by': 0.32380952380952377,
 'date of death': 0.962962962962963}

(OPTION 2) EVALUATION BASED ON SIMILARITY ON EACH TRIPLET PART (THUS NO NEED FOR PERFECT MATCH)

In [32]:
nlp = spacy.load('en_core_web_sm')
new_true = [[nlp(str(element)) for element in triple] for triple in eval_list_true_id]
new_pred = [[nlp(str(element)) for element in triple] for triple in eval_list_pred_id]

In [33]:
correct = 0
for triple_pred in new_pred:
    for triple_true in new_true:
        if triple_pred[0].similarity(triple_true[0]) >= 0.9 and triple_pred[1].similarity(triple_true[1]) >= 0.99 and triple_pred[2].similarity(triple_true[2]) >= 0.9:
            print(triple_true)
            print(triple_pred)
            print()
            correct += 1

recall = correct/len(new_true)
precision = correct/len(new_pred)

[Jeff Seibert, educated at, Stanford University]
[Jeff Seibert, educated at, Stanford University]

[Vladimir Naumovich Vapnik, date of birth, 6 December 1936]
[Vladimir Naumovich Vapnik, date of birth, 6 December 1936]

[the first successful commerical anti-spam service, founded by, Paul Vixie]
[first successful commercial anti-spam service, founded by, Paul Vixie]

[Cynthia Barnhart, occupation, civil engineer]
[Cynthia Barnhart, occupation, civil engineer]

[Donald Bruce Gillies, date of birth, October 15, 1928]
[Bruce Gillies, date of birth, October 15, 1928]

[David Isaac Murray, educated at, Carnegie Mellon University]
[David Isaac Murray, educated at, Carnegie Mellon University]

[Donald Bruce Gillies, date of birth, October 15, 1928]
[Donald Bruce Gillies, date of birth, October 15, 1928]

[Fernando Ramón Sallaberry Valls, date of birth, November 25, 1965]
[Donald Bruce Gillies, date of birth, October 15, 1928]

[Dharmendra S. Modha, educated at, Jacobs School of Engineering]
[D

In [34]:
print('Number of correct predictions: {}'.format(correct), 'Number of evaluation triples: {}'.format(len(new_true)))

print('Recall: {}'.format(recall))
print('Precision: {}'.format(precision))
print('F1: {}'.format(2*precision*recall/(precision+recall)))

Number of correct predictions: 268 Number of evaluation triples: 779
Recall: 0.34403080872913994
Precision: 0.7052631578947368
F1: 0.4624676445211389


SHOW ALL EVALUATION TRIPLETS THAT WERE NOT PREDICTED BY REBEL

In [35]:
missing_triples = []
for triple_true in new_true:
    triple_missing = True
    for triple_pred in new_pred:
        if triple_pred[0].similarity(triple_true[0]) >= 0.9 and triple_pred[1].similarity(triple_true[1]) >= 0.99 and triple_pred[2].similarity(triple_true[2]) >= 0.9:
            triple_missing = False
    if triple_missing:
        missing_triples.append(triple_true)

print('Number of articles not predicted by REBEL, but in Evaluation set: {}'.format(len(missing_triples)))
missing_triples

Number of articles not predicted by REBEL, but in Evaluation set: 533


[[Andi (Andrei) Gutmans, educated at, Israel Institute of Technology],
 [Stephen Richard , educated at, Trinity College, Cambridge],
 [LiveJournal, founded by, Bradley Joseph Fitzpatrick],
 [Tara Hernandez, educated at, University of California],
 [Neil T. Heffernan, occupation, associate professor],
 [Berxel Photonics Co. Ltd., founded by, Constance J. Chang-Hasnain],
 [Ashok K. Chandra,
  member of,
  Database and Distributed Systems at IBM Almaden Research Center],
 [Lucinda Lucy Sanders, educated at, Louisiana State University],
 [Dharmendra S. Modha, occupation, lead researcher],
 [GeneProt, founded by, Ron David Appel],
 [Martin Richards, occupation, computer scientist],
 [Elissa Shevinsky, occupation, chief operating officer],
 [Paul Vixie, occupation, software engineer],
 [Dr. Herbert Freeman, place of birth, Frankfurt, Germany],
 [Theo H\u00e4rder, date of birth, August 28, 1945],
 [Stephen Richard , educated at, King's college London, England],
 [Andi (Andrei) Gutmans, educat

SHOW ALL PREDICTED TRIPLETS THAT WERE NOT IN THE EVALUATION SET

In [36]:
missing_triples = []

for triple_pred in new_pred:
    triple_missing = True
    for triple_true in new_true:
        if triple_pred[0].similarity(triple_true[0]) >= 0.9 and triple_pred[1].similarity(triple_true[1]) >= 0.99 and triple_pred[2].similarity(triple_true[2]) >= 0.9:
            triple_missing = False
    if triple_missing:
        missing_triples.append(triple_pred)

print('Number of articles predicted by REBEL, but not in Evaluation set: {}'.format(len(missing_triples)))
missing_triples

Number of articles predicted by REBEL, but not in Evaluation set: 137


[[Roopam Sharma, place of birth, Faridabad],
 [Bernard Chazelle, member of, ACM],
 [Zend Technologies, founded by, Zeev Suraski],
 [Bhawani Shankar Chowdhry, educated at, University of Southampton],
 [Herbert Freeman, date of birth, December 13, 1925],
 [Robin Milner, date of death, 20 March 2010],
 [Constance J. Chang-Hasnain, educated at, Stanford University],
 [Marilyn Kirsch, spouse, Gabor Tamas Herman],
 [Bernard Chazelle, member of, American Academy of Arts and Sciences],
 [Oscar Neftalí Sallaberry Valls, place of birth, New York City],
 [Pixar, founded by, Ed Catmull],
 [Frances Yao, spouse, Andrew Chi-Chih Yao],
 [Constance J. Chang-Hasnain, educated at, University of California, Davis],
 [Bernard Chazelle, member of, John Simon Guggenheim Memorial Foundation],
 [Ian Tremere Foster, date of birth, 1959],
 [Lucinda "Lucy" Sanders, educated at, Louisiana State University],
 [Andrey Petrovich Yershov,
  member of,
  Academy of Sciences of the Soviet Union],
 [Zeev Suraski, educate

EVALUATION ON EACH PREDICATE BASED ON SIMILARITY ON EACH TRIPLET PART

In [37]:
predicates = ['date of birth', 'occupation', 'educated at', 'spouse', 'place of birth', 'member of', 'founded by', 'date of death']
recall_per_predicate = dict()
precision_per_predicate = dict()
f1_per_predicate = dict()

for predicate in predicates:
    
    true_predicate = []
    pred_predicate = []
    
    for triple in eval_list_true_id:
        if triple[1] == predicate:
            true_predicate.append(triple)
    
    for triple in eval_list_pred_id:
        if triple[1] == predicate:
            pred_predicate.append(triple)

    nlp = spacy.load('en_core_web_sm')
    new_true = [[nlp(str(element)) for element in triple] for triple in true_predicate]
    new_pred = [[nlp(str(element)) for element in triple] for triple in pred_predicate]

    correct = 0
    for triple_pred in new_pred:
        for triple_true in new_true:
            if triple_pred[0].similarity(triple_true[0]) >= 0.9 and triple_pred[1].similarity(triple_true[1]) >= 0.99 and triple_pred[2].similarity(triple_true[2]) >= 0.9:
                correct += 1

    recall = correct/len(new_true)
    recall_per_predicate[predicate] = recall
    
    precision = correct/len(new_pred)
    precision_per_predicate[predicate] = precision
    
    f1 = 2*precision*recall/(precision+recall)
    f1_per_predicate[predicate] = f1

In [38]:
print('Recall per predicate')
recall_per_predicate

Recall per predicate


{'date of birth': 0.9518072289156626,
 'occupation': 0.12121212121212122,
 'educated at': 0.6129032258064516,
 'spouse': 0.23076923076923078,
 'place of birth': 0.3673469387755102,
 'member of': 0.12096774193548387,
 'founded by': 0.20512820512820512,
 'date of death': 0.7692307692307693}

In [39]:
print('Precision per predicate')
precision_per_predicate

Precision per predicate


{'date of birth': 0.9080459770114943,
 'occupation': 1.0666666666666667,
 'educated at': 0.6737588652482269,
 'spouse': 0.3,
 'place of birth': 0.47368421052631576,
 'member of': 0.45454545454545453,
 'founded by': 0.5925925925925926,
 'date of death': 0.7142857142857143}

In [40]:
print('F1 per predicate')
f1_per_predicate

F1 per predicate


{'date of birth': 0.9294117647058823,
 'occupation': 0.21768707482993196,
 'educated at': 0.6418918918918918,
 'spouse': 0.2608695652173913,
 'place of birth': 0.41379310344827586,
 'member of': 0.1910828025477707,
 'founded by': 0.3047619047619048,
 'date of death': 0.7407407407407408}