In [1]:
import pandas as pd
import qgrid
import numpy as np
from scipy.stats import pearsonr, spearmanr, wasserstein_distance
import json, bz2, pickle
from pprint import pprint
from tqdm.autonotebook import tqdm
from sklearn.preprocessing import minmax_scale

  import sys


In [2]:
with bz2.open('./convai1_results.pickle.bz2') as fin:
    convai1_data = pickle.load(fin)
len(convai1_data)

2154

In [3]:
dialogue_scores = list()
indices = list()
dialogue_data = dict()

for d in tqdm(convai1_data):
    d_item = dict()
    dialogue_data[str(d['dialogId'])] = d
    indices.append(str(d['dialogId']))
    d_item['quality'] = d['quality']
    
    pred_keys = list(d['predictions'].keys())
    bert_keys = list(filter(lambda x: 'bert' in x, pred_keys))
    
    for pred_key in bert_keys:
        pred_sum = sum([np.log(x) for x in d['predictions'][pred_key] if x != 0])
        pred_avg = pred_sum / len(d['predictions'][pred_key])
        d_item['{}_log_sum'.format(pred_key)] = pred_sum
        d_item['{}_log_avg'.format(pred_key)] = pred_avg
        
        pred_sum = sum(d['predictions'][pred_key])
        pred_avg = pred_sum / len(d['predictions'][pred_key])
        d_item['{}_sum'.format(pred_key)] = pred_sum
        d_item['{}_avg'.format(pred_key)] = pred_avg
        
        d_item['{}_prd'.format(pred_key)] = np.prod(d['predictions'][pred_key])        
        d_item['{}_prd_avg'.format(pred_key)] = d_item['{}_prd'.format(pred_key)] / len(d['predictions'][pred_key])
        
    prob_keys = list(filter(lambda x: 'prob' in x and x not in bert_keys, pred_keys))
    
    for pred_key in prob_keys: 
        s_sums = [sum(x) for x in d['predictions'][pred_key]]
        s_sums_d_sum = sum([x for x in s_sums if x != 0])
        s_sums_d_avg = s_sums_d_sum / len(s_sums)
        
        d_item['{}_s_sums_d_sum'.format(pred_key)] = pred_sum
        d_item['{}_s_sums_d_avg'.format(pred_key)] = pred_avg
        
        s_sums = [sum([np.log(x_1) for x_1 in x if x_1 != 0]) for x in d['predictions'][pred_key]]
        s_sums_d_sum = sum([x for x in s_sums if x != 0])
        s_sums_d_avg = s_sums_d_sum / len(s_sums)
        
        d_item['{}_s_log_sums_d_sum'.format(pred_key)] = pred_sum
        d_item['{}_s_log_sums_d_avg'.format(pred_key)] = pred_avg
        
        s_prd = [np.prod(x) for x in d['predictions'][pred_key]]
        s_prd_d_sum = sum(s_sums)
        s_prd_d_avg = s_sums_d_sum / len(s_sums)
        
        d_item['{}_s_prod_d_sum'.format(pred_key)] = pred_sum
        d_item['{}_s_prod_d_avg'.format(pred_key)] = pred_avg
        
        s_avg = [float(sum([np.log(x_1) for x_1 in x if x_1 != 0]) / len(x)) for x in d['predictions'][pred_key] if len(x) > 0]        
        s_avg_d_sum = sum(s_avg)
        s_avg_d_avg = s_avg_d_sum / len(s_avg)        
        
        d_item['{}_s_log_avg_d_sum'.format(pred_key)] = s_avg_d_sum
        d_item['{}_s_log_avg_d_avg'.format(pred_key)] = s_avg_d_avg
        
        s_avg = [float(sum(x) / len(x)) for x in d['predictions'][pred_key] if len(x) > 0]        
        s_avg_d_sum = sum(s_avg)
        s_avg_d_avg = s_avg_d_sum / len(s_avg)        
        
        d_item['{}_s_avg_d_sum'.format(pred_key)] = s_avg_d_sum
        d_item['{}_s_avg_d_avg'.format(pred_key)] = s_avg_d_avg
        
    dialogue_scores.append(d_item)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2154.0), HTML(value='')))




In [4]:
dialogue_scores = pd.DataFrame(dialogue_scores)
dialogue_scores.index = indices
dialogue_scores.head()

