In [1]:
import dill as pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
with open('../data_splits.pickle', 'rb') as f:
    data_splits = pickle.load(f)
val_bounds = np.cumsum([0] + [len(data_splits[0][250]['normed'][fold_ix]) for fold_ix in range(5)])

In [3]:
def process_sample(raw_ps, index_map, bounds, gt_spans, num_tokens, match_stats, ):
    
    predicted_spans = {x: [map_span_to_word_indices(span, index_map, bounds) for span in y] 
                       for x, y in extract_entities(raw_ps, num_tokens).items()}
    num_entities = np.zeros(7)
    for cat_ix in range(1, 8):
        
        pspans = predicted_spans.get(cat_ix, [])
        gspans = gt_spans.get(cat_ix, [])
        num_entities[cat_ix - 1] = len(pspans)
        if not len(pspans) or not len(gspans):
            match_stats[cat_ix]['fn'] += len(gspans)
            match_stats[cat_ix]['fp'] += len(pspans)
        else:
            all_overlaps = np.zeros((len(pspans), len(gspans)))
            for x1 in range(len(pspans)):
                pspan = pspans[x1]
                for x2 in range(len(gspans)):
                    gspan = gspans[x2]
                    start_ix = max(pspan[0], gspan[0])
                    end_ix = min(pspan[1], gspan[1])
                    overlap = max(0, end_ix - start_ix + 1)
                    if overlap > 0:
                        o1 = overlap / (pspan[1] - pspan[0] + 1)
                        o2 = overlap / (gspan[1] - gspan[0] + 1)
                        if min(o1, o2) >= .5:
                            all_overlaps[x1, x2] = max(o1, o2)
            unused_p_ix = set(range(len(pspans)))
            unused_g_ix = set(range(len(gspans)))
            col_size = len(pspans)
            row_size = len(gspans)
            for ix in np.argsort(all_overlaps.ravel())[::-1]:
                if not len(unused_g_ix) or not len(unused_p_ix) or all_overlaps.ravel()[ix] == 0:
                    match_stats[cat_ix]['fp'] += len(unused_p_ix)
                    match_stats[cat_ix]['fn'] += len(unused_g_ix)
                    break
                p_ix = ix // row_size
                g_ix = ix % row_size
                if p_ix not in unused_p_ix or g_ix not in unused_g_ix:
                    continue
                match_stats[cat_ix]['tp'] += 1
                unused_g_ix.remove(g_ix)
                unused_p_ix.remove(p_ix)
    return match_stats, num_entities

def map_span_to_word_indices(span, index_map, bounds):
    return (index_map[bounds[span[0], 0]], index_map[bounds[span[1], 1] - 1])

def split_predstring(x):
    vals = x.split()
    return int(vals[0]), int(vals[-1])

In [4]:
def calc_entity_score(span, ps, c):
    s, e = span
    score = (ps[s, c * 2 - 1] + ps[s + 1: e + 1, c * 2].sum())/(e - s + 1)
    return score

