In [1]:
import json
import data
from eval_model import compute_accuracy
import pandas as pd

tag_vocab, _ = data.load_vocab('data/swda_tag_vocab.json')
test_data = data.load_data('data/swda_test.json', 'utts', 'tags')

with open('models/wordvec-avg.L/preds.E10.json') as f:
    preds_wva = json.load(f)
preds_wva = [[tag_vocab[t] for t in d] for d in preds_wva]

with open('models/bert.L/preds.E5.json') as f:
    preds_bert = json.load(f)
preds_bert = [[tag_vocab[t] for t in d] for d in preds_bert]

with open('models/wordvec-avg.NL/preds.E10.json') as f:
    preds_wva_nl = json.load(f)
preds_wva_nl = [[tag_vocab[t] for t in d] for d in preds_wva_nl]

with open('models/bert.NL/preds.E6.json') as f:
    preds_bert_nl = json.load(f)
preds_bert_nl = [[tag_vocab[t] for t in d] for d in preds_bert_nl]

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
df = []
for i, diag in enumerate(test_data):
    for j, (utt, tag) in enumerate(zip(*diag)):
        df.append(
            { 'diag_id': i
            , 'utt_id': j
            , 'utt': utt
            , 'tag': tag
            , 'pred_wva': preds_wva[i][j]
            , 'pred_bert': preds_bert[i][j]
            , 'pred_wva_nl': preds_wva_nl[i][j]
            , 'pred_bert_nl': preds_bert_nl[i][j]
            })
df = pd.DataFrame(df)
df['correct_wva'] = df.pred_wva == df.tag
df['correct_wva_nl'] = df.pred_wva_nl == df.tag
df['correct_bert'] = df.pred_bert == df.tag
df['correct_bert_nl'] = df.pred_bert_nl == df.tag

In [3]:
print("Wordvec-Avg accuracy (L)  {:.4f}".format(sum(df.correct_wva) / len(df)))
print("Wordvec-Avg accuracy (NL) {:.4f}".format(sum(df.correct_wva_nl) / len(df)))
print("Bert accuracy        (L)  {:.4f}".format(sum(df.correct_bert) / len(df)))
print("Bert accuracy        (NL) {:.4f}".format(sum(df.correct_bert_nl) / len(df)))

Wordvec-Avg accuracy (L)  0.7017
Wordvec-Avg accuracy (NL) 0.6981
Bert accuracy        (L)  0.7788
Bert accuracy        (NL) 0.7723


In [4]:
def scores(df, model):
    pred = 'pred' + '_' + model
    columns = [c + '_' + model for c in ['precision', 'recall', 'f1']]
    tp = df[df[pred] == df.tag].tag.value_counts()
    p  = df[pred].value_counts()
    pr = (tp / p).fillna(0)
    re = (tp / df.tag.value_counts()).fillna(0)
    f1 = (2 * (pr * re) / (pr + re)).fillna(0)
    return pd.DataFrame([pr, re, f1], index=columns).transpose()


In [5]:
sdf = pd.merge(scores(df, 'wva'), scores(df, 'bert'), left_index=True, right_index=True)
sdf = pd.merge(sdf, scores(df, 'wva_nl'), left_index=True, right_index=True)
sdf = pd.merge(sdf, scores(df, 'bert_nl'), left_index=True, right_index=True)

sdf['tag_count'] = df.tag.value_counts()
sdf['f1_diff_bert_wva'] = sdf.f1_bert - sdf.f1_wva
sdf['f1_diff_wva_nl'] = sdf.f1_wva - sdf.f1_wva_nl
sdf['f1_diff_bert_nl'] = sdf.f1_bert - sdf.f1_bert_nl
f1_cols = ['tag_count', 'f1_bert', 'f1_bert_nl', 'f1_wva', 'f1_wva_nl', 'f1_diff_bert_wva', 'f1_diff_wva_nl', 'f1_diff_bert_nl']

