In [1]:
from scripts.convert_brat_annotations_to_json import load_jsonl, used_entities
def span_match(span_1, span_2) :
    sa, ea = span_1
    sb, eb = span_2
    iou = (min(ea, eb) - max(sa, sb)) / (max(eb, ea) - min(sa, sb))
    return iou

In [2]:
ents = used_entities + ['None']

In [3]:
annotated_data = load_jsonl('../model_data/all_data_propagated.jsonl')[:440]
original_data = load_jsonl('../model_data/all_data_original_propagated.jsonl')[:440]

In [4]:
from collections import defaultdict
for d in original_data :
    span_to_cluster_ids = defaultdict(list)
    for cluster_name in d['coref']:
        for span in d['coref'][cluster_name]:
            span_to_cluster_ids[tuple(span)].append(cluster_name)

    span_to_cluster_ids = {span: set(sorted(v)) for span, v in span_to_cluster_ids.items()}
    d['span_to_cluster'] = span_to_cluster_ids
    
for d in annotated_data :
    span_to_cluster_ids = defaultdict(list)
    for cluster_name in d['coref']:
        for span in d['coref'][cluster_name]:
            span_to_cluster_ids[tuple(span)].append(cluster_name)

    span_to_cluster_ids = {span: set(sorted(v)) for span, v in span_to_cluster_ids.items()}
    d['span_to_cluster'] = span_to_cluster_ids

In [10]:
import numpy as np

def matching(spans_1, spans_2) :
    match = np.zeros((len(spans_1), len(spans_2)))
    match_types = {k + '_1':{l + '_2':0 for l in ents} for k in ents}
    for i, s1 in enumerate(spans_1) :
        for j, s2 in enumerate(spans_2) :
            if (s1[0], s1[1]) == (s2[0], s2[1]) :
                match[i, j] = 1
                
    assert all(x <= 1 for x in match.sum(0)), breakpoint()
    assert all(x <= 1 for x in match.sum(1)), breakpoint()
    for i, s1 in enumerate(spans_1) :
        matched = [j for j, x in enumerate(match[i]) if x == 1]
        assert len(matched) <= 1
        if len(matched) == 0 :
            match_types[s1[2] + '_1']['None_2'] += 1 / len(spans_1)
        else :
            match_types[s1[2] + '_1'][spans_2[matched[0]][2] + '_2'] += 1 
            
            
    for i, s2 in enumerate(spans_2) :
        matched = [j for j, x in enumerate(match[:, i]) if x == 1]
        assert len(matched) <= 1
        if len(matched) == 0 :
            match_types['None_1'][s2[2] + '_2'] += 1 
            
    s = sum(y for k, v in match_types.items() for x, y in v.items())
    for k, v in match_types.items() :
        for x, y in v.items() :
            match_types[k][x] /= s
    return match_types

In [11]:
from tqdm import tqdm
match_types = {k + '_1':{l + '_2':0 for l in ents} for k in ents}
for i in tqdm(range(len(original_data))) :
    match_types_doc = matching(original_data[i]['ner'], annotated_data[i]['ner'])
    for k in match_types :
        for l in match_types[k] :
            match_types[k][l] += match_types_doc[k][l]

100%|██████████| 440/440 [00:44<00:00,  9.92it/s]


In [12]:
import pandas as pd
match_types = pd.DataFrame(match_types).loc[[e + '_2' for e in ents], [e + '_1' for e in ents]].T / len(annotated_data)

In [15]:
match_types.sum()

Material_2    0.069475
Metric_2      0.095343
Task_2        0.210575
Method_2      0.624067
None_2        0.000540
dtype: float64

In [14]:
print(match_types.rename(index=lambda x : x.replace('_1', ''), columns=lambda x : x.replace('_2', '')).to_latex(float_format=lambda x : "{:0.2f}".format(x*100)))

\begin{tabular}{lrrrrr}
\toprule
{} &  Material &  Metric &  Task &  Method &  None \\
\midrule
Material &      3.55 &    0.01 &  0.07 &    0.16 &  0.03 \\
Metric   &      0.02 &    7.95 &  0.00 &    0.03 &  0.00 \\
Task     &      0.32 &    0.07 & 17.92 &    0.44 &  0.01 \\
Method   &      0.65 &    0.21 &  0.24 &   53.27 &  0.02 \\
None     &      2.40 &    1.30 &  2.82 &    8.50 &  0.00 \\
\bottomrule
\end{tabular}



