In [13]:
def get_entities(data, preds):
    data_entities = []

    for cur_d in data:
        cur_dict = {}
        for ent in cur_d['entities']:
            cur_dict[(ent['start'], ent['end'])] = ent['type']
        data_entities.append(cur_dict)

    pred_entities = []

    for cur_p in preds:
        cur_dict = {}
        for ent in cur_p['entities']:
            cur_dict[(ent['start'], ent['end'])] = ent['type']
        pred_entities.append(cur_dict)
    return data_entities, pred_entities


def analyze_entities(data_entities, pred_entities):
    # cb: correct boundary
    # wb: wrong boundary
    # ct: correct type
    # wt: wrong type

    correct = []
    missing = []
    extra = []
    cb_wt = []
    wb_ct = []
    wb_wt = []

    for doc_id, (data_ent, pred_ent) in enumerate(zip(data_entities, pred_entities)):

        cbs = set(data_ent.keys()) & set(pred_ent.keys())
        for cb in cbs:
            if data_ent[cb] == pred_ent[cb]:
                correct.append([doc_id, cb, cb, data_ent[cb], pred_ent[cb]])
            else:
                cb_wt.append({'doc_id': doc_id, 'data_cb': cb, 'true_type': data_ent[cb], 'pred_type': pred_ent[cb]})

        data_wbs = set(data_ent.keys()) - cbs
        pred_wbs = set(pred_ent.keys()) - cbs

        used_cb = set()

        for b_s, b_e in pred_wbs:
            flag = True
            for d_s, d_e in data_wbs:
                if min([d_e, b_e]) > max([d_s, b_s]):
                    if pred_ent[(b_s, b_e)] == data_ent[(d_s, d_e)]:
                        wb_ct.append({'doc_id': doc_id, 'pre_cb': (b_s, b_e), 'data_cb': (d_s, d_e), 'true_type': data_ent[(d_s, d_e)]})
                    else:
                        wb_wt.append({'doc_id': doc_id, 'pre_cb': (b_s, b_e), 'data_cb': (d_s, d_e), 'true_type': data_ent[(d_s, d_e)], 'pred_type': pred_ent[(b_s, b_e)]})
                    flag = False
                    used_cb.add((d_s, d_e))
                    break
            if flag:
                extra.append({'doc_id': doc_id, 'pred_cb': (b_s, b_e), 'pred_type': pred_ent[(b_s, b_e)]})
        
        for cb in data_wbs - used_cb:
            missing.append({'doc_id': doc_id, 'data_cb': cb, 'true_type': data_ent[cb]})
        
    return {'correct': correct, 'missing': missing, 'extra': extra, 'cb_wt': cb_wt, 'wb_ct': wb_ct, 'wb_wt': wb_wt}


In [15]:
import json

# data = json.load(open('../data/datasets/scierc/scierc_test.json'))
# preds = json.load(open('../data/log/scierc_train/baseline/predictions_valid_epoch_20.json'))

# data = json.load(open('../data/datasets/scierc/scierc_test_new_doc.json'))
# preds = json.load(open('../data/log/scierc_train/doc_baseline/predictions_valid_epoch_27.json'))
# preds = json.load(open('../data/log/scierc_train/2022-06-24_22:29:36.379915/predictions_valid_epoch_30.json'))


# data = json.load(open('../data/datasets/conll04/conll04_test.json'))
# preds = json.load(open('../data/log/conll04_train/baseline/predictions_valid_epoch_20.json'))

data = json.load(open('../data/datasets/ontonotes/doc_test.json'))
preds = json.load(open('../data/log/ontonotes_train/2022-07-04_14:36:41.174231/predictions_valid_epoch_30.json'))

data_entities, pred_entities = get_entities(data, preds)
results = analyze_entities(data_entities, pred_entities)

print('correct', len(results['correct']))
print('missing', len(results['missing']))
print('extra', len(results['extra']))
print('cb_wt', len(results['cb_wt']))
print('wb_ct', len(results['wb_ct']))
print('wb_wt', len(results['wb_wt']))

tp = len(results['correct'])
fn = len(results['missing']) + len(results['cb_wt']) + len(results['wb_ct']) + len(results['wb_wt'])
fp = len(results['extra']) + len(results['cb_wt']) + len(results['wb_ct']) + len(results['wb_wt'])

# tp = len(correct) + len(cb_wt)
# fn = len(missing) + len(wb_ct) + len(wb_wt)
# fp = len(extra) + len(wb_ct) + len(wb_wt)

precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * tp / (2 * tp + fp + fn)

print('results -------->', tp, fn, fp, precision, recall, f1)