In [6]:
sdf = sdf.sort_values('f1_diff_bert_wva', ascending=False)
sdf[f1_cols].head(10)

Unnamed: 0,tag_count,f1_bert,f1_bert_nl,f1_wva,f1_wva_nl,f1_diff_bert_wva,f1_diff_wva_nl,f1_diff_bert_nl
fp,69,0.779661,0.786885,0.0,0.0,0.779661,0.0,-0.007224
h,230,0.746204,0.774194,0.0,0.0,0.746204,0.0,-0.02799
qo,135,0.746032,0.637555,0.0,0.0,0.746032,0.0,0.108477
"fo/o/fw/""/by/bc",169,0.664311,0.618893,0.0,0.577778,0.664311,-0.577778,0.045418
fa,12,0.6,0.444444,0.0,0.0,0.6,0.0,0.155556
qrr,43,0.582278,0.192308,0.0,0.0,0.582278,0.0,0.389971
qw,361,0.811321,0.76942,0.28785,0.409861,0.52347,-0.122011,0.0419
bk,242,0.536697,0.496454,0.054264,0.298851,0.482434,-0.244587,0.040243
ad,189,0.398625,0.441472,0.0,0.0,0.398625,0.0,-0.042846
^h,72,0.361905,0.45045,0.0,0.0,0.361905,0.0,-0.088546


In [7]:
sdf = sdf.sort_values('f1_diff_wva_nl', ascending=False)
sdf[f1_cols].head(10)

Unnamed: 0,tag_count,f1_bert,f1_bert_nl,f1_wva,f1_wva_nl,f1_diff_bert_wva,f1_diff_wva_nl,f1_diff_bert_nl
fc,580,0.75,0.750442,0.57384,0.517316,0.17616,0.056524,-0.000442
bh,221,0.76644,0.726437,0.622449,0.587927,0.143991,0.034522,0.040003
+,3575,0.883608,0.904035,0.727625,0.695283,0.155983,0.032342,-0.020426
fx/sv,5399,0.633404,0.580528,0.576644,0.549779,0.05676,0.026865,0.052876
na,162,0.340136,0.386861,0.03125,0.011905,0.308886,0.019345,-0.046725
fe/ba,927,0.736358,0.711951,0.573302,0.557399,0.163057,0.015903,0.024407
%,3112,0.806894,0.788073,0.701906,0.686722,0.104988,0.015184,0.018821
x,651,0.979561,0.981046,0.869695,0.857143,0.109866,0.012552,-0.001485
sd,14882,0.838227,0.835305,0.792555,0.786595,0.045672,0.00596,0.002922
t1,21,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [8]:
sdf = sdf.sort_values('f1_diff_bert_nl', ascending=False)
sdf[f1_cols].head(10)

Unnamed: 0,tag_count,f1_bert,f1_bert_nl,f1_wva,f1_wva_nl,f1_diff_bert_wva,f1_diff_wva_nl,f1_diff_bert_nl
qrr,43,0.582278,0.192308,0.0,0.0,0.582278,0.0,0.389971
fa,12,0.6,0.444444,0.0,0.0,0.6,0.0,0.155556
qo,135,0.746032,0.637555,0.0,0.0,0.746032,0.0,0.108477
fx/sv,5399,0.633404,0.580528,0.576644,0.549779,0.05676,0.026865,0.052876
no,58,0.289855,0.238806,0.0,0.0,0.289855,0.0,0.051049
"fo/o/fw/""/by/bc",169,0.664311,0.618893,0.0,0.577778,0.664311,-0.577778,0.045418
aa,2387,0.529521,0.485957,0.359533,0.395844,0.169988,-0.036311,0.043564
qw,361,0.811321,0.76942,0.28785,0.409861,0.52347,-0.122011,0.0419
^q,206,0.087591,0.046332,0.0,0.0,0.087591,0.0,0.041259
bk,242,0.536697,0.496454,0.054264,0.298851,0.482434,-0.244587,0.040243


In [9]:
sdf = sdf.sort_values('tag_count', ascending=False)
sdf[f1_cols].head(20)

