In [None]:
! pip install nltk

In [None]:
! python -m nltk.downloader punkt

In [1]:
import json
from pathlib import Path

from nltk.tokenize import sent_tokenize
import tiktoken

encoder = tiktoken.get_encoding('cl100k_base')

<br><br><br>
# Load dataset data

In [2]:
rebel_20_data_path = Path('REBEL_20.json')
with rebel_20_data_path.open('r') as f:
    rebel_20_data = json.load(f)

<br><br><br>
# Get metrics data

In [3]:
def get_metrics_dict(data):
    metrics_data = dict()
    for doc_id, doc_data in enumerate(data):        
        doc_metrics = dict()

        # --- Text centered stats
        doc_metrics['text'] = {
            'sentence count': len(sent_tokenize(doc_data['doc']['text'])),
            'token count': len(encoder.encode(doc_data['doc']['text']))
        }

        # --- Entity centered metrics
        gpt_entities = doc_data['entities']['gpt']
        rebel_entities = list(set([e['surfaceform'] for e in doc_data['entities']['gold']]))
        gpt_correct_entities = [e for e in gpt_entities if e['annotation']['entity correctness']]
        gpt_missed_entities = doc_data['entities']['annotation']['gpt missed']
        gpt_entity_description_from_text = [
            e for e in gpt_entities
            if e['annotation']['entity correctness'] and e['annotation']['description from text']
        ]
        gpt_rebel_entities_matches = [e for e in gpt_entities if e['annotation']['rebel match']]
        doc_metrics['entities'] = {
            'total':         len(gpt_entities),
            'rebel count':   len(rebel_entities),
            'correct':       len(gpt_correct_entities),
            'missing':       len(gpt_missed_entities),
            'from text':     len(gpt_entity_description_from_text),
            'rebel matches': len(gpt_rebel_entities_matches)
        }
    
        # --- Type centered metrics
        gpt_types_all = [t for e in gpt_entities for t in e['types']]
        gpt_types_first = [e['types'][0] for e in gpt_entities]
        gpt_correct_types_all = [t for e in gpt_entities for t, b in zip(e['types'], e['annotation']['type correctness']) if b]
        gpt_correct_types_first = [e['types'][0] for e in gpt_entities if e['annotation']['type correctness'][0]]
        doc_metrics['types'] = {
            'all total':   len(gpt_types_all),
            'first total':   len(gpt_types_first),
            'all correct': len(gpt_correct_types_all),
            'first correct': len(gpt_correct_types_first),
        }
    
        # --- Triplet centered metrics
        gpt_triples = doc_data['triples']['gpt']
        rebel_triples = list(set([
            '|'.join([t[role]['surfaceform'].strip() for role in ['subject', 'predicate', 'object']])
            for t in doc_data['triples']['gold']
        ]))
        gpt_correct_triples = [t for t in gpt_triples if t['annotation']['triple correctness']]
        gpt_relation_from_text = [
            t for t in gpt_triples
            if t['annotation']['triple correctness'] and t['annotation']['relation from text']
        ]
        gpt_rebel_triples_matches = [t for t in gpt_triples if t['annotation']['rebel match']]
        doc_metrics['triples'] = {
            'total':         len(gpt_triples),
            'rebel count':   len(rebel_triples),
            'correct':       len(gpt_correct_triples),
            'from text':     len(gpt_relation_from_text),
            'rebel matches': len(gpt_rebel_triples_matches),
        }

        # --- Track metrics
        metrics_data[doc_id] = doc_metrics
    
    return metrics_data

In [4]:
m = get_metrics_dict(rebel_20_data)

## Counts

In [5]:
print(f'Number of documents: {len(rebel_20_data):,d}')

Number of documents: 20


In [6]:
ce = sum([doc['entities']['total'] for _, doc in m.items()])
print(f'Number of entities: {ce:,d}')

Number of entities: 379


In [7]:
ct = sum([doc['triples']['total'] for _, doc in m.items()])
print(f'Number of triples: {ct:,d}')

Number of triples: 329


In [8]:
sentence_count = sum([doc['text']['sentence count'] for _, doc in m.items()])
print(f'Sentence count: {sentence_count:,d}')

Sentence count: 374


In [9]:
token_count_list = [doc['text']['token count'] for _, doc in m.items()]
print(f'Average token count: {round(sum(token_count_list)/len(token_count_list)):,d}')

Average token count: 602


## Compute metrics

In [10]:
pe = sum([doc['entities']['correct'] for _, doc in m.items()]) / sum([doc['entities']['total'] for _, doc in m.items()])
print(f'Micro precision entities: {pe: 4.3f}')

Micro precision entities:  0.976


In [11]:
re = sum([doc['entities']['correct'] for _, doc in m.items()]) / sum([doc['entities']['correct'] + doc['entities']['missing'] for _, doc in m.items()])
print(f'Micro recall entities: {re: 4.3f}')

Micro recall entities:  0.896


In [12]:
f1e = 2 / (1 / pe + 1 / re)
print(f'Micro F1 entities: {f1e: 4.3f}')

Micro F1 entities:  0.934


In [13]:
pta = sum([doc['types']['all correct'] for _, doc in m.items()]) / sum([doc['types']['all total'] for _, doc in m.items()])
print(f'Micro precision types (all): {pta: 4.3f}')

Micro precision types (all):  0.777


In [14]:
ptf = sum([doc['types']['first correct'] for _, doc in m.items()]) / sum([doc['types']['first total'] for _, doc in m.items()])
print(f'Micro precision types (only first): {ptf: 4.3f}')

Micro precision types (only first):  0.939


In [15]:
pr = sum([doc['triples']['correct'] for _, doc in m.items()]) / sum([doc['triples']['total'] for _, doc in m.items()])
print(f'Micro precision triples: {pr: 4.3f}')

Micro precision triples:  0.836


In [16]:
ae = 1 - (sum([doc['entities']['from text'] for _, doc in m.items()]) / sum([doc['entities']['correct'] for _, doc in m.items()]))
print(f'Percentage of entities descriptions from GPT knowledge: {ae: 4.3f}')

Percentage of entities descriptions from GPT knowledge:  0.322


In [17]:
at = 1 - (sum([doc['triples']['from text'] for _, doc in m.items()]) / sum([doc['triples']['correct'] for _, doc in m.items()]))
print(f'Percentage of triples from GPT knowledge: {at: 4.3f}')

Percentage of triples from GPT knowledge:  0.065


In [18]:
ne = 1 - (sum([doc['entities']['rebel matches'] for _, doc in m.items()]) / sum([doc['entities']['total'] for _, doc in m.items()]))
print(f'Percentage of new entities (vs REBEL): {ne: 4.3f}')

Percentage of new entities (vs REBEL):  0.451


In [19]:
nt = 1 - (sum([doc['triples']['rebel matches'] for _, doc in m.items()]) / sum([doc['triples']['total'] for _, doc in m.items()]))
print(f'Percentage of new triples (vs REBEL): {nt: 4.3f}')

Percentage of new triples (vs REBEL):  0.945
