*This is the supplementary material for the AKBC 2022 submission "Understanding Relation Extraction through
Knowledge Graphs" (anonymous).*

# Benchmark results

In [1]:
# Load serialized model predictions
datasets = {
    'wiki20m_test': dict(
        test_file = '../OpenNRE/benchmark/wiki20m/wiki20m_test.txt',
        rel2id_file = '../OpenNRE/benchmark/wiki20m/wiki20m_rel2id.json',
        nonna = False,
        score_files = {
            'bert-masking-cls': 'pred/bert-supervised-cls-masking.wiki20m_test.npz',
            'bert-noMasking-cls': 'pred/bert-supervised-cls-noMasking.wiki20m_test.npz',
            'bert-masking-entity': 'pred/bert-supervised-entity-masking.wiki20m_test.npz',
            'bert-noMasking-entity': 'pred/bert-supervised-entity-noMasking.wiki20m_test.npz',
            'ptr': 'pred/ptr.wiki20m_test.npz'
        }
    ),
}

import json, tqdm
import numpy as np

for dataset_name, data in datasets.items():
    
    data['rel2id'] = json.load(open(data['rel2id_file']))
    data['test_items'] = [json.loads(line) for line in open(data['test_file'])]
    data['gold'] = np.array([data['rel2id'][item['relation']] for item in data['test_items']])
    if data['nonna']:
        data['gold'] -= 1

    data['pred_scores'] = []
    for score_file in tqdm.tqdm(data['score_files'].values(), desc=dataset_name):
        npz = np.load(score_file, allow_pickle=True)
        npz = npz['arr_0'] if 'keys' in dir(npz) else npz
        data['pred_scores'].append( npz )

wiki20m_test: 100%|██████████| 5/5 [00:01<00:00,  4.69it/s]


In [2]:
# Calculate evaluation scores
from IPython.display import display
from sklearn.metrics import classification_report, precision_recall_fscore_support
import pandas as pd

from opennre.framework import SentenceREDataset

import warnings
warnings.filterwarnings("ignore")

scores = []
for dataset_name, data in datasets.items():
    labels = [l for l in set(data['gold']) if l != data['rel2id']['NA']]
    n = len(data['rel2id'])
    gold = data['gold']
    opennre_dataset = SentenceREDataset(data['test_file'], data['rel2id'], lambda x: x['token'], {})
    
    for (name, fname), pred_scores in zip(data['score_files'].items(), data['pred_scores']):
        pred = pred_scores.argmax(axis=1)
        micro_score = list(precision_recall_fscore_support(gold, pred, average='micro', labels=labels))
        macro_score = list(precision_recall_fscore_support(gold, pred, average='macro', labels=labels))
                
        # Sanity check
        mp = sum(pred[pred!=0] == gold[pred!=0]) / sum(pred!=0)
        mr = sum(pred[pred!=0] == gold[pred!=0]) / sum(gold!=0)
        assert((abs(micro_score[0]-mp) < 1E-6) and (abs(micro_score[1]-mr) < 1E-6))
        
        # OpenNRE Sanity check
        opennre_score = opennre_dataset.eval(list(pred))
        assert(abs(micro_score[0] - opennre_score['micro_p']) < 1E-6)
        assert(abs(micro_score[1] - opennre_score['micro_r']) < 1E-6)

        acc = (gold == pred).mean()
        scores.append( (dataset_name, name, acc, *macro_score[:-1], *micro_score[:-1]))

pd.options.display.float_format = '{:,.2f}'.format
pd.options.display.max_colwidth = 100
cols = 'eval_data model acc macro_p macro_r macro_f1 micro_p micro_r micro_f1'
score_df = pd.DataFrame(scores, columns=cols.split())
score_df.insert(0, 'train_data', 'wiki20m_train')
score_df

