In [1]:
import dill as pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
with open('../data_splits.pickle', 'rb') as f:
    data_splits = pickle.load(f)
val_bounds = np.cumsum([0] + [len(data_splits[0][250]['normed'][fold_ix]) for fold_ix in range(5)])

In [29]:
def extract_entities(ps, n):
    cat_ps = ps.argmax(-1)
    all_entities = {}
    current_cat = None
    current_start = None
    for ix in range(1, n - 1):
        if cat_ps[ix] % 2 == 1:
            if current_cat is not None:
                if current_cat not in all_entities:
                    all_entities[current_cat] = []
                all_entities[current_cat].append((current_start, ix - 1))
            current_cat = (cat_ps[ix] + 1) // 2
            current_start = ix
        elif cat_ps[ix] == 0:
            if current_cat is not None:
                if current_cat not in all_entities:
                    all_entities[current_cat] = []
                all_entities[current_cat].append((current_start, ix - 1))
            current_cat = None
    if current_cat is not None:
        if current_cat not in all_entities:
            all_entities[current_cat] = []
        all_entities[current_cat].append((current_start, ix))
    
    for cat_ix, min_len in zip(range(1, 8), (2, 2, 5, 2, 4, 3, 2)):
        if cat_ix in all_entities:
            all_entities[cat_ix] = [x for x in all_entities[cat_ix] if x[1] - x[0] + 1 >= min_len]
        
    return all_entities

def process_sample(raw_ps, index_map, bounds, gt_spans, num_tokens, match_stats, min_len=0):
    
    #bounds[num_tokens - 2, 1] = min(len(index_map) - 1, bounds[num_tokens - 2, 1])
    predicted_spans = {x: [map_span_to_word_indices(span, index_map, bounds) for span in y] 
                       for x, y in extract_entities(raw_ps, num_tokens).items()}
    predicted_spans = {x: [z for z in y if z[1] - z[0] >= min_len] for x, y in predicted_spans.items()}
    
    for cat_ix in range(1, 8):
        
        pspans = predicted_spans.get(cat_ix, [])
        gspans = gt_spans.get(cat_ix, [])
        if not len(pspans) or not len(gspans):
            match_stats[cat_ix]['fn'] += len(gspans)
            match_stats[cat_ix]['fp'] += len(pspans)
        else:
            all_overlaps = np.zeros((len(pspans), len(gspans)))
            for x1 in range(len(pspans)):
                pspan = pspans[x1]
                for x2 in range(len(gspans)):
                    gspan = gspans[x2]
                    start_ix = max(pspan[0], gspan[0])
                    end_ix = min(pspan[1], gspan[1])
                    overlap = max(0, end_ix - start_ix + 1)
                    if overlap > 0:
                        o1 = overlap / (pspan[1] - pspan[0] + 1)
                        o2 = overlap / (gspan[1] - gspan[0] + 1)
                        if min(o1, o2) >= .5:
                            all_overlaps[x1, x2] = max(o1, o2)
            unused_p_ix = set(range(len(pspans)))
            unused_g_ix = set(range(len(gspans)))
            col_size = len(pspans)
            row_size = len(gspans)
            for ix in np.argsort(all_overlaps.ravel())[::-1]:
                if not len(unused_g_ix) or not len(unused_p_ix) or all_overlaps.ravel()[ix] == 0:
                    match_stats[cat_ix]['fp'] += len(unused_p_ix)
                    match_stats[cat_ix]['fn'] += len(unused_g_ix)
                    break
                p_ix = ix // row_size
                g_ix = ix % row_size
                if p_ix not in unused_p_ix or g_ix not in unused_g_ix:
                    continue
                match_stats[cat_ix]['tp'] += 1
                unused_g_ix.remove(g_ix)
                unused_p_ix.remove(p_ix)
    return match_stats

