In [287]:
from collections import defaultdict

import pandas as pd

In [288]:
pred_df = pd.read_csv('../data/predictions.csv')
data_row = pd.read_csv('../data/row_data.csv')

In [289]:
pred_df.columns

Index(['Text', 'fold', 'agr_rate', 'bert_fold_0', 'bert_fold_1', 'bert_fold_2',
       'bert_fold_3', 'bart_fold_0', 'bart_fold_1', 'bart_fold_2',
       'bart_fold_3', 'gptv2_fold_0', 'gptv2_fold_1', 'gptv2_fold_2',
       'gptv2_fold_3'],
      dtype='object')

In [290]:
def ensemble_models(row):
    model_names = [
        'bert_fold_0', 'bert_fold_1',
        'bart_fold_0', 'bart_fold_1',
        'bart_fold_2', 'bart_fold_3',
        'gptv2_fold_3'
    ]

    counter = 0
    val = 0
    for model_name in model_names:
        train_fold = int(model_name.split('_')[-1])
        counter += 1
        val += row[model_name]
    val = val / counter
    return val

In [291]:
pred_df['ensemb_preds'] = pred_df.apply(ensemble_models, axis=1)

In [292]:
text_pred_map = {}
for i in range(len(pred_df)):
    row = pred_df.iloc[i]
    text_pred_map[row.Text] = row['ensemb_preds'] > 0.5

In [293]:
user_info = defaultdict(lambda: [0, 0])
for i in range(len(data_row)):
    row = data_row.iloc[i]
    annotator_id = row['Annotator ID']
    user_info[annotator_id][1] += 1
    user_info[annotator_id][0] += row.Answer == text_pred_map[row.Text]

In [294]:
error_rate = [
    (user_name, i / j, j)
    for user_name, (i, j) in user_info.items()
    if j > 10
]

In [295]:
error_rate = sorted(error_rate, key=lambda x: x[1])

In [296]:
print(error_rate)

[('A3BJX6UUSOIKFN', 0.5339673913043478, 1472), ('A1MG8KNVSVZ365', 0.5512496489750071, 3561), ('AQIP3DSYXEXX5', 0.5618333813779187, 3469), ('A3OCJJMRKAIJZA', 0.5824835032993402, 5001), ('A33Y36Y252Z30U', 0.6060606060606061, 99), ('A3BCKNE5CWHODZ', 0.6521136521136521, 1443), ('A2CJFO19NY4T5R', 0.6594202898550725, 138), ('A3BISMR4GI02ZG', 0.6910002786291446, 3589), ('A9HQ3E0F2AGVO', 0.6926869350862778, 1217), ('A1YSYI926BBOHW', 0.7023809523809523, 84), ('A2KHLJ2F58BEZK', 0.7051282051282052, 78), ('AMYURTQIMAC8T', 0.7255639097744361, 532), ('AKSLU0C30G3JT', 0.7272727272727273, 11), ('A2A78DMGLC1S0Y', 0.73, 100), ('AKQAI78JTXXC9', 0.75, 20), ('A3MV3PT4TOO69P', 0.7543893874365978, 5126), ('AG36U7IOG2LAP', 0.756, 250), ('A2WPHVMLLEV5ZB', 0.7634730538922155, 334), ('AAX9LTAOIBECD', 0.7777777777777778, 18), ('AYTH0E5PUXWX8', 0.7824074074074074, 864), ('ARW1TCHCLEK1W', 0.7850287907869482, 1042), ('A3124SRR191UIL', 0.7853658536585366, 205), ('AR9AU5FY1S3RO', 0.7868686868686868, 990), ('A33B85TN97