2022-05-06 10:35:21,503 - root - INFO - Loaded sentence RE dataset ../OpenNRE/benchmark/wiki20m/wiki20m_test.txt with 137986 lines and 81 relations.
2022-05-06 10:35:22,240 - root - INFO - Evaluation result: {'acc': 0.428905831026336, 'micro_p': 0.6094844821465274, 'micro_r': 0.3145242939378453, 'micro_f1': 0.41492609458875007}.
2022-05-06 10:35:23,089 - root - INFO - Evaluation result: {'acc': 0.706202078471729, 'micro_p': 0.7904100551287643, 'micro_r': 0.6815609059965866, 'micro_f1': 0.731960898020048}.
2022-05-06 10:35:23,832 - root - INFO - Evaluation result: {'acc': 0.4051932804777296, 'micro_p': 0.5962063047601842, 'micro_r': 0.318226961980156, 'micro_f1': 0.4149650455162701}.
2022-05-06 10:35:24,695 - root - INFO - Evaluation result: {'acc': 0.786108735668836, 'micro_p': 0.8423145197288515, 'micro_r': 0.7823911135966984, 'micro_f1': 0.81124775044991}.
2022-05-06 10:35:25,573 - root - INFO - Evaluation result: {'acc': 0.7900511646109025, 'micro_p': 0.8316271525078386, 'micro_r': 

Unnamed: 0,train_data,eval_data,model,acc,macro_p,macro_r,macro_f1,micro_p,micro_r,micro_f1
0,wiki20m_train,wiki20m_test,bert-masking-cls,0.43,0.62,0.3,0.35,0.61,0.31,0.41
1,wiki20m_train,wiki20m_test,bert-noMasking-cls,0.71,0.79,0.68,0.71,0.79,0.68,0.73
2,wiki20m_train,wiki20m_test,bert-masking-entity,0.41,0.69,0.32,0.37,0.6,0.32,0.41
3,wiki20m_train,wiki20m_test,bert-noMasking-entity,0.79,0.84,0.78,0.79,0.84,0.78,0.81
4,wiki20m_train,wiki20m_test,ptr,0.79,0.83,0.82,0.81,0.83,0.81,0.82


In [3]:
nice_names = {
    'bert-masking-cls': 'BERT-M-CLS',
    'bert-noMasking-cls': 'BERT-CLS', 
    'bert-masking-entity': 'BERT-M-ENT',
    'bert-noMasking-entity': 'BERT-ENT', 
    'ptr': 'PTR'
}

In [4]:
score_styled = ( 
    score_df.replace({'model':nice_names}).set_index('model').sort_values(by='micro_f1')
    .replace({'model':nice_names}).drop(columns=['train_data', 'eval_data'])
    .style
    .apply(lambda col: ['font-weight:bold;text-decoration:underline' if x==col.max() else '' for x in col])
    .format('{:.2f}')
    .set_table_styles([{
        'selector': 'tbody tr th', 'props': 'border-top: 1px solid black; white-space: nowrap;'}])
    .bar(vmax=1, color='lightblue', props="width: 8em; text-align:left;")
)
display(score_styled)
print( 
    score_styled.data.reset_index()
    .apply(lambda x: 
           x.map(lambda x: '%.2f'%x).mask(x==x.max(), '\\underline{%.2f}'%x.max()) if x.dtype==float else x )
    .to_latex(index=False, escape=False)
)

Unnamed: 0_level_0,acc,macro_p,macro_r,macro_f1,micro_p,micro_r,micro_f1
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
BERT-M-CLS,0.43,0.62,0.3,0.35,0.61,0.31,0.41
BERT-M-ENT,0.41,0.69,0.32,0.37,0.6,0.32,0.41
BERT-CLS,0.71,0.79,0.68,0.71,0.79,0.68,0.73
BERT-ENT,0.79,0.84,0.78,0.79,0.84,0.78,0.81
PTR,0.79,0.83,0.82,0.81,0.83,0.81,0.82


\begin{tabular}{llllllll}
\toprule
     model &              acc &          macro_p &          macro_r &         macro_f1 &          micro_p &          micro_r &         micro_f1 \\
\midrule
BERT-M-CLS &             0.43 &             0.62 &             0.30 &             0.35 &             0.61 &             0.31 &             0.41 \\
BERT-M-ENT &             0.41 &             0.69 &             0.32 &             0.37 &             0.60 &             0.32 &             0.41 \\
  BERT-CLS &             0.71 &             0.79 &             0.68 &             0.71 &             0.79 &             0.68 &             0.73 \\
  BERT-ENT &             0.79 & \underline{0.84} &             0.78 &             0.79 & \underline{0.84} &             0.78 &             0.81 \\
       PTR & \underline{0.79} &             0.83 & \underline{0.82} & \underline{0.81} &             0.83 & \underline{0.81} & \underline{0.82} \\
\bottomrule
\end{tabular}



In [5]:
# Calculate confusion matrices
from sklearn.metrics import confusion_matrix
from IPython.display import display
import pandas as pd

ds_conf = {}
for dataset_name, data in datasets.items():
    n = len(data['rel2id'])
    gold = data['gold']
    
    for (name, fname), pred_scores in zip(data['score_files'].items(), data['pred_scores']):
        conf = pd.DataFrame(
            confusion_matrix(gold, pred_scores.argmax(axis=1), labels=list(range(n))),
            index  =pd.Series(list(data['rel2id'])).rename('true'), 
            columns=pd.Series(list(data['rel2id'])).rename('pred'))
        conf = conf.stack().astype('Int64').rename('confused').sort_values(ascending=False)
        ds_conf[name] = conf

In [6]:
# Show worst relations
import numpy as np
pscores = {}
for model in ['ptr', 'bert-noMasking-entity']:
    gold = ds_conf[model].groupby(level=0).sum()
    pred = ds_conf[model].groupby(level=1).sum()
    good = (ds_conf[model].unstack() * np.eye(81).astype(int)).stack().groupby(level=1).sum()
    pscore = pd.DataFrame({' Precision': good / pred, ' Recall': good / gold})
    pscore['F1'] = 2 / (1/pscore[' Precision'] + 1/pscore[' Recall'])
    pscore.index.name = '  Relation'
    pscores[model] = pscore.sort_values('F1').head().reset_index()
pscores = pd.concat(pscores, names=['Model']).T.stack().reorder_levels([1,0]).unstack()
# pscores = pscores[pscores.columns[::-1]]
display( pscores )
print( pscores.to_latex(index=False) )

Model,ptr,ptr,ptr,ptr,bert-noMasking-entity,bert-noMasking-entity,bert-noMasking-entity,bert-noMasking-entity
Unnamed: 0_level_1,Relation,Precision,Recall,F1,Relation,Precision,Recall,F1
0,residence,0.63,0.26,0.36,after a work by,0.91,0.14,0.24
1,screenwriter,0.3,0.52,0.38,residence,0.76,0.17,0.28
2,part of,0.48,0.34,0.4,part of,0.46,0.23,0.31
3,after a work by,0.85,0.32,0.47,followed by,0.96,0.26,0.41
4,work location,0.45,0.54,0.49,screenwriter,0.39,0.5,0.44


\begin{tabular}{llllllll}
\toprule
            ptr & \multicolumn{4}{l}{bert-noMasking-entity} \\
       Relation &  Precision &  Recall &   F1 &              Relation &  Precision &  Recall &   F1 \\
\midrule
      residence &       0.63 &    0.26 & 0.36 &       after a work by &       0.91 &    0.14 & 0.24 \\
   screenwriter &       0.30 &    0.52 & 0.38 &             residence &       0.76 &    0.17 & 0.28 \\
        part of &       0.48 &    0.34 & 0.40 &               part of &       0.46 &    0.23 & 0.31 \\
after a work by &       0.85 &    0.32 & 0.47 &           followed by &       0.96 &    0.26 & 0.41 \\
  work location &       0.45 &    0.54 & 0.49 &          screenwriter &       0.39 &    0.50 & 0.44 \\
\bottomrule
\end{tabular}



# Confusion direct

In [7]:
allconf = pd.DataFrame(ds_conf)
allconf.columns.name = 'model'
allconf = allconf.stack().rename('confused')
allconf = allconf[allconf>0].dropna()
allconf = allconf.to_frame()
allconf

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,confused
true,pred,model,Unnamed: 3_level_1
,,bert-masking-cls,26564
,,bert-noMasking-cls,26762
,,bert-masking-entity,22908
,,bert-noMasking-entity,27331
,,ptr,24936
...,...,...,...
work location,work location,bert-masking-cls,51
work location,work location,bert-noMasking-cls,911
work location,work location,bert-masking-entity,48
work location,work location,bert-noMasking-entity,859


In [8]:
%matplotlib inline
# Create confusion network for PTR predictions
model = 'ptr'
weights = (allconf.loc[:, :, model]
    .reset_index().rename(columns={'true':'source', 'pred':'target', 'confused':'weight'})
)
size = weights[(weights['source'] == weights['target'])].set_index(['source','target'])['weight']

weights = weights[(weights['source'] != 'NA') & (weights['target'] != 'NA')]
weights = weights[(weights['source'] != weights['target'])]
weights = weights.sort_values('weight')[::-1].head(100)

adj = weights.set_index(['source','target'])['weight'].unstack()
sym = adj.index.union(adj.columns)
adj = adj.reindex(sym).T.reindex(sym).fillna(0)
adj

from d3graph import d3graph
import logging
logging.getLogger().setLevel(40)
d3 = d3graph()
d3.graph(adj)
d3.set_edge_properties(directed=True)
d3.set_node_properties(
    color='#ddddff',
    size=10, 
    edge_color='#000000', 
    cmap='Set2'
)
filepath = './confusion-network.html'
d3.show(filepath=filepath, figsize=(1000, 700), title='Confusion Network')
# Improve graph
import json
lines = open(filepath).readlines()
for i, line in enumerate(lines):
    if 'graph = ' in line:
        graph = eval(line.split('=')[1])
        graph['links'] = [
            {
                k:v.replace('_',' ') if type(v) == str else v
                for k,v in l.items()
            } 
            for l in graph['links']
        ]
        graph['nodes'] = [
            {
                k:v.replace('_',' ') if type(v) == str else v
                for k,v in l.items()
            } 
            for l in graph['nodes']
        ]
        lines[i] = 'graph = ' + json.dumps(graph) + '\n'
with open(filepath, 'w') as fw:
    print(''.join(lines), file=fw)

from IPython.display import IFrame
display( IFrame(filepath, '100%', '600px') )

[d3graph] INFO> Keep only edges with weight>0
[d3graph] INFO> Keep only edges with weight>0
[d3graph] INFO> Slider range is set to [16, 515]
[d3graph] INFO> Write to path: [/project/rekg/notebooks/confusion-network.html]
[d3graph] INFO> File already exists and will be overwritten: [/project/rekg/notebooks/confusion-network.html]


Link threshold 16 [16                  ] 515



www-browser: /sw/arch/Debian10/EB_production/2020/software/ncurses/6.2-GCCcore-9.3.0/lib/libncursesw.so.6: version `NCURSESW6_5.1.20000708' not found (required by www-browser)
lynx: /sw/arch/Debian10/EB_production/2020/software/ncurses/6.2-GCCcore-9.3.0/lib/libncursesw.so.6: version `NCURSESW6_5.1.20000708' not found (required by lynx)


# Semantic Confusion Analysis

In [9]:
# Load ontology statements about Wikidata properties

def transitive_closure(pairs):
    lookup = {}
    for a,b in pairs:
        lookup.setdefault(a, set()).add(b)
    def close(x, start=None):
        for l in lookup.get(x, []):
            if l != start:
                yield l
                yield from close(l, start=start or x)
    return set((a,b) for a in lookup for b in close(a))

def siblings(pairs):
    lookup = {}
    for a,b in pairs:
        lookup.setdefault(a, set()).add(b)
    return set((a,b) for a in lookup for b in lookup if lookup[a] & lookup[b])

plabel = pd.read_csv('../wikidata-prop-label.tsv', sep='\t').set_index('item')['itemLabel']

# Inverse
invdf = pd.read_csv('../wikidata-prop-inverse.tsv', sep='\t')
invdf['p1Label'] = invdf['p1'].replace(plabel)
invdf['p2Label'] = invdf['p2'].replace(plabel)
inv = set(invdf[['p1Label', 'p2Label']].apply(tuple, axis=1))
inv = set(s for a,b in inv for s in [(a,b), (b,a)])

# Sub / Super-properties & Siblings
subdf = pd.read_csv('../wikidata-subproperty.tsv', sep='\t')
subdf['p1Label'] = subdf['p1'].replace(plabel)
subdf['p2Label'] = subdf['p2'].replace(plabel)
sub = set(subdf[['p1Label', 'p2Label']].apply(tuple, axis=1))
sub = transitive_closure(sub)
sup = [(b,a) for a,b in sub]
sib = siblings(sub)

# See Also
seedf = pd.read_csv('../wikidata-prop-seealso.tsv', sep='\t')
seedf['p1Label'] = seedf['p1'].replace(plabel)
seedf['p2Label'] = seedf['p2'].replace(plabel)
see = set(seedf[['p1Label', 'p2Label']].apply(tuple, axis=1))

# Range and Domain
rddf = pd.read_csv('../wikidata-range-domain.csv')
rddf['pLabel'] = rddf['p'].replace(plabel)
p_range, p_domain = {}, {}
for _, p, e, c in rddf[['pLabel','edgeLabel','cLabel']].itertuples():
    if e == 'range':
        p_range.setdefault(p, set()).add( c )
    else:
        p_domain.setdefault(p, set()).add( c )
ps = set(p_range) | set(p_domain)
rangeDomainMatch = set()
rangeMatch = set()
domainMatch = set()
for p1 in tqdm.tqdm(ps, desc='Loading Range and Domain'):
    ran1, dom1 = p_range.get(p1, set()), p_domain.get(p1, set())
    for p2 in ps:
        ran2, dom2 = p_range.get(p2, set()), p_domain.get(p2, set())
        ran, dom = (ran1 & ran2), (dom1 & dom2)
        if ran and dom:
            rangeDomainMatch.add((p1, p2))
        else:
            if ran:
                rangeMatch.add((p1, p2))
            if dom:
                domainMatch.add((p1, p2))
            
order = [
    ('inverse', inv),
    ('subProp', sub),
    ('superProp', sup),
    ('sibling', sib),
    ('seeAlso', see),
    ('rangeDomainMatch', rangeDomainMatch),
    ('onlyRangeMatch', rangeMatch),
    ('onlyDomainMatch', domainMatch),
]

# Only use pairs that are ever confused
# allpairs = set(p for _, pairs in order for p in pairs)
allpairs = allconf['confused'].unstack().index

ont = pd.Series({
    (a,b):[k for k, ps in order if (a,b) in ps]
    for a,b in tqdm.tqdm(allpairs, 'Making property-pair analysis categories')
}, name = 'category')
ont.index.names = ['true', 'pred']
ont.to_csv('../wikidata-prop-pair-analysis.csv')
categories, _ = zip(*order)
ont

Loading Range and Domain: 100%|██████████| 6300/6300 [00:25<00:00, 244.17it/s]
Making property-pair analysis categories: 100%|██████████| 1873/1873 [00:00<00:00, 47157.85it/s]


true           pred                   
NA             NA                                                           []
               after a work by                                              []
               applies to jurisdiction                                      []
               architect                                                    []
               characters                                                   []
                                                          ...                 
work location  religion                                      [onlyDomainMatch]
               residence                  [sibling, seeAlso, rangeDomainMatch]
               said to be the same as                                       []
               subsidiary                                                   []
               work location                       [sibling, rangeDomainMatch]
Name: category, Length: 1873, dtype: object

In [10]:
# Statistics about categories
for name, pairs in order:
    pairs80 = set(p for p in pairs if all(x in set(allconf.index.levels[0]) for x in p))
    ps80 = set(x for p in pairs80 for x in p)
    print(f"{name:20s}: {len(ps80):4d} unique relations; {len(pairs80):6d} pairs")

inverse             :   13 unique relations;     14 pairs
subProp             :   34 unique relations;     44 pairs
superProp           :   34 unique relations;     44 pairs
sibling             :   43 unique relations;    529 pairs
seeAlso             :   48 unique relations;     77 pairs
rangeDomainMatch    :   58 unique relations;    178 pairs
onlyRangeMatch      :   53 unique relations;    462 pairs
onlyDomainMatch     :   58 unique relations;   1074 pairs


In [11]:
# For illustrative purposes: relations that share an ancestor in the property hierarchy
def sibling_track(pairs):
    lookup = {}
    for a,b in pairs:
        lookup.setdefault(a, set()).add(b)
    return {(a,b):(lookup[a] & lookup[b]) for a in lookup for b in lookup if lookup[a] & lookup[b]}

super_siblings = {}
for p,s in sorted(sibling_track(sub).items()):
    for sup in s:
        for x in p:
            if x in set(conf.index.levels[0]):
                super_siblings.setdefault(sup, set()).add(x)
for sup, sibs in super_siblings.items():
    if len(sibs) > 1:
        print('%30s'%sup, sibs)

                      has part {'heritage designation', 'contains administrative territorial entity'}
                   instance of {'heritage designation', 'taxon rank'}
                       part of {'military branch', 'member of political party', 'located on terrain feature', 'member of', 'country of citizenship', 'mountain range', 'country of origin', 'country', 'constellation', 'sports season of league or competition', 'participant of', 'located in the administrative territorial entity', 'league'}
            significant person {'performer', 'after a work by', 'winner', 'head of government', 'sibling', 'father', 'child', 'participant', 'participating team', 'successful candidate', 'spouse', 'director', 'mother', 'screenwriter'}
                   affiliation {'member of political party', 'member of'}
                      location {'residence', 'work location', 'headquarters location', 'applies to jurisdiction', 'located on terrain feature', 'location of formation', 'country of 

In [12]:
analysis = allconf.join(ont)
analysis.loc[analysis.index.get_level_values('pred') == 'NA', 'category'] = 'NA'
eye = analysis.index.get_level_values('true') == analysis.index.get_level_values('pred')
analysis.loc[eye, 'category'] = 'CORRECT'
analysis = analysis.fillna('UNK')

In [13]:
model = 'ptr'
conf = analysis.loc[:,:, model].explode('category').fillna('UNK').reset_index()
conf = conf[conf['true'] != conf['pred']].set_index(['true', 'pred'])
(conf
     .sort_values('confused', ascending=False)
     .groupby('category').apply(lambda x: x.head(1)).droplevel(0).reset_index().set_index('category')
     .style.set_caption(f'Most confused relation pair per category, {model} model')
)

Unnamed: 0_level_0,true,pred,confused
category,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
,residence,,1158
UNK,,participant of,1562
inverse,followed by,follows,55
onlyDomainMatch,location of formation,headquarters location,247
onlyRangeMatch,screenwriter,director,117
rangeDomainMatch,residence,work location,515
seeAlso,residence,work location,515
sibling,residence,work location,515
subProp,league,participant of,80
superProp,language of work or name,original language of film or TV show,92


In [14]:
n = 3
model = 'ptr'
conf = analysis.loc[:,:, model].explode('category').fillna('UNK').reset_index()
conf = conf[conf['true'] != conf['pred']].set_index(['true', 'pred'])
(conf
     .sort_values('confused', ascending=False)
     .groupby('category').apply(lambda x: x.head(n)).droplevel(0).reset_index().set_index(['category', 'true'])
     .style.set_caption(f'Top-{n} confused relation pair per category, {model} model') 
)

Unnamed: 0_level_0,Unnamed: 1_level_0,pred,confused
category,true,Unnamed: 2_level_1,Unnamed: 3_level_1
,residence,,1158
,said to be the same as,,846
,after a work by,,711
UNK,,participant of,1562
UNK,,screenwriter,970
UNK,,operator,940
inverse,followed by,follows,55
inverse,follows,followed by,37
inverse,child,mother,32
onlyDomainMatch,location of formation,headquarters location,247


In [15]:
model = 'ptr'
conf = analysis.loc[:,:, model].explode('category').fillna('UNK').reset_index()
conf = conf[conf['true'] != conf['pred']].set_index(['true', 'pred'])
(conf
    .sort_values('confused', ascending=False).head(50)
    .style.set_caption(f'Most confused relation pairs, {model} model') 
)

Unnamed: 0_level_0,Unnamed: 1_level_0,confused,category
true,pred,Unnamed: 2_level_1,Unnamed: 3_level_1
,participant of,1562,UNK
residence,,1158,
,screenwriter,970,UNK
,operator,940,UNK
said to be the same as,,846,
after a work by,,711,
head of government,,690,
,main subject,623,UNK
characters,,612,
applies to jurisdiction,,602,


In [16]:
conf = analysis.fillna('UNK').rename(index=nice_names).reset_index()
conf = conf[conf['true'] != conf['pred']]
conf = conf[(conf['category'].astype(bool)) & (conf['category'] != 'NA')]
conf = pd.concat([(
    conf[conf['model'] == model].sort_values('confused', ascending=False).head(5)
) for model in ['PTR', 'BERT-ENT']])[['model','true','pred','confused', 'category']]
conf['category'] = conf['category'].map(lambda x: ' / '.join(x))
display( conf )

print( conf.to_latex(index=False, escape=False) )
#     .style.set_caption(f'Most confused relation pairs, {model} model') 

Unnamed: 0,model,true,pred,confused,category
4400,PTR,residence,work location,515,sibling / seeAlso / rangeDomainMatch
388,PTR,after a work by,screenwriter,367,sibling / rangeDomainMatch
2497,PTR,location of formation,headquarters location,247,sibling / seeAlso / onlyDomainMatch
4154,PTR,publisher,developer,178,rangeDomainMatch
4889,PTR,tributary,mouth of the watercourse,158,seeAlso / rangeDomainMatch
4399,BERT-ENT,residence,work location,767,sibling / seeAlso / rangeDomainMatch
387,BERT-ENT,after a work by,screenwriter,395,sibling / rangeDomainMatch
1338,BERT-ENT,followed by,follows,350,inverse
4153,BERT-ENT,publisher,developer,228,rangeDomainMatch
1793,BERT-ENT,headquarters location,location of formation,224,sibling / seeAlso / onlyDomainMatch


\begin{tabular}{lllrl}
\toprule
   model &                  true &                     pred &  confused &                             category \\
\midrule
     PTR &             residence &            work location &       515 & sibling / seeAlso / rangeDomainMatch \\
     PTR &       after a work by &             screenwriter &       367 &           sibling / rangeDomainMatch \\
     PTR & location of formation &    headquarters location &       247 &  sibling / seeAlso / onlyDomainMatch \\
     PTR &             publisher &                developer &       178 &                     rangeDomainMatch \\
     PTR &             tributary & mouth of the watercourse &       158 &           seeAlso / rangeDomainMatch \\
BERT-ENT &             residence &            work location &       767 & sibling / seeAlso / rangeDomainMatch \\
BERT-ENT &       after a work by &             screenwriter &       395 &           sibling / rangeDomainMatch \\
BERT-ENT &           followed by &             

In [17]:
per_gold = (
    analysis.explode('category').groupby(['model','true', 'category'])['confused'].sum()
    .unstack().fillna(0).astype('int')
)
per_gold = per_gold[sorted(per_gold.columns, key=lambda x: categories.index(x) if x in categories else -1)]
per_gold

Unnamed: 0_level_0,category,CORRECT,NA,inverse,subProp,superProp,sibling,seeAlso,rangeDomainMatch,onlyRangeMatch,onlyDomainMatch
model,true,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
bert-masking-cls,,26564,0,0,0,0,0,0,0,0,0
bert-masking-cls,after a work by,0,1496,0,0,0,20,0,18,7,4
bert-masking-cls,applies to jurisdiction,1932,2913,0,0,0,10,0,51,8,177
bert-masking-cls,architect,232,854,0,0,0,0,0,5,1,28
bert-masking-cls,characters,458,2330,0,0,0,0,0,0,13,10
...,...,...,...,...,...,...,...,...,...,...,...
ptr,taxon rank,1901,5,0,0,0,0,0,0,0,0
ptr,tributary,2373,92,0,0,0,0,158,158,0,0
ptr,voice type,699,0,0,0,0,0,0,0,0,1
ptr,winner,927,57,0,24,3,56,27,0,7,122


In [18]:
caterrors = (per_gold.groupby(level='model').sum() / 137986 ).rename(index=nice_names)
score_df2 = score_styled.data.copy().join((1-caterrors[['CORRECT', 'NA']].sum(axis=1).rename('Non-NA Errors')))
score_df2.insert(len(score_df2.columns)-1, 'Acc.', score_df2['acc'])
score_df2.drop(columns=['acc'], inplace=True)
display( score_df2 )
print(
    score_df2
    .apply(lambda x: x.map(lambda x: '%.2f'%x).mask(x==x.max(), '\\underline{%.2f}'%x.max()) )
    .reset_index().to_latex(index=False, escape=False)
)

Unnamed: 0_level_0,macro_p,macro_r,macro_f1,micro_p,micro_r,micro_f1,Acc.,Non-NA Errors
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
BERT-M-CLS,0.62,0.3,0.35,0.61,0.31,0.41,0.43,0.15
BERT-M-ENT,0.69,0.32,0.37,0.6,0.32,0.41,0.41,0.16
BERT-CLS,0.79,0.68,0.71,0.79,0.68,0.73,0.71,0.14
BERT-ENT,0.84,0.78,0.79,0.84,0.78,0.81,0.79,0.11
PTR,0.83,0.82,0.81,0.83,0.81,0.82,0.79,0.12


\begin{tabular}{lllllllll}
\toprule
     model &          macro_p &          macro_r &         macro_f1 &          micro_p &          micro_r &         micro_f1 &             Acc. &    Non-NA Errors \\
\midrule
BERT-M-CLS &             0.62 &             0.30 &             0.35 &             0.61 &             0.31 &             0.41 &             0.43 &             0.15 \\
BERT-M-ENT &             0.69 &             0.32 &             0.37 &             0.60 &             0.32 &             0.41 &             0.41 & \underline{0.16} \\
  BERT-CLS &             0.79 &             0.68 &             0.71 &             0.79 &             0.68 &             0.73 &             0.71 &             0.14 \\
  BERT-ENT & \underline{0.84} &             0.78 &             0.79 & \underline{0.84} &             0.78 &             0.81 &             0.79 &             0.11 \\
       PTR &             0.83 & \underline{0.82} & \underline{0.81} &             0.83 & \underline{0.81} & \underline{0.82} 

## Figures

In [19]:
ax = (
    (per_gold.groupby(level='model').sum() / 137986 * 100 ).rename(index=nice_names)
    .drop(columns=['CORRECT', 'NA'])
    .rename_axis('Model')[::-1].T[::-1]
    .plot.barh(figsize=(9,6))
)
ax.set(ylabel='Relation-pair category', xlabel='Model errors per category (% of predictions)')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], title='Model', frameon=True, facecolor='white', fancybox=True)

import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})
import matplotlib.pyplot as plt
plt.style.use('seaborn')
plt.tight_layout()
plt.savefig('semantic-barchart.pgf')

<Figure size 576x396 with 0 Axes>

In [23]:
model = 'ptr'
flat = analysis.explode('category').reset_index()
ptr_unk = (
    flat[flat['category'].isna() & (flat['model'] == model)]
    .sort_values('confused')[::-1]
)
(
    ptr_unk[ptr_unk['true'] != 'NA'].head(50).style
    .set_caption('Unknown errors, PTR class')
)

Unnamed: 0,true,pred,model,confused,category
5203,said to be the same as,platform,ptr,188,
5170,said to be the same as,mother,ptr,43,
5198,said to be the same as,place served by transport hub,ptr,33,
4380,participant,country,ptr,23,
1572,field of work,sport,ptr,17,
2333,language of work or name,country of origin,ptr,15,
5185,said to be the same as,operating system,ptr,14,
5138,said to be the same as,has part,ptr,13,
1352,distributor,developer,ptr,12,
2926,location,participant,ptr,12,