def map_span_to_word_indices(span, index_map, bounds):
    return (index_map[bounds[span[0], 0]], index_map[bounds[span[1], 1] - 1])

def split_predstring(x):
    vals = x.split()
    return int(vals[0]), int(vals[-1])

In [44]:
def calc_entity_score(span, ps, c):
    s, e = span
    score = (ps[s, c * 2 - 1] + ps[s + 1: e + 1, c * 2].sum())/(e - s + 1)
    return score

In [47]:
def extract_entities(ps, n):
    cat_ps = ps.argmax(-1)
    all_entities = {}
    current_cat = None
    current_start = None
    for ix in range(1, n - 1):
        if cat_ps[ix] % 2 == 1:
            if current_cat is not None:
                if current_cat not in all_entities:
                    all_entities[current_cat] = []
                all_entities[current_cat].append((current_start, ix - 1))
            current_cat = (cat_ps[ix] + 1) // 2
            current_start = ix        
        elif cat_ps[ix] == 0:
            if current_cat is not None:
                if current_cat not in all_entities:
                    all_entities[current_cat] = []
                all_entities[current_cat].append((current_start, ix - 1))
            current_cat = None
        elif current_cat is not None and cat_ps[ix] != current_cat * 2:
            if current_cat not in all_entities:
                all_entities[current_cat] = []
            all_entities[current_cat].append((current_start, ix - 1))
            current_cat = None
    if current_cat is not None:
        if current_cat not in all_entities:
            all_entities[current_cat] = []
        all_entities[current_cat].append((current_start, ix))

    # score_thresholds = [None, -.9, -.56, -.55, -.65, -.56, -.76, -.68]
    score_thresholds = [None, -5, -3.4, -2.065, -2.5, -6, -5.75, -5.5]
    
    for cat_ix, min_len in zip(range(1, 8), (2, 2, 5, 2, 4, 3, 2)):
        if cat_ix in all_entities:
            all_entities[cat_ix] = [x for x in all_entities[cat_ix] if x[1] - x[0] + 1 >= min_len and calc_entity_score(x, ps, cat_ix) > score_thresholds[cat_ix]]
            
        
    return all_entities

In [5]:
def make_gt_dict(df):
    gt_dict = {}
    for cat_ix in range(1, 8):
        cat_name = label_names[cat_ix]
        cat_entities = df.loc[df.discourse_type==cat_name]
        if len(cat_entities):
            gt_dict[cat_ix] = [(x[0], x[1]) for x in cat_entities.predictionstring.map(split_predstring)]
    return gt_dict

In [6]:
def f1s(stats):
    f1s = np.zeros(8)
    rec = np.zeros(7)
    prec = np.zeros(7)
    for ix in range(1, 8):
        f1s[ix] = stats[ix]['tp'] / (1e-7 + stats[ix]['tp'] + .5 * (stats[ix]['fp'] + stats[ix]['fn']))
        rec[ix - 1] = stats[ix]['tp'] / (1e-7 + stats[ix]['tp'] + stats[ix]['fn'])
        prec[ix - 1] = stats[ix]['tp'] / (1e-7 + stats[ix]['tp'] + stats[ix]['fp'])
    f1s[0] = np.mean(f1s[1:])
    return f1s

In [7]:
#data = pd.read_csv('../train.csv')
label_names = ['None', 'Lead', 'Position', 'Evidence', 'Claim',
               'Concluding Statement', 'Counterclaim', 'Rebuttal']


In [8]:
with open('../gt_dicts.pickle', 'rb') as f:
    gt_dicts = pickle.load(f)

In [9]:
from tqdm import tqdm