In [75]:
def extract_entities(ps, n):
    cat_ps = ps.argmax(-1)
    all_entities = {}
    current_cat = None
    current_start = None
    seq_len = ps.shape[0]
    for ix in range(n):
    # for ix in range(seq_len - n, seq_len - 2):
        if cat_ps[ix] % 2 == 1:
            if current_cat is not None:
                if current_cat not in all_entities:
                    all_entities[current_cat] = []
                all_entities[current_cat].append((current_start, ix - 1))
            current_cat = (cat_ps[ix] + 1) // 2
            current_start = ix        
        elif cat_ps[ix] == 0:
            if current_cat is not None:
                if current_cat not in all_entities:
                    all_entities[current_cat] = []
                all_entities[current_cat].append((current_start, ix - 1))
            current_cat = None
        elif current_cat is not None and cat_ps[ix] != current_cat * 2:
            if current_cat not in all_entities:
                all_entities[current_cat] = []
            all_entities[current_cat].append((current_start, ix - 1))
            current_cat = None
    if current_cat is not None:
        if current_cat not in all_entities:
            all_entities[current_cat] = []
        all_entities[current_cat].append((current_start, ix))

    score_thresholds = [None, -.8, -.48, -.49, -.65, -.47, -.7, -.68]
    
    score_thresholds = [None, -.8, -.5, -.515, -.635, -.535, -.82, -.88]
    score_thresholds = [None, -2.1, -1.15, -1.2, -1.5, -1.36, -3, -3]
    
    # length_thresholds = (6, 3, 16, 2, 8, 7, 5)
    # score_thresholds = [None, -5, -3.4, -2.065, -2.5, -6, -5.75, -5.5]
    xlnet_score_thresholds = [None, -.9, -.547, -.537, -.685, -.53, -.8, -.77]
    xlnet_length_thresholds = (5, 2, 17, 2, 5, 7, 5)
    debertav1new_score_thresholds = [None, -1., -.565, -.545, -.615, -.56, -1.05, -1.1]
    debertav1new_length_thresholds = (6, 3, 18, 2, 7, 6, 7)
    debertav1x2_score_thresholds = [None, -1.71, -1.3, -1.15, -1.35, -1.5, -3.5, -3.5]
    debertav1x2_len_thresholds = (3, 3, 17, 2, 6, 7, 7)
    debertav1old_score_thresholds = [None, -.8, -.5, -.515, -.635, -.535, -.82, -.88]
    debertav1old_len_thresholds = (6, 3, 16, 2, 8, 7, 5)
    
    
    length_thresholds = (6, 5, 17, 2, 7, 7, 5)
    for cat_ix, min_len in zip(range(1, 8), debertav1old_len_thresholds):
        if cat_ix in all_entities:
            possible_entities = [x for x in all_entities[cat_ix] if x[1] - x[0] + 1 >= min_len
                                      and calc_entity_score(x, ps, cat_ix) > debertav1old_score_thresholds[cat_ix]]
            if cat_ix in (1, 2,  5):
                if len(possible_entities) > 1:
                    max_score = -9999
                    for x in possible_entities:
                        entity_score = calc_entity_score(x, ps, cat_ix)
                        if entity_score > max_score:
                            max_score = entity_score
                            biggest_entity = x
                    possible_entities = [biggest_entity]
            all_entities[cat_ix] = possible_entities
    return all_entities

In [65]:
def make_gt_dict(df):
    gt_dict = {}
    for cat_ix in range(1, 8):
        cat_name = label_names[cat_ix]
        cat_entities = df.loc[df.discourse_type==cat_name]
        if len(cat_entities):
            gt_dict[cat_ix] = [(x[0], x[1]) for x in cat_entities.predictionstring.map(split_predstring)]
    return gt_dict

In [66]:
def f1s(stats):
    f1s = np.zeros(8)
    rec = np.zeros(7)
    prec = np.zeros(7)
    for ix in range(1, 8):
        f1s[ix] = stats[ix]['tp'] / (1e-7 + stats[ix]['tp'] + .5 * (stats[ix]['fp'] + stats[ix]['fn']))
        rec[ix - 1] = stats[ix]['tp'] / (1e-7 + stats[ix]['tp'] + stats[ix]['fn'])
        prec[ix - 1] = stats[ix]['tp'] / (1e-7 + stats[ix]['tp'] + stats[ix]['fp'])
    f1s[0] = np.mean(f1s[1:])
    return f1s, prec, rec

In [67]:
#data = pd.read_csv('../train.csv')
label_names = ['None', 'Lead', 'Position', 'Evidence', 'Claim',
               'Concluding Statement', 'Counterclaim', 'Rebuttal']


In [68]:
with open('../gt_dicts.pickle', 'rb') as f:
    gt_dicts = pickle.load(f)

In [None]:
from tqdm import tqdm

In [None]:
def merge_ps(logits_a, logits_b, bounds_a, bounds_b):
    a = bounds_a
    b = bounds_b
    mapping_a = []
    mapping_b = []
    new_bounds = []
    apos = 0
    bpos = 0
    a_s, a_e = a[apos]
    b_s, b_e = b[bpos]
    current_start = 0
    while True:
        if a_e == b_e:
            if apos == len(a) - 1 or bpos == len(b) - 1:
                break
            new_bounds.append((current_start, b_e))
            mapping_a.append(apos)
            mapping_b.append(bpos)
            next_a_s, next_b_s = a[apos + 1][0], b[bpos + 1][0]
            if next_a_s < next_b_s:
                apos += 1
                a_s, a_e = a[apos]
                current_start = a_s
            elif next_b_s < next_a_s:
                bpos += 1
                b_s, b_e = b[bpos]
                current_start = b_s
            else:
                apos += 1
                bpos += 1
                a_s, a_e = a[apos]
                b_s, b_e = b[bpos]
                current_start = a_s
        elif a_e < b_e:
            new_bounds.append((current_start, a_e))
            mapping_a.append(apos)
            mapping_b.append(bpos)
            apos += 1
            a_s, a_e = a[apos]
            current_start = a_s
        else:
            new_bounds.append((current_start, b_e))
            mapping_a.append(apos)
            mapping_b.append(bpos)
            bpos += 1
            b_s, b_e = b[bpos]
            current_start = b_s
    
    return logits_a[mapping_a] + logits_b[mapping_b], np.array(new_bounds)


