In [12]:
from IPython.display import display
import matplotlib.pyplot as plt
from sklearn import metrics

import json
import os
import pandas as pd

from util import gen_model_preds_df

groups_s = json.load(open("SWDA_dialogue-acts-groups.json"))
groups_a = json.load(open("AMI-DA_dialogue-acts-groups.json"))

def report_metrics(frames, conditions, group=False):
    metric_funcs = [
        lambda x,y: metrics.precision_score(x,y,average='macro'), 
        lambda x,y: metrics.recall_score(x,y,average='macro'), 
        lambda x,y: metrics.f1_score(x,y,average='macro'),
        lambda x,y: metrics.precision_score(x,y,average='micro')]
    metric_names = [
        'macro precision',
        'macro recall',
        'macro f1',
        'micro accuracy']
    if not group:
        table = [[
            metric(df['da_tag'], df[cond])
                for df in frames]
                for cond in conditions for metric in metric_funcs]
    else:
        table = [[
            metric([groups[ix][dat] for dat in df['da_tag']], 
                   [groups[ix][dat] for dat in df[cond]])
                for ix,df in enumerate(frames)]
                for cond in conditions for metric in metric_funcs for g in groups] 
    
    multiindex = [[c for c in conditions for m in metric_names],
        [m for c in conditions for m in metric_names]]
    return pd.DataFrame(table, columns=['SWBD', 'AMI'], index=multiindex)

conditions = ['NL_bert', 'L_bert', 'NL_cnn', 'L_cnn']

model_dirs = [f'/scratch/DistributionalDiscourse/models/SWDA-{c}_2019-11-20/' for c in conditions]
dfs = gen_model_preds_df('SWDA', conditions, model_dirs, group=False)
display(dfs.head(10))

model_dirs = [f'/scratch/DistributionalDiscourse/models/AMI-DA-{c}_2019-11-20/' for c in conditions]
dfa = gen_model_preds_df('AMI-DA', conditions, model_dirs, group=False)
display(dfa.head(10))
dfa = dfa[dfa['da_tag'].notnull()]

Unnamed: 0_level_0,Unnamed: 1_level_0,speaker,utt,da_tag,NL_bert,L_bert,NL_cnn,L_cnn
dialogue_id,utt_no,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
sw2608,0,A,Okay .,"fo_o_fw_""_by_bc","fo_o_fw_""_by_bc","fo_o_fw_""_by_bc","fo_o_fw_""_by_bc","fo_o_fw_""_by_bc"
sw2608,1,B,Do you want to go first ?,qy,qy,qy,qy,qy
sw2608,2,A,"You can go first ,",oo_co_cc,ad,ad,sd,sd
sw2608,3,A,or I will .,oo_co_cc,sd,sd,sd,sd
sw2608,4,B,"Well , you go ahead .",ad,ad,ad,ad,sd
sw2608,5,A,Okay .,aa,b,b,bk,bk
sw2608,6,A,"Well , I'm going to tell you what I'd have <la...",sd,^h,sd,sd,sd
sw2608,7,B,<laughter> .,x,x,x,x,x
sw2608,8,A,"Down in the south , we have a lot of shrimp ,",sd,sd,sd,sd,sd
sw2608,9,B,"# Oh , #",b,b,b,b,b


Unnamed: 0_level_0,Unnamed: 1_level_0,speaker,utt,da_tag,NL_bert,L_bert,NL_cnn,L_cnn
dialogue_id,utt_no,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
IS1001d,0,A,Okay .,ami_da_16,ami_da_2,ami_da_9,ami_da_9,ami_da_9
IS1001d,1,A,Je croix que c'est dommage de le,ami_da_3,ami_da_14,ami_da_16,ami_da_4,ami_da_4
IS1001d,2,A,it will be sad to destroy this prototype .,ami_da_16,ami_da_14,ami_da_14,ami_da_4,ami_da_4
IS1001d,3,A,It really looks like a banana .,ami_da_4,ami_da_9,ami_da_9,ami_da_9,ami_da_9
IS1001d,4,C,It is a banana .,ami_da_14,ami_da_9,ami_da_4,ami_da_4,ami_da_9
IS1001d,5,A,It is a banana . <laughter>,ami_da_12,ami_da_9,ami_da_4,ami_da_4,ami_da_9
IS1001d,6,C,It is the essence of bananas .,ami_da_14,ami_da_9,ami_da_9,ami_da_4,ami_da_4
IS1001d,7,C,I would be confused with this thing .,ami_da_15,ami_da_4,ami_da_4,ami_da_9,ami_da_4
IS1001d,8,A,,ami_da_3,ami_da_3,ami_da_3,ami_da_3,ami_da_3
IS1001d,9,C,,ami_da_3,ami_da_3,ami_da_3,ami_da_3,ami_da_3


In [8]:
metric_funcs = [
#     lambda x,y: metrics.precision_score(x,y,average='macro'), 
#     lambda x,y: metrics.recall_score(x,y,average='macro'), 
#     lambda x,y: metrics.f1_score(x,y,average='micro'),
    lambda x,y: metrics.precision_score(x,y,average='micro')]
for cond in conditions:
    print(cond)
    for g in ['Forward-Communicative-Function', 'Backwards-Communicative-Function']:
        print(g)
        dfsf = dfs[dfs['da_tag'] == g]
        for mf in metric_funcs:
            print(f"{mf(dfsf['da_tag'],dfsf[cond])*100:.2f}")
    print()

NL_bert
Forward-Communicative-Function
92.54
Backwards-Communicative-Function
91.67

L_bert
Forward-Communicative-Function
95.42
Backwards-Communicative-Function
90.91

NL_cnn
Forward-Communicative-Function
93.30
Backwards-Communicative-Function
90.85

L_cnn
Forward-Communicative-Function
94.64
Backwards-Communicative-Function
90.13



In [9]:
for cond in conditions:
    print(cond)
    for g in ['Forward-Communicative-Function', 'Backwards-Communicative-Function']:
        print(g)
        dfaf = dfa[dfa['da_tag'] == g]
        for mf in metric_funcs:
            print(f"{mf(dfaf['da_tag'],dfaf[cond])*100:.2f}")
    print()

NL_bert
Forward-Communicative-Function
87.77
Backwards-Communicative-Function
80.83

L_bert
Forward-Communicative-Function
87.36
Backwards-Communicative-Function
79.87

NL_cnn
Forward-Communicative-Function
91.85
Backwards-Communicative-Function
71.21

L_cnn
Forward-Communicative-Function
91.25
Backwards-Communicative-Function
71.28



In [13]:
report_metrics([dfs,dfa], conditions)

  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


Unnamed: 0,Unnamed: 1,SWBD,AMI
NL_bert,macro precision,0.455734,0.59382
NL_bert,macro recall,0.367783,0.465963
NL_bert,macro f1,0.381006,0.490873
NL_bert,micro accuracy,0.770665,0.670764
L_bert,macro precision,0.561119,0.588414
L_bert,macro recall,0.43047,0.48445
L_bert,macro f1,0.459891,0.501652
L_bert,micro accuracy,0.769267,0.671383
NL_cnn,macro precision,0.556029,0.488473
NL_cnn,macro recall,0.343041,0.361419