In [42]:
def evaluate(list_of_dir_names):
    print(f'evaluating {list_of_dir_names}')
    arrays = []
    oof_dir_name = list_of_dir_names[0]
    
    for array_name in 'all_outs, all_bounds, all_token_nums, all_word_indices, all_sample_ids'.split(', '):
        arrays.append(np.load(f'oof_ps/{oof_dir_name}/{array_name}.npy', allow_pickle=True))
    all_outs, all_bounds, all_token_nums, all_word_indices, all_sample_ids = arrays
    
    for oof_dir_name in list_of_dir_names[1:]:
        all_outs += np.load(f'oof_ps/{oof_dir_name}/all_outs.npy', allow_pickle=True)

    for fold_ix in range(5):
        start_ix = val_bounds[fold_ix]
        end_ix = val_bounds[fold_ix + 1]
        fold_stats = {ix:  {'fp': 0, 'fn': 0, 'tp': 0} for ix in range(1, 8)}
        fold_ps = all_outs[start_ix:end_ix]
        fold_bounds = all_bounds[start_ix:end_ix]
        fold_token_nums = all_token_nums[start_ix:end_ix]
        fold_word_indices = all_word_indices[start_ix:end_ix]
        fold_sample_ids = all_sample_ids[start_ix:end_ix]
        for sample_ix in range(len(fold_ps)):
            fold_stats = process_sample(fold_ps[sample_ix], fold_word_indices[sample_ix], 
                                        fold_bounds[sample_ix],
                                        gt_dicts[fold_sample_ids[sample_ix]],
                                        fold_token_nums[sample_ix],
                                       fold_stats)
        print(f'fold {fold_ix} f1s: {f1s(fold_stats)}')
    
    for sample_ix in range(val_bounds[-1]):
        fold_stats = process_sample(all_outs[sample_ix], all_word_indices[sample_ix], 
                                    all_bounds[sample_ix],
                                    gt_dicts[all_sample_ids[sample_ix]],
                                    all_token_nums[sample_ix],
                                   fold_stats)
    print(f'CV f1s: {f1s(fold_stats)}')


In [36]:
evaluate(['debertav1_rce01'])

evaluating ['debertav1_rce01']
fold 0 f1s: [0.68948203 0.83719705 0.70455643 0.7406348  0.65539329 0.85863686
 0.5621446  0.46781116]
fold 1 f1s: [0.68838468 0.82464455 0.71379815 0.73202322 0.64304235 0.85083075
 0.57084856 0.48350515]
fold 2 f1s: [0.69156138 0.82933263 0.71834625 0.73452787 0.65588681 0.84813127
 0.57349398 0.48121086]
fold 3 f1s: [0.68134482 0.82512643 0.70401829 0.738637   0.65273914 0.85127453
 0.5434606  0.45415778]
fold 4 f1s: [0.69342359 0.83226837 0.72140762 0.73610489 0.6591209  0.85425101
 0.56492411 0.48588821]
CV f1s: [0.6895975  0.83014037 0.7139146  0.73634442 0.6542189  0.85290444
 0.56329936 0.47636039]


In [48]:
evaluate(['debertav1_rce01', 'debertav1_norce_mnli', 'longformer'])

evaluating ['debertav1_rce01', 'debertav1_norce_mnli', 'longformer']
fold 0 f1s: [0.69677957 0.83508403 0.71330724 0.75843333 0.66666667 0.85514019
 0.56665238 0.48217317]
fold 1 f1s: [0.6976502  0.82411796 0.72124113 0.75004301 0.65961901 0.8537607
 0.58078053 0.49398907]
fold 2 f1s: [0.70187565 0.83753943 0.72438863 0.75313241 0.67021777 0.85516739
 0.5864094  0.48627451]
fold 3 f1s: [0.69346296 0.83466667 0.71780175 0.7560636  0.66850917 0.85866961
 0.55795848 0.46057143]
fold 4 f1s: [0.70347084 0.84249867 0.73143996 0.75455756 0.67549738 0.85252525
 0.56514085 0.5026362 ]
CV f1s: [0.69945575 0.83604097 0.72325151 0.75446629 0.66933601 0.85463064
 0.57050731 0.48795752]