In [69]:
def evaluate(list_of_dir_names):
    print(f'evaluating {list_of_dir_names}')
    arrays = []
    oof_dir_name = list_of_dir_names[0]
    
    for array_name in 'all_outs, all_bounds, all_token_nums, all_word_indices, all_sample_ids'.split(', '):
        arrays.append(np.load(f'oof_ps/{oof_dir_name}/{array_name}.npy', allow_pickle=True))
    all_outs, all_bounds, all_token_nums, all_word_indices, all_sample_ids = arrays
    # all_outs += np.load(f'oof_ps/{list_of_dir_names[1]}/all_outs.npy', allow_pickle=True)
    
#     for oof_dir_name in list_of_dir_names[1:]:
#         arrays = []
#         for array_name in 'all_outs, all_bounds, all_token_nums'.split(', '):
#             arrays.append(np.load(f'oof_ps/{oof_dir_name}/{array_name}.npy', allow_pickle=True))
#         new_outs, new_bounds, new_token_nums = arrays
#         merged_outs = []
#         merged_bounds = []
#         merged_token_nums = []
#         for sample_ix in range(len(new_outs)):
#             logits, bounds = merge_ps(all_outs[sample_ix][:all_token_nums[sample_ix]],
#                                       new_outs[sample_ix][:new_token_nums[sample_ix]],
#                                       all_bounds[sample_ix][:all_token_nums[sample_ix]],
#                                       new_bounds[sample_ix][:new_token_nums[sample_ix]])
#             merged_outs.append(logits)
#             merged_bounds.append(bounds)
#             merged_token_nums.append(len(logits))
#         all_outs = merged_outs
#         all_bounds = merged_bounds
#         all_token_nums = merged_token_nums
        
    all_num_entities = np.zeros((len(all_outs), 7))
    text_ix = 0
    for fold_ix in range(5):
        start_ix = val_bounds[fold_ix]
        end_ix = val_bounds[fold_ix + 1]
        fold_stats = {ix:  {'fp': 0, 'fn': 0, 'tp': 0} for ix in range(1, 8)}
        fold_ps = all_outs[start_ix:end_ix]
        fold_bounds = all_bounds[start_ix:end_ix]
        fold_token_nums = all_token_nums[start_ix:end_ix]
        fold_word_indices = all_word_indices[start_ix:end_ix]
        fold_sample_ids = all_sample_ids[start_ix:end_ix]
        for sample_ix in range(len(fold_ps)):
            fold_stats, num_entities = process_sample(fold_ps[sample_ix], fold_word_indices[sample_ix], 
                                        fold_bounds[sample_ix],
                                        gt_dicts[fold_sample_ids[sample_ix]],
                                        fold_token_nums[sample_ix],
                                       fold_stats)
            all_num_entities[text_ix] = num_entities
            text_ix += 1
            
        fs, prec, rec = f1s(fold_stats)
        print(f'fold {fold_ix}\nf1s: {fs}\nprec: {prec}\nrec: {rec}')
        
    fold_stats = {ix:  {'fp': 0, 'fn': 0, 'tp': 0} for ix in range(1, 8)}
    for sample_ix in range(val_bounds[-1]):
        fold_stats, _ = process_sample(all_outs[sample_ix], all_word_indices[sample_ix], 
                                    all_bounds[sample_ix],
                                    gt_dicts[all_sample_ids[sample_ix]],
                                    all_token_nums[sample_ix],
                                   fold_stats)
    fs, prec, rec = f1s(fold_stats)
    print(f'Total\nf1s: {fs}\nprec: {prec}\nrec: {rec}')
    return all_num_entities

In [None]:
evaluate(['debertav1_rce01']);