In [23]:
from tqdm import tqdm
def matching_links(spans_1, spans_2, links_1, links_2) :
    match = np.zeros((len(spans_1), len(spans_2)))
    match_types = {k + '_2':{'Exist' : {'+' : 0, '-' : 0}, 'New' : {'+' : 0, '-' : 0}} for k in ents}
    for i, s1 in enumerate(spans_1) :
        for j, s2 in enumerate(spans_2) :
            if span_match((s1[0], s1[1]), (s2[0], s2[1])) > 0.5 :
                match[i, j] = 1
                
    assert all(x <= 1 for x in match.sum(0)), breakpoint()
    assert all(x <= 1 for x in match.sum(1)), breakpoint()
            
            
    for i, s2 in enumerate(spans_2) :
        matched = [j for j, x in enumerate(match[:, i]) if x == 1]
        assert len(matched) <= 1
        is_matched = True if len(matched) > 0 and spans_1[matched[0]][2] == s2[2] else False
        if not is_matched :
            t = 'New'
            new_links = links_2.get(tuple(s2[:2]), set())
            match_types[s2[2] + '_2'][t]['+'] += 1 if len(new_links) > 0 else 0
        else :
            t = 'Exist'
            matched_span = spans_1[matched[0]]
            new_links, old_links = links_2.get(tuple(s2[:2]), set()), links_1.get(tuple(matched_span[:2]), set())
            match_types[s2[2] + '_2'][t]['+'] += 1 if len(new_links - old_links) > 0 else 0
            match_types[s2[2] + '_2'][t]['-'] += 1 if len(old_links - new_links) > 0 else 0
            
            
    return match_types

In [24]:
match_types = {k + '_2':{(t, q) : 0 for t in ['Exist', 'New'] for q in ['+', '-']} for k in ents}

for i in tqdm(range(len(annotated_data))) :
    m = matching_links(original_data[i]['ner'], annotated_data[i]['ner'], original_data[i]['span_to_cluster'], annotated_data[i]['span_to_cluster'])
    for k in m :
        m[k] = {(s, x):n for s, v in m[k].items() for x, n in v.items()}
        for (s, x), n in m[k].items() :
            match_types[k][(s, x)] += n

100%|██████████| 440/440 [01:07<00:00,  7.27it/s]


In [32]:
pd.DataFrame(match_types).sum().sum()*100 / (440*360)

12.448232323232324

In [25]:
print((pd.DataFrame(match_types).T / len(annotated_data)).rename(index=lambda x : x.replace('_2', '')).to_latex(float_format="{:0.1f}".format))

\begin{tabular}{lrrrr}
\toprule
{} & \multicolumn{2}{l}{Exist} & \multicolumn{2}{l}{New} \\
{} &     + &    - &   + &   - \\
\midrule
Material &   1.2 &  0.8 & 6.8 & 0.0 \\
Metric   &   0.4 &  1.2 & 2.8 & 0.0 \\
Task     &   0.9 &  1.6 & 5.4 & 0.0 \\
Method   &   1.7 & 14.0 & 8.0 & 0.0 \\
None     &   0.0 &  0.0 & 0.0 & 0.0 \\
\bottomrule
\end{tabular}



In [5]:
from dygie.data.dataset_readers.read_pwc_dataset import is_x_in_y
does_overlap = lambda x, y: max(x[0], y[0]) < min(x[1], y[1])

def map_to_section_and_sentence(data) :
    mentions = data['ner']
    sections = data['sections']
    sentences = data['sentences']
    
    for e in mentions:
        in_sentences = [i for i, s in enumerate(data["sentences"]) if is_x_in_y(e, s)]
        if len(in_sentences) > 1:
            breakpoint()
        if len(in_sentences) == 0:
            in_sentences = [i for i, s in enumerate(data["sentences"]) if does_overlap(e, s)]
            assert sorted(in_sentences) == list(range(min(in_sentences), max(in_sentences) + 1)), breakpoint()
            # breakpoint()
            in_sentences = sorted(in_sentences)
            data["sentences"][in_sentences[0]][1] = data["sentences"][in_sentences[-1]][1]
            data["sentences"] = [s for i, s in enumerate(data["sentences"]) if i not in in_sentences[1:]]
    mention_to_section = {}
    mention_to_sentence = {}
    for s, e, _ in mentions :
        done = False
        for s_1, e_1 in sections :
            if is_x_in_y((s, e), (s_1, e_1)):
                mention_to_section[(s, e)] = (s_1, e_1)
                done = True
        assert done
        
    for s, e, _ in mentions :
        done = False
        for s_1, e_1 in sentences :
            if is_x_in_y((s, e), (s_1, e_1)):
                mention_to_sentence[(s, e)] = (s_1, e_1)
                done = True
        assert done
        
    data['mention_to_section'] = mention_to_section
    data['mention_to_sentence'] = mention_to_sentence
    
    return data



