In [129]:
import json
import data
from eval_model import compute_accuracy
import pandas as pd

tag_vocab, _ = data.load_vocab('data/swda_tag_vocab.json')
test_data = data.load_data('data/swda_test.json', 'utts', 'tags')

with open('models/wordvec-avg.L/preds.E10.json') as f:
    preds_wva = json.load(f)
preds_wva = [[tag_vocab[t] for t in d] for d in preds_wva]

with open('models/bert.L/preds.E5.json') as f:
    preds_bert = json.load(f)
preds_bert = [[tag_vocab[t] for t in d] for d in preds_bert]

In [130]:
df = []
for i, diag in enumerate(test_data):
    for j, (utt, tag) in enumerate(zip(*diag)):
        df.append(
            { 'diag_id': i
            , 'utt_id': j
            , 'utt': utt
            , 'tag': tag
            , 'pred_wva': preds_wva[i][j]
            , 'pred_bert': preds_bert[i][j]}
        )
df = pandas.DataFrame(df)
df['correct_wva'] = df.pred_wva == df.tag
df['correct_bert'] = df.pred_bert == df.tag

In [131]:
print("Wordvec-Avg accuracy {:.4f}".format(sum(df.correct_wva) / len(df)))
print("Bert accuracy        {:.4f}".format(sum(df.correct_bert) / len(df)))

Wordvec-Avg accuracy 0.7017
Bert accuracy        0.7788


In [192]:
def scores(df, pred):   
    tp = df[df[pred] == df.tag].tag.value_counts()
    p  = df[pred].value_counts()
    pr = (tp / p).fillna(0)
    re = (tp / df.tag.value_counts()).fillna(0)
    f1 = (2 * (pr * re) / (pr + re)).fillna(0)
    return pd.DataFrame([pr, re, f1], index=['precision', 'recall', 'f1']).transpose()


In [197]:
sdf = pd.merge(scores(df, 'pred_wva'), scores(df, 'pred_bert'), left_index=True, right_index=True, suffixes=('_wva', '_bert'))
sdf['tag_count'] = df.tag.value_counts()
sdf['f1_diff'] = sdf.f1_bert - sdf.f1_wva

In [198]:
sdf = sdf.sort_values('tag_count', ascending=False)
sdf

Unnamed: 0,precision_wva,recall_wva,f1_wva,precision_bert,recall_bert,f1_bert,tag_count,f1_diff
sd,0.754957,0.834095,0.792555,0.799805,0.880527,0.838227,14882,0.045672
b,0.753401,0.94886,0.839909,0.807865,0.926188,0.86299,7763,0.023081
fx/sv,0.571247,0.582145,0.576644,0.659843,0.609002,0.633404,5399,0.05676
+,0.675864,0.787972,0.727625,0.899218,0.868531,0.883608,3575,0.155983
%,0.699776,0.704049,0.701906,0.84183,0.774743,0.806894,3112,0.104988
aa,0.490935,0.28362,0.359533,0.594778,0.477168,0.529521,2387,0.169988
fe/ba,0.640479,0.518878,0.573302,0.787469,0.691478,0.736358,927,0.163057
qr/qy,0.480803,0.634793,0.54717,0.752066,0.83871,0.793028,868,0.245859
x,0.843931,0.897081,0.869695,0.965672,0.993856,0.979561,651,0.109866
fc,0.73913,0.468966,0.57384,0.856195,0.667241,0.75,580,0.17616


In [199]:
sdf = sdf.sort_values('f1_diff', ascending=False)
sdf

Unnamed: 0,precision_wva,recall_wva,f1_wva,precision_bert,recall_bert,f1_bert,tag_count,f1_diff
fp,,0.0,0.0,0.938776,0.666667,0.779661,69,0.779661
h,,0.0,0.0,0.744589,0.747826,0.746204,230,0.746204
qo,0.0,0.0,0.0,0.803419,0.696296,0.746032,135,0.746032
"fo/o/fw/""/by/bc",0.0,0.0,0.0,0.824561,0.556213,0.664311,169,0.664311
fa,,0.0,0.0,0.75,0.5,0.6,12,0.6
qrr,,0.0,0.0,0.638889,0.534884,0.582278,43,0.582278
qw,0.442529,0.213296,0.28785,0.790026,0.833795,0.811321,361,0.52347
bk,0.4375,0.028926,0.054264,0.603093,0.483471,0.536697,242,0.482434
ad,,0.0,0.0,0.568627,0.306878,0.398625,189,0.398625
^h,,0.0,0.0,0.575758,0.263889,0.361905,72,0.361905


In [173]:
scores.sort_values('tag_count', ascending=False)

Unnamed: 0,precision,recall,f1,tag_count
sd,0.754957,0.834095,0.792555,14882
b,0.753401,0.94886,0.839909,7763
fx/sv,0.571247,0.582145,0.576644,5399
+,0.675864,0.787972,0.727625,3575
%,0.699776,0.704049,0.701906,3112
aa,0.490935,0.28362,0.359533,2387
fe/ba,0.640479,0.518878,0.573302,927
qr/qy,0.480803,0.634793,0.54717,868
x,0.843931,0.897081,0.869695,651
fc,0.73913,0.468966,0.57384,580


In [94]:
pd.crosstab(df.tag, df.pred_wva)

pred_wva,%,+,aa,b,bh,bk,fc,fe/ba,"fo/o/fw/""/by/bc",fx/sv,na,nn,ny,qo,qr/qy,qw,sd,x
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
%,2191,165,45,247,0,0,6,5,0,86,0,1,10,0,1,1,346,8
+,82,2817,25,33,0,0,2,8,0,196,1,0,0,0,26,8,374,3
^2,15,22,10,14,1,0,1,1,0,26,0,0,0,0,2,0,39,5
^g,0,0,1,4,3,0,0,0,0,0,0,0,0,0,8,0,0,0
^h,16,2,3,1,0,0,2,4,0,7,1,3,3,0,0,0,28,2
^q,5,7,3,5,0,0,0,5,0,39,0,0,0,0,8,0,132,2
aa,49,56,677,1228,0,1,2,51,0,142,0,5,27,0,0,0,138,11
aap/am,0,1,4,0,0,0,1,0,0,5,0,0,0,0,0,0,2,0
ad,5,19,3,3,0,0,4,4,1,22,0,0,1,0,14,1,111,1
ar,4,1,44,0,0,0,0,2,0,2,0,2,0,0,0,0,9,4


In [95]:
pd.crosstab(df.tag, df.pred_bert)

pred_bert,%,+,^2,aa,b,bh,fe/ba,fx/sv,h,ny,qr/qy,qw,qy^d,sd,x
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
%,1627,150,0,1,292,0,3,79,1,0,3,0,0,939,17
+,91,2386,0,11,109,0,10,246,1,0,26,1,0,684,10
^2,11,16,0,2,35,0,3,12,0,0,1,0,0,50,6
^g,0,0,0,0,5,0,0,0,0,0,10,0,0,1,0
^h,7,3,0,0,4,0,2,11,0,0,0,0,0,44,1
^q,4,9,0,1,10,0,2,18,0,0,4,1,0,156,1
aa,41,65,1,123,1494,0,65,100,0,1,5,0,0,482,10
aap/am,1,2,0,1,2,0,0,0,1,0,0,0,0,6,0
ad,6,13,0,2,18,0,2,19,0,0,5,0,0,123,1
ar,3,1,0,10,5,0,2,1,0,0,3,0,0,43,0