Unnamed: 0,quality,bert-base-uncased_nsp_0_log_sum,bert-base-uncased_nsp_0_log_avg,bert-base-uncased_nsp_0_sum,bert-base-uncased_nsp_0_avg,bert-base-uncased_nsp_0_prd,bert-base-uncased_nsp_0_prd_avg,bert-base-uncased_nsp_1_log_sum,bert-base-uncased_nsp_1_log_avg,bert-base-uncased_nsp_1_sum,...,gpt2-large_sentences_best_word_probs_s_sums_d_sum,gpt2-large_sentences_best_word_probs_s_sums_d_avg,gpt2-large_sentences_best_word_probs_s_log_sums_d_sum,gpt2-large_sentences_best_word_probs_s_log_sums_d_avg,gpt2-large_sentences_best_word_probs_s_prod_d_sum,gpt2-large_sentences_best_word_probs_s_prod_d_avg,gpt2-large_sentences_best_word_probs_s_log_avg_d_sum,gpt2-large_sentences_best_word_probs_s_log_avg_d_avg,gpt2-large_sentences_best_word_probs_s_avg_d_sum,gpt2-large_sentences_best_word_probs_s_avg_d_avg
-749262821,1.5,-0.000357,-8.9e-05,3.999643,0.999911,0.9996427,0.2499107,-40.210713,-10.052678,0.000357,...,5.4e-05,1.4e-05,5.4e-05,1.4e-05,5.4e-05,1.4e-05,-6.364283,-1.591071,0.97861,0.244652
-155769874,0.5,-0.073824,-0.012304,5.92879,0.988132,0.928835,0.1548058,-51.52024,-8.586707,0.07121,...,0.030035,0.005006,0.030035,0.005006,0.030035,0.005006,-8.110848,-1.351808,2.07365,0.345608
1327080259,1.5,-15.297059,-1.019804,12.062486,0.804166,2.272855e-07,1.515237e-08,-115.073049,-7.671537,2.937514,...,1.512429,0.100829,1.512429,0.100829,1.512429,0.100829,-20.992727,-1.399515,4.736203,0.315747
-1682987452,0.5,-18.694987,-1.699544,8.993401,0.817582,7.600989e-09,6.90999e-10,-83.706999,-7.609727,2.006599,...,2.006208,0.182383,2.006208,0.182383,2.006208,0.182383,-13.923161,-1.265742,3.801247,0.345568
2037906078,2.5,-0.00212,-0.000353,5.997881,0.999647,0.997882,0.1663137,-51.210174,-8.535029,0.002119,...,0.00346,0.000577,0.00346,0.000577,0.00346,0.000577,-9.343397,-1.557233,1.896593,0.316099


In [5]:
for col in dialogue_scores.columns:
    dialogue_scores[col] = minmax_scale(dialogue_scores[col])
    
dialogue_scores.head()

Unnamed: 0,quality,bert-base-uncased_nsp_0_log_sum,bert-base-uncased_nsp_0_log_avg,bert-base-uncased_nsp_0_sum,bert-base-uncased_nsp_0_avg,bert-base-uncased_nsp_0_prd,bert-base-uncased_nsp_0_prd_avg,bert-base-uncased_nsp_1_log_sum,bert-base-uncased_nsp_1_log_avg,bert-base-uncased_nsp_1_sum,...,gpt2-large_sentences_best_word_probs_s_sums_d_sum,gpt2-large_sentences_best_word_probs_s_sums_d_avg,gpt2-large_sentences_best_word_probs_s_log_sums_d_sum,gpt2-large_sentences_best_word_probs_s_log_sums_d_avg,gpt2-large_sentences_best_word_probs_s_prod_d_sum,gpt2-large_sentences_best_word_probs_s_prod_d_avg,gpt2-large_sentences_best_word_probs_s_log_avg_d_sum,gpt2-large_sentences_best_word_probs_s_log_avg_d_avg,gpt2-large_sentences_best_word_probs_s_avg_d_sum,gpt2-large_sentences_best_word_probs_s_avg_d_avg
-749262821,0.222222,0.999998,0.99999,0.044356,0.999917,0.9996593,0.4998296,0.946148,0.188477,1.2e-05,...,2e-06,1.1e-05,2e-06,1.1e-05,2e-06,1.1e-05,0.976474,0.394803,0.029466,0.185227
-155769874,0.0,0.999573,0.99854,0.065752,0.988135,0.9288504,0.3096168,0.931001,0.306823,0.002469,...,0.001237,0.005007,0.001237,0.005007,0.001237,0.005007,0.968154,0.505829,0.069868,0.317857
1327080259,0.222222,0.911509,0.878952,0.133782,0.804123,2.272893e-07,3.030523e-08,0.845887,0.380704,0.101861,...,0.062299,0.100903,0.062299,0.100903,0.062299,0.100903,0.906785,0.483691,0.168104,0.278627
-1682987452,0.0,0.891852,0.798268,0.099743,0.817543,7.601115e-09,1.382021e-09,0.887895,0.385694,0.069581,...,0.082638,0.182519,0.082638,0.182519,0.082638,0.182519,0.940464,0.545767,0.133608,0.317804
2037906078,0.444444,0.999988,0.999959,0.066519,0.999653,0.9978985,0.3326328,0.931416,0.310995,7.3e-05,...,0.000142,0.000575,0.000142,0.000575,0.000142,0.000575,0.962282,0.410505,0.063335,0.279089