In [6]:
from tqdm import tqdm
for d in tqdm(annotated_data) :
    map_to_section_and_sentence(d)

100%|██████████| 440/440 [00:15<00:00, 27.74it/s]


In [17]:
from scripts.entity_utils import used_entities
from itertools import combinations
def stats_for_relation(data) :
    n_ary_relations = data['n_ary_relations']
    corefs = data['coref']
    stats = {}
    for card in range(2, 5) :
        for elist in combinations(used_entities, card) :
            stats[tuple(elist)] = {'In section':0, 'Out section':0, 'In sentence':0, 'Out sentence':0}
            for rel in n_ary_relations :
                rel = [rel[k] for k in elist]
                clusters = [corefs[x] for x in rel]
#                 if any(len(c) == 0 for c in clusters) :
#                     continue
                sections = [set([data['mention_to_section'][tuple(span)] for span in cluster]) 
                            for cluster in clusters]
                sections = set.intersection(*sections)
                if len(sections) > 0 :
                    stats[tuple(elist)]['In section'] += 1
                else :
                    stats[tuple(elist)]['Out section'] += 1
#                     if card == 4 :
#                         print(data['doc_id'], rel)
                    
                sections = [set([data['mention_to_sentence'][tuple(span)] for span in cluster]) 
                            for cluster in clusters]
                sections = set.intersection(*sections)
                if len(sections) > 0 :
                    stats[tuple(elist)]['In sentence'] += 1
                else :
                    stats[tuple(elist)]['Out sentence'] += 1
                    
    return stats

In [18]:
from tqdm import tqdm
stats = stats_for_relation(annotated_data[0])
for d in tqdm(annotated_data[1:]) :
    stats_d = stats_for_relation(d)
    for key in stats_d :
        for v, x in stats_d[key].items() :
            stats[key][v] += x

100%|██████████| 439/439 [00:00<00:00, 1177.72it/s]


In [19]:
import pandas as pd
stats = {"/".join(k):v for k, v in stats.items()}

In [20]:
metrics = pd.DataFrame(stats).T / 440

In [21]:
import numpy as np
metrics['n_ary'] = pd.Series(metrics.index).apply(lambda x : len(x.split('/'))).values
print(metrics.groupby('n_ary').agg(np.sum).to_latex(float_format=lambda x: "{:0.3f}".format(x)))

\begin{tabular}{lrrrr}
\toprule
{} &  In section &  In sentence &  Out section &  Out sentence \\
n\_ary &             &              &              &               \\
\midrule
2     &      12.380 &        6.630 &       16.980 &        22.730 \\
3     &       4.955 &        0.591 &       14.618 &        18.982 \\
4     &       0.707 &        0.011 &        4.186 &         4.882 \\
\bottomrule
\end{tabular}



In [22]:
1.57 / (1.57 + 0.011)

0.9930423782416193

In [16]:
12.38+16.98

29.36

In [26]:
from scripts.entity_utils import used_entities
from itertools import combinations
def stats_for_relation(data) :
    n_ary_relations = data['n_ary_relations']
    corefs = data['coref']
    stats = {}
    for card in range(2, 5) :
        rels = []
        for elist in combinations(used_entities, card) :
            for rel in n_ary_relations :
                rel = [rel[k] for k in elist]
                rels.append(tuple(rel))
                
        rels = set(rels)
        stats[card] = len(rels)
                    
    return stats

In [28]:
from tqdm import tqdm
stats = stats_for_relation(annotated_data[0])
for d in tqdm(annotated_data[1:]) :
    stats_d = stats_for_relation(d)
    for key, x in stats_d.items() :
        stats[key] += x

100%|██████████| 439/439 [00:00<00:00, 13902.37it/s]


In [30]:
7072/440

16.072727272727274