In [9]:
%load_ext autoreload
%autoreload 2

import os
import json
from collections import Counter

In [4]:
model_dirs = ['/private/home/abisee/models', '/checkpoint/abisee/ahm/pretrain_twitter_split']
mf2wordstats = {}

for model_dir in model_dirs:
    wordstat_files = [fname for fname in os.listdir(model_dir) if 'wordstats.json' in fname]
    for json_file in wordstat_files:
        mf = json_file[:json_file.index('.wordstats.json')]
        if mf == 'goldresponse':
            mf = ' goldresponse'
        print(mf)
        with open(os.path.join(model_dir, json_file), "r") as f:
            data = json.load(f)
        mf2wordstats[mf] = data


seq2seq.valid.beam1
seq2seq.valid.beam20
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20
seq2seq.valid.beam10
 goldresponse
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid9
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid1
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid4
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid3
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid6
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid7
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid5
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid8
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid0
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_cluste

In [5]:
from IPython.display import HTML, display
import tabulate

def model2row(mf, wordstats):
    avg_niwf = wordstats['niwf']['avg_niwf']*100
    row = [
        mf,
        wordstats['report']['exs'],
        wordstats['report']['ppl'],
        "%.2f%%" % wordstats['unique_percent'],
        "%.2f" % wordstats['word_statistics']['mean_wlength'],
        "%.2f" % wordstats['word_statistics']['mean_clength'],
        "%.4f%%" % avg_niwf,
        "%.2f%%" % wordstats['word_statistics']['freqs_perc']['100'],
        "%.2f%%" % wordstats['word_statistics']['freqs_perc']['1000'],
        "%.2f%%" % wordstats['word_statistics']['freqs_perc']['10000'],
        "%.4f" % wordstats['distinct-n']['1'],
        "%.4f" % wordstats['distinct-n']['2'],
        "%.4f" % wordstats['distinct-n']['3'],
        "%.4f" % wordstats['distinct-n']['4'],
    ]
    return row

header_row = ['model name', 
              'num_exs', 
              'ppl',
              'unique_perc', 
              'mean_wlength', 
              'mean_clength', 
              'avg_niwf',
              '% rare<100', 
              '% rare<1000', 
              '% rare<10000',
              'distinct-1',
              'distinct-2',
              'distinct-3',
              'distinct-4',
             ]


table = [header_row] 

for mf in sorted(mf2wordstats.keys()):
    if 'seq2seq.' in mf:
        continue
    table.append(model2row(mf, mf2wordstats[mf]))
    
display(HTML(tabulate.tabulate(table, tablefmt='html')))

0,1,2,3,4,5,6,7,8,9,10,11,12,13
model name,num_exs,ppl,unique_perc,mean_wlength,mean_clength,avg_niwf,% rare<100,% rare<1000,% rare<10000,distinct-1,distinct-2,distinct-3,distinct-4
goldresponse,7801,31.62,98.77%,11.87,51.22,11.7342%,5.40%,15.84%,38.51%,0.0541,0.3336,0.5786,0.6536
convai2_pretrain.valid.beam1,7801,21.13,60.17%,10.52,38.32,1.2864%,0.13%,2.22%,12.11%,0.0162,0.0675,0.1225,0.1724
convai2_pretrain.valid.beam10,7801,21.13,8.61%,8.48,30.19,0.1702%,0.01%,0.16%,2.45%,0.0083,0.0227,0.0325,0.0382
convai2_pretrain.valid.beam20,7801,21.13,6.13%,8.19,29.45,0.1432%,0.00%,0.11%,1.88%,0.0078,0.0206,0.0279,0.0314
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20,7801,24.08,22.14%,11.30,40.42,1.4813%,0.43%,1.34%,5.41%,0.0138,0.0423,0.0617,0.0736
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid0,7801,39.96,3.92%,7.76,28.44,0.0638%,0.00%,0.00%,1.69%,0.0038,0.0118,0.0167,0.0189
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid1,7801,34.14,11.52%,9.09,34.40,0.2123%,0.00%,0.03%,4.63%,0.0072,0.0264,0.0395,0.0467
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid2,7801,30.12,14.22%,9.81,36.89,0.3133%,0.00%,0.12%,4.72%,0.0087,0.0309,0.0463,0.0551
seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid3,7801,27.48,17.52%,10.21,38.08,0.4501%,0.00%,0.30%,4.86%,0.0108,0.0363,0.0545,0.0653


In [42]:
def show_preds(mf):
    counter = Counter()
    preds = mf2wordstats[mf]['predictions']
    counter.update(preds)
    num_unique = len([p for p,count in counter.items() if count==1])
    print("num_unique: ", num_unique, num_unique*100/sum(counter.values()))
#     for p, count in counter.most_common(100):
#         print("%3i   %s" % (count, p))

In [43]:
show_preds('convai2_pretrain.valid.beam20')

num_unique:  586 7.511857454172542


In [44]:
show_preds('seq2seq_twitterpretrained_specificityclusters_10buckets.valid.beam20.fixed_clusterid9')

num_unique:  1983 25.419817972054865