In [6]:
def rmse(predictions, targets):
    return np.sqrt(np.mean((predictions-targets)**2))

all_scores = {col:dict() for col in dialogue_scores.columns[1:]}

for col in dialogue_scores.columns[1:]:
    for f in (pearsonr, spearmanr, wasserstein_distance, rmse):
        scores = f(dialogue_scores.quality, dialogue_scores[col])
        if np.isscalar(scores):
            scores = [scores]
        
        for score, name in zip(scores, [f.__name__, f.__name__+'_p']):
            all_scores[col][name] = round(score, 3)

all_scores = pd.DataFrame.from_dict(all_scores, orient='index')
qgrid.show_grid(all_scores,
               grid_options={
    # SlickGrid options
    'fullWidthRows': True,
    'syncColumnCellResize': True,
    'forceFitColumns': False,
    'defaultColumnWidth': 80,
    'rowHeight': 28,
    'enableColumnReorder': False,
    'enableTextSelectionOnCells': True,
    'editable': True,
    'autoEdit': False,
    'explicitInitialization': True,

    # Qgrid options
    'maxVisibleRows': 15,
    'minVisibleRows': 8,
    'sortable': True,
    'filterable': True,
    'highlightSelectedCell': False,
    'highlightSelectedRow': True
})

QgridWidget(grid_options={'fullWidthRows': True, 'syncColumnCellResize': True, 'forceFitColumns': False, 'defa…

In [7]:
key = '-1286395059'
d = dialogue_data[key]
pprint(d['quality'])
pprint(list((idx, u) for idx,u in enumerate(d['utterances'])))
pprint({k:[(idx,'__'.join(u)) for idx,u in enumerate(v)] for k,v in d['predictions'].items() if 'best_words' in k})
pprint(dialogue_scores[dialogue_scores.index == key].to_dict('list'))

0.5
[(0, 'Do you know Utrecht?'),
 (1, 'granted the right to accept only one religion'),
 (2, 'What do you mean?'),
 (3, 'granted the right to accept only one religion'),
 (4, 'Oh no bring me more pastes'),
 (5, 'Calvinism seems like a nice place!\n'),
 (6, 'I dont think so')]
{'gpt2-large_sentences_best_words': [(0,
                                      'ind__,__Ġtitle__Ġto__Ġuse__Ġdonations__ĠDutch__Ġapplication__,'),
                                     (1, 'Ġis__Ġyou__Ġthink__Ġby__ĠI'),
                                     (2,
                                      'anted__,__Ġfact__Ġto__Ġbe__Ġor__Ġthe__Ġof__,'),
                                     (3, ',__,__Ġback__Ġthe__Ġof__ries__!'),
                                     (4,
                                      'm__Ġand__Ġis__Ġto__Ġa__Ġgood__Ġidea__Ġto__ĠI'),
                                     (5, "'m__Ġknow__Ġit__.")],
 'gpt2-medium_sentences_best_words': [(0,
                                       '?__.__Ġright__Ġto__Ġvote_

In [8]:
key = '1111790167'
d = dialogue_data[key]
pprint(d['quality'])
pprint(list((idx, u) for idx,u in enumerate(d['utterances'])))
pprint({k:[(idx,'__'.join(u)) for idx,u in enumerate(v)] for k,v in d['predictions'].items() if 'best_words' in k})
pprint(dialogue_scores[dialogue_scores.index == key].to_dict('list'))

5.0
[(0, 'fascinating :D'),
 (1, 'What do you find interesting about this region?'),
 (2, 'I hadn\'t heard of the concept of a "fall line" before. Had you?'),
 (3, 'Neither had I.'),
 (4, "Other than that, i don't find the snippet very interesting."),
 (5, 'have you been to raleigh?'),
 (6, 'No, I have not. And you?'),
 (7, 'No.'),
 (8, 'I think it would be great to visit this place for leasure.'),
 (9, 'Yea sounds like there might be some good hiking around there.')]
{'gpt2-large_sentences_best_words': [(0,
                                      "'s__Ġyou__Ġthink__Ġmost__Ġabout__Ġthe__Ġgame__?__Ċ"),
                                     (1,
                                      '\'m__\'t__Ġheard__Ġof__Ġthis__Ġregion__Ġof__Ġthe__Ġ"__C__en__"__Ġbefore__,__ĠI__Ġyou__?__ĠI'),
                                     (2, 'Ġhad__ĠI__.__ĠBut'),
                                     (3,
                                      "Ġthan__Ġthe__,__ĠI__Ġwas__'t__Ġthink__Ġit__Ġgame__Ġto__Ġinteresting__.__ĠI"

 'xlnet-base-cased_sentences_word_probs_s_avg_d_sum': [0.08813137530570679],
 'xlnet-base-cased_sentences_word_probs_s_log_avg_d_avg': [0.6530927829779828],
 'xlnet-base-cased_sentences_word_probs_s_log_avg_d_sum': [0.9604477326261017],
 'xlnet-base-cased_sentences_word_probs_s_log_sums_d_avg': [0.11185919438645127],
 'xlnet-base-cased_sentences_word_probs_s_log_sums_d_sum': [0.041437723989253536],
 'xlnet-base-cased_sentences_word_probs_s_prod_d_avg': [0.11185919438645127],
 'xlnet-base-cased_sentences_word_probs_s_prod_d_sum': [0.041437723989253536],
 'xlnet-base-cased_sentences_word_probs_s_sums_d_avg': [0.11185919438645127],
 'xlnet-base-cased_sentences_word_probs_s_sums_d_sum': [0.041437723989253536],
 'xlnet-large-cased_sentences_best_word_probs_s_avg_d_avg': [0.5534925037933643],
 'xlnet-large-cased_sentences_best_word_probs_s_avg_d_sum': [0.07870152397825131],
 'xlnet-large-cased_sentences_best_word_probs_s_log_avg_d_avg': [0.7242968797572396],
 'xlnet-large-cased_sentences_bes

In [9]:
key = '-155769874'
d = dialogue_data[key]
pprint(d['quality'])
pprint(list((idx, u) for idx,u in enumerate(d['utterances'])))
pprint({k:[(idx,'__'.join(u)) for idx,u in enumerate(v)] for k,v in d['predictions'].items() if 'best_words' in k})
pprint(dialogue_scores[dialogue_scores.index == key].to_dict('list'))

0.5
[(0, 'Hi'),
 (1, 'Who uses the four stages of civil society ?'),
 (2, 'Ehh its incorrect. Hint: first 3 answer letters is "fer" '),
 (3, 'What is your name?'),
 (4, 'What'),
 (5, 'Please, speak with me.'),
 (6, 'Please, speak with me. It gives me energy to live')]
{'gpt2-large_sentences_best_words': [(0,
                                      'a__Ġthe__Ġsite__-__Ġof__Ġgrief__Ġdisobedience__?__Ċ'),
                                     (1,
                                      '.__,__Ġa__Ġto__ĠThe__aha__:__Ġit__,__Ġstages__Ġare__Ġare__Ġthe__civil__"__Ġand'),
                                     (2, 'Ġis__Ġthe__Ġfavorite__?__Ċ'),
                                     (3, 'Ġis'),
                                     (4, 'Ċ__Ġwhat__Ġto__Ġme__.__Ċ'),
                                     (5,
                                      ",__Ġspeak__Ġwith__Ġme__.__Ċ__'s__Ġme__Ġa__.__Ġkeep__.")],
 'gpt2-medium_sentences_best_words': [(0,
                                       'a__Ġthis__Ġapp__-__Ġof_

 'xlnet-base-cased_sentences_best_word_probs_s_log_avg_d_sum': [0.9600647002234467],
 'xlnet-base-cased_sentences_best_word_probs_s_log_sums_d_avg': [0.005007333767808692],
 'xlnet-base-cased_sentences_best_word_probs_s_log_sums_d_sum': [0.0012369965259775525],
 'xlnet-base-cased_sentences_best_word_probs_s_prod_d_avg': [0.005007333767808692],
 'xlnet-base-cased_sentences_best_word_probs_s_prod_d_sum': [0.0012369965259775525],
 'xlnet-base-cased_sentences_best_word_probs_s_sums_d_avg': [0.005007333767808692],
 'xlnet-base-cased_sentences_best_word_probs_s_sums_d_sum': [0.0012369965259775525],
 'xlnet-base-cased_sentences_word_probs_s_avg_d_avg': [0.1528864338569163],
 'xlnet-base-cased_sentences_word_probs_s_avg_d_sum': [0.05473718495407207],
 'xlnet-base-cased_sentences_word_probs_s_log_avg_d_avg': [0.5243691286075626],
 'xlnet-base-cased_sentences_word_probs_s_log_avg_d_sum': [0.9662992474097879],
 'xlnet-base-cased_sentences_word_probs_s_log_sums_d_avg': [0.005007333767808692],
 'xl