Unnamed: 0,tag_count,f1_bert,f1_bert_nl,f1_wva,f1_wva_nl,f1_diff_bert_wva,f1_diff_wva_nl,f1_diff_bert_nl
sd,14882,0.838227,0.835305,0.792555,0.786595,0.045672,0.00596,0.002922
b,7763,0.86299,0.867476,0.839909,0.841497,0.023081,-0.001588,-0.004486
fx/sv,5399,0.633404,0.580528,0.576644,0.549779,0.05676,0.026865,0.052876
+,3575,0.883608,0.904035,0.727625,0.695283,0.155983,0.032342,-0.020426
%,3112,0.806894,0.788073,0.701906,0.686722,0.104988,0.015184,0.018821
aa,2387,0.529521,0.485957,0.359533,0.395844,0.169988,-0.036311,0.043564
fe/ba,927,0.736358,0.711951,0.573302,0.557399,0.163057,0.015903,0.024407
qr/qy,868,0.793028,0.767762,0.54717,0.574227,0.245859,-0.027057,0.025266
x,651,0.979561,0.981046,0.869695,0.857143,0.109866,0.012552,-0.001485
fc,580,0.75,0.750442,0.57384,0.517316,0.17616,0.056524,-0.000442


In [10]:
pd.crosstab(df.tag, df.pred_wva)

pred_wva,%,+,aa,b,bh,bk,fc,fe/ba,"fo/o/fw/""/by/bc",fx/sv,na,nn,ny,qo,qr/qy,qw,sd,x
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
%,2191,165,45,247,0,0,6,5,0,86,0,1,10,0,1,1,346,8
+,82,2817,25,33,0,0,2,8,0,196,1,0,0,0,26,8,374,3
^2,15,22,10,14,1,0,1,1,0,26,0,0,0,0,2,0,39,5
^g,0,0,1,4,3,0,0,0,0,0,0,0,0,0,8,0,0,0
^h,16,2,3,1,0,0,2,4,0,7,1,3,3,0,0,0,28,2
^q,5,7,3,5,0,0,0,5,0,39,0,0,0,0,8,0,132,2
aa,49,56,677,1228,0,1,2,51,0,142,0,5,27,0,0,0,138,11
aap/am,0,1,4,0,0,0,1,0,0,5,0,0,0,0,0,0,2,0
ad,5,19,3,3,0,0,4,4,1,22,0,0,1,0,14,1,111,1
ar,4,1,44,0,0,0,0,2,0,2,0,2,0,0,0,0,9,4


In [11]:
pd.crosstab(df.tag, df.pred_bert)

pred_bert,%,+,^2,^h,^q,aa,ad,ar,b,b^m,...,oo/co/cc,qh,qo,qr/qy,qrr,qw,qy^d,sd,t1,x
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
%,2411,31,4,1,0,38,1,0,243,1,...,0,7,1,16,0,5,2,258,0,8
+,35,3105,7,1,0,10,2,0,26,6,...,0,3,0,21,0,4,4,241,0,3
^2,7,23,24,0,0,3,1,0,4,3,...,0,0,0,3,0,0,0,44,0,0
^g,0,0,0,0,0,2,0,0,3,0,...,0,0,0,3,0,0,0,0,0,0
^h,5,0,0,19,1,2,0,0,2,0,...,0,1,0,0,0,0,0,28,0,0
^q,1,2,0,0,12,3,5,0,3,1,...,0,7,0,8,0,1,1,119,0,1
aa,5,4,9,1,1,1139,1,0,966,0,...,0,0,0,0,0,0,0,101,0,0
aap/am,0,0,0,0,0,4,0,0,0,0,...,0,0,0,0,0,0,0,2,0,0
ad,1,2,3,1,8,1,58,0,0,0,...,1,3,0,8,0,0,1,67,0,0
ar,3,1,0,0,0,29,0,0,0,0,...,0,0,0,0,0,0,0,10,0,0