json.dump(results['extra'], open('analysisResults/extra.json', 'w'))
json.dump(results['missing'], open('analysisResults/missing.json', 'w'))
json.dump(results['wb_ct'], open('analysisResults/wb_ct.json', 'w'))
json.dump(results['wb_wt'], open('analysisResults/wb_wt.json', 'w'))
json.dump(results['cb_wt'], open('analysisResults/cb_wt.json', 'w'))

correct 10276
missing 350
extra 473
cb_wt 228
wb_ct 327
wb_wt 131
results --------> 10276 1036 1159 0.8986445124617403 0.9084158415841584 0.903503758737416


In [16]:
missing_stat = {}
for rec in results['missing']:
    t_type = rec['true_type']
    missing_stat[t_type] = missing_stat.get(t_type, 0) + 1

print('missing stat -------->', missing_stat)

extra_stat = {}
for rec in results['extra']:
    t_type = rec['pred_type']
    extra_stat[t_type] = extra_stat.get(t_type, 0) + 1

print('extra stat -------->', extra_stat)

wb_ct_stat = {}
for rec in results['wb_ct']:
    t_type = rec['true_type']
    wb_ct_stat[t_type] = wb_ct_stat.get(t_type, 0) + 1

print('wrong boundary correct type stat -------->', wb_ct_stat)

cb_wt_stat = {}
for rec in results['cb_wt']:
    t_type = tuple(sorted([rec['true_type'], rec['pred_type']]))
    cb_wt_stat[t_type] = cb_wt_stat.get(t_type, 0) + 1

print('Correct boundary wrong type stat -------->', cb_wt_stat)

wb_wt_stat = {}
for rec in results['wb_wt']:
    t_type = tuple(sorted([rec['true_type'], rec['pred_type']]))
    wb_wt_stat[t_type] = wb_wt_stat.get(t_type, 0) + 1

print('Wrong boundary wrong type stat -------->', wb_wt_stat)

missing stat --------> {'ORG': 52, 'DATE': 72, 'TIME': 32, 'CARDINAL': 69, 'GPE': 10, 'QUANTITY': 3, 'PERSON': 44, 'ORDINAL': 10, 'NORP': 8, 'LOC': 7, 'WORK_OF_ART': 21, 'MONEY': 3, 'PRODUCT': 8, 'FAC': 3, 'LAW': 3, 'PERCENT': 2, 'LANGUAGE': 1, 'EVENT': 2}
extra stat --------> {'TIME': 32, 'CARDINAL': 96, 'QUANTITY': 12, 'DATE': 123, 'GPE': 12, 'ORDINAL': 28, 'PERSON': 33, 'ORG': 56, 'NORP': 16, 'LOC': 6, 'WORK_OF_ART': 14, 'PRODUCT': 8, 'FAC': 1, 'MONEY': 15, 'LAW': 4, 'PERCENT': 13, 'EVENT': 4}
wrong boundary correct type stat --------> {'TIME': 25, 'EVENT': 6, 'GPE': 18, 'FAC': 5, 'DATE': 86, 'ORG': 45, 'CARDINAL': 40, 'QUANTITY': 11, 'PERSON': 20, 'NORP': 3, 'LAW': 3, 'PERCENT': 27, 'MONEY': 23, 'PRODUCT': 3, 'WORK_OF_ART': 4, 'LOC': 8}
Correct boundary wrong type stat --------> {('CARDINAL', 'PERSON'): 1, ('CARDINAL', 'DATE'): 6, ('GPE', 'ORG'): 21, ('CARDINAL', 'TIME'): 1, ('GPE', 'LOC'): 17, ('ORG', 'WORK_OF_ART'): 32, ('DATE', 'TIME'): 7, ('LOC', 'ORG'): 6, ('EVENT', 'WORK_OF_A

In [12]:
import json
from sklearn.metrics import precision_recall_fscore_support
from sklearn import metrics

# probs = json.load(open('../data/log/scierc_train/2022-06-24_23:21:02.651449/predictions_valid_epoch_30.json_b_probs'))
probs = json.load(open('../data/log/ontonotes_train/2022-07-04_14:36:41.174231/predictions_valid_epoch_29.json_b_probs'))

all_probs = []
for prob in probs:
    all_probs.extend(prob)

probs = [p[0] for p in all_probs]
preds = [int(prob[0] > 0.5) for prob in all_probs]
gts = [0 if prob[1] == 0 else 1 for prob in all_probs]

print(precision_recall_fscore_support(gts, preds, average='binary'))

fpr, tpr, thresholds = metrics.roc_curve(gts, probs)
print(metrics.auc(fpr, tpr))

(0.9140801109762442, 0.9373221906116643, 0.9255552629268721, None)
0.9970328247429978
