In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, modified_precision
from nltk.translate.chrf_score import sentence_chrf, corpus_chrf
from nltk.metrics import scores
import scipy.io.wavfile
from IPython.display import Audio
from IPython.display import display
from nltk.stem import *
# from nltk.stem.snowball import SnowballStemmer
from stemming.porter2 import stem
import stemming
from nltk.metrics.scores import recall

from nltk.corpus import stopwords

%matplotlib inline

In [None]:
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]    
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.    
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)

In [None]:
smooth_fun = nltk.translate.bleu_score.SmoothingFunction()

In [None]:
from nmt_run import *

In [None]:
cfg_path = "./sp2enw/sp_.50"

In [None]:
last_epoch, model, optimizer, m_cfg, t_cfg = check_model(cfg_path)

In [None]:
log_train = np.loadtxt(os.path.join(cfg_path, "train.log"), delimiter=',', skiprows=False).transpose()
log_test = np.genfromtxt(os.path.join(cfg_path, "dev.log"), delimiter=',', usecols = (0,1,2)).transpose()

In [None]:
fig, ax1 = plt.subplots()
fig.set_size_inches(8,5)
ax1.plot(log_train[0], log_train[1], color='#ff7700')
ax1.plot(log_test[0], log_test[1], 'r--')
ax1.set_xlabel('epoch', size=24)
ax1.set_ylabel('loss', color='r', size=24)
for tl in ax1.get_yticklabels():
    tl.set_color('r')
    tl.set_fontsize(18)
plt.legend(['train loss', 'dev loss'], bbox_to_anchor=(1.45, 0.96), framealpha=0, fontsize=20)    
ax2 = ax1.twinx()
ax2.plot(log_test[0], log_test[2]*100, 'b-')
ax2.set_xlabel('iteration')
ax2.set_ylabel('dev bleu', color='b', size=24)
# ax1.set_xlim(0, 60)
for tl in ax2.get_yticklabels():
    tl.set_color('b')
    tl.set_fontsize(18) 
plt.legend(['dev bleu'], bbox_to_anchor=(1.44, 1.04), framealpha=0, fontsize=20)
# plt.legend(['dev bleu'], bbox_to_anchor=(1.06, 0.9), framealpha=0, fontsize=20)
plt.grid(False)
plt.tight_layout()

In [None]:
import nltk.translate.bleu_score

In [None]:
def play_utt(utt, m_dict):
    sr, y = scipy.io.wavfile.read(os.path.join(wavs_path, utt.rsplit("-",1)[0]+'.wav'))
    start_t = min(seg['start'] for seg in m_dict[utt]['seg'])
    end_t = max(seg['end'] for seg in m_dict[utt]['seg'])
    print(start_t, end_t)
    start_t_samples, end_t_samples = int(start_t*sr), int(end_t*sr)
    display(Audio(y[start_t_samples:end_t_samples], rate=sr))

In [None]:
def display_words(m_dict, v_dict, preds, utts, dec_key, key, play_audio=False, displayN=-1):
    if displayN == -1:
        displayN = len(utts)
    es_ref = []
    en_ref = []
    for u in utts:
        es_ref.append(" ".join([w.decode() for w in m_dict[u]['es_w']]))
        if type(m_dict[u][dec_key]) == list:
            en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w']]))
        else:
            en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w'][0]]))

    en_pred = []
    join_str = ' ' if dec_key.endswith('_w') else ''

    for p in preds:
        t_str = join_str.join([v_dict['i2w'][i].decode() for i in p])
        t_str = t_str[:t_str.find('_EOS')]
        en_pred.append(t_str)

    for u, es, en, p in sorted(list(zip(utts, es_ref, en_ref, en_pred)))[:displayN]:
        # for reference, 1st word is GO_ID, no need to display
        print("Utterance: {0:s}".format(u))
        display_pp = PrettyTable(["cat","sent"], hrules=True)
        display_pp.align = "l"
        display_pp.header = False
        display_pp.add_row(["es ref", textwrap.fill(es,50)])
        display_pp.add_row(["en ref", textwrap.fill(en,50)])
        display_pp.add_row(["en pred", textwrap.fill(p,50)])

        print(display_pp)
        if play_audio:
            play_utt(u, m_dict)
    

In [None]:
def display_words_latex(m_dict, v_dict, preds, utts, dec_key):
    print("min length={0:d}, max length={1:d}".format(min_len, max_len))
    es_ref = []
    en_ref = []
    for u in utts:
        es_ref.append(" ".join([w.decode() for w in m_dict[u]['es_w']]))
        if type(m_dict[u][dec_key]) == list:
            en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w']]))
        else:
            en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w'][0]]))

    en_pred = []
    join_str = ' ' if dec_key.endswith('_w') else ''

    for p in preds:
        t_str = join_str.join([v_dict['i2w'][i].decode() for i in p])
        t_str = t_str[:t_str.find('_EOS')]
        en_pred.append(t_str)

    total_matching_len = 0

    for u, es, en, p in zip(utts, es_ref, en_ref, en_pred):
        # for reference, 1st word is GO_ID, no need to display
        print("Utterance: {0:s}".format(u))
        print("{0:d} & {1:s} & {2:s} & {3:s} \\\\".format(total_matching_len, es, en, p))



In [None]:
m_cfg['model_dir']

In [None]:
def write_predictions_to_file(m_dict, v_dict, preds, utts, dec_key, key, stemmify=False):
    en_hyp = []
    en_ref = []
    ref_key = 'en_w' if 'en_' in dec_key else 'es_w'
    src_key = 'es_w'
    
    for u in tqdm(utts, ncols=80):
        if type(m_dict[u][ref_key]) == list:
            if stemmify:
                en_ref.append(" ".join([stem(w.decode()) for w in m_dict[u]['en_w']]))
            else:
                en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w']]))
        else:
            en_r_list = []
            for r in m_dict[u][ref_key]:
                if stemmify:
                    en_r_list.append(" ".join([stem(w.decode()) for w in r]))
                else:
                    en_r_list.append(" ".join([w.decode() for w in r]))
            en_ref.append(en_r_list)

    join_str = ' ' if dec_key.endswith('_w') else ''

    total_matching_len = 0

    for u, p in zip(utts, preds):
        if stemmify:
            t_str = join_str.join([stem(v_dict['i2w'][i].decode()) if i != EOS_ID else EOS.decode() for i in p])
        else:
            t_str = join_str.join([v_dict['i2w'][i].decode() for i in p])
        t_str = t_str[:t_str.find('_EOS')]
        en_hyp.append(t_str)

    out_fname = os.path.join(m_cfg['model_dir'], "{0:s}_mt-output".format(key))
    print("writing to file: {0:s}".format(out_fname))
    
    with open(out_fname, "w") as pred_f:
        for p in en_hyp:
            pred_f.write("{0:s}\n".format(p))
        # end for
    # end while
    
    if type(m_dict[utts[0]][ref_key]) == list:
        with open(os.path.join(m_cfg['model_dir'], "{0:s}.ref0".format(key)), "w") as ref_f:
            for r in en_ref:
                ref_f.write("{0:s}\n".format(r))
    else:
        num_ref = len(m_dict[u][ref_key])
        for i in range(num_ref):
            with open(os.path.join(m_cfg['model_dir'], "{0:s}_en.ref{1:d}".format(key,i)), "w") as ref_f:
                for r in en_ref:
                    ref_f.write("{0:s}\n".format(r[i]))
                # end for each utt
            # end with
        # end for reference
    # end else
    print("done")
    return en_ref, en_hyp

### Fisher dev

In [None]:
train_key = m_cfg['train_set']
dev_key = m_cfg['dev_set']
batch_size=t_cfg['batch_size']
enc_key=m_cfg['enc_key']
dec_key=m_cfg['dec_key']
input_path = os.path.join(m_cfg['data_path'], m_cfg['dev_set'])
# -------------------------------------------------------------------------
# get data dictionaries
# -------------------------------------------------------------------------
map_dict, vocab_dict, bucket_dict = get_data_dicts(m_cfg)
batch_size = {'max': 96, 'med': 128, 'min': 256, 'scale': 1}

In [None]:
random.seed("meh")
# random.seed("haha")

In [None]:
pred_sents, utts, dev_loss = feed_model(model,
                                        optimizer=optimizer,
                                        m_dict=map_dict[dev_key],
                                        b_dict=bucket_dict[dev_key],
                                        vocab_dict=vocab_dict,
                                        batch_size=batch_size,
                                        x_key=enc_key,
                                        y_key=dec_key,
                                        train=False,
                                        input_path=input_path,
                                        max_dec=m_cfg['max_en_pred'],
                                        t_cfg=t_cfg,
                                        use_y=False)

In [None]:
# Eval parameters
ref_index = -1
min_len, max_len= 0, m_cfg['max_en_pred']
# min_len, max_len = 0, 5
displayN = 50
m_dict=map_dict[dev_key]
# wavs_path = os.path.join(m_cfg['data_path'], "wavs")
wavs_path = os.path.join("../chainer2/speech2text/both_fbank_out/", "wavs")
v_dict = vocab_dict['en_w']
key = m_cfg['dev_set']

In [None]:
fsh_filt_pred, fsh_filt_utts = zip(*sorted([(p,u) for p, u in zip(pred_sents, utts) if (len(m_dict[u]['es_w']) >= min_len) and 
                                        (len(m_dict[u]['es_w']) <= max_len)]))

In [None]:
print("length filtered utterances = {0:d}".format(len(fsh_filt_utts)))

In [None]:
display_words(m_dict, v_dict, 
              fsh_filt_pred, 
              fsh_filt_utts, dec_key, 
              key, 
              play_audio=True, 
              displayN=displayN)

In [None]:
es_ref = []
en_ref = []
for u in fsh_filt_utts:
    es_ref.append(" ".join([w.decode() for w in m_dict[u]['es_w']]))
    if type(m_dict[u][dec_key]) == list:
        en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w']]))
    else:
        en_ref.append(" ".join([w.decode() for w in m_dict[u]['en_w'][0]]))

en_pred = []
join_str = ' ' if dec_key.endswith('_w') else ''

In [None]:
os.chdir("..")
os.chdir("/afs/inf.ed.ac.uk/group/project/lowres/work/speech2text")

In [None]:
b, chrf, h, r = calc_bleu(m_dict, 
                          v_dict, 
                          fsh_filt_pred, 
                          fsh_filt_utts, 
                          dec_key, 
                          ref_index=ref_index)

print("BLEU score on: {0:s} = {1:.2f}".format(key, b * 100))
print("-"*60)

model_refs = {u: mr for u, mr in zip(fsh_filt_utts, r)}
model_hyps = {u: mh for u, mh in zip(fsh_filt_utts, h)}

all_weights=[(1.,0.,0.,0.),
             (0.,1.,0.,0.),
             (0.,0.,1.,0.),
             (0.,0.,0.,1.),
             (1./2,1./2,0.,0.),
             (1./3,1./3,1./3,0.),
             (.25,.25,.25,.25)]

In [None]:
en_ref, en_hyp = write_predictions_to_file(m_dict, v_dict, fsh_filt_pred, fsh_filt_utts, 
                                           dec_key, key, stemmify=False)

In [None]:
len(en_hyp)

In [None]:
_, _ = corpus_precision_recall(r, h)

In [None]:
model_prec_recall = {'precision' : {}, 'recall' : {}, "tp": 0, "tc": 0, "tr": 0}

for utt_id, ref, hyp in zip(fsh_filt_utts, r, h):
    es_ref = [w.decode() for w in m_dict[utt_id]['es_w']]
    
    pval, rval = modified_precision_recall(ref, hyp, n=1)
    model_prec_recall['tc'] += pval.numerator
    model_prec_recall['tp'] += pval.denominator
    model_prec_recall['tr'] += rval.denominator

    model_prec_recall['precision'][utt_id] = float(pval)
    model_prec_recall['recall'][utt_id] = float(rval)
# end for
    
model_prec_recall['total_precision'] = model_prec_recall['tc'] / model_prec_recall['tp']
model_prec_recall['total_recall'] = model_prec_recall['tc'] / model_prec_recall['tr']

In [None]:
print("\t\tModel metrics")
print("-"*60)
print("{0:10s} = {1:0.3f}\n{2:10s} = {3:0.3f}".format("precision",
                                                      model_prec_recall['total_precision'],
                                                      "recall",
                                                      model_prec_recall['total_recall']))


In [None]:
print("model-s2t BLEU score:")
"{0:0.3f}".format(100.0 * corpus_bleu(model_refs.values(), model_hyps.values()))

In [None]:
print("-"*60)
print("\t\tMODEL")
print("-"*60)
_, _ = corpus_precision_recall(model_refs.values(), model_hyps.values())

In [None]:
print("-"*60)
print("\t\tMODEL")
print("-"*60)
metrics = {"p": {1: [], 2: [], 3: [], 4: []},
           "r": {1: [], 2: [], 3: [], 4: []}}

for ix in range(len(list(model_refs.values())[0])):
    temp_refs = [[i[ix]] for i in model_refs.values()]
    print("-"*60)
    print("\t\tUsing reference = {0:d}".format(ix+1))
    print("-"*60)
    ps, rs = corpus_precision_recall(temp_refs, model_hyps.values())
    for i, (p, r) in enumerate(zip(ps, rs)):
        metrics['p'][i+1].append(p)
        metrics['r'][i+1].append(r)
    

In [None]:
def get_mean_std(vals_dict):
#     print("estimated mean")
    k = []
    u = []
    s = []
    for m, vals in vals_dict.items():
        k.append(m)
        u.append(np.mean(vals))
        s.append(np.std(vals)/np.sqrt(len(vals)))
        print("{0:10d}-gram = {1:.2f} ± {2:.2f}".format(m, np.mean(vals), np.std(vals)/np.sqrt(len(vals))))
        
    print(",".join(map(lambda x : "{0:.2f}".format(x), u)))
    print(",".join(map(lambda x : "{0:.2f}".format(x), s)))
    return k, u, s

In [None]:
s2w_p_r = {}

print("Precision:")
_, s2w_p_r['p'], s2w_p_r['p_std'] = get_mean_std(metrics['p'])

print("Recall:")
_, s2w_p_r['r'], s2w_p_r['r_std'] = get_mean_std(metrics['r'])

In [None]:
with open(os.path.join(m_cfg['model_dir'], "s2w_p_r.json"), "w") as out_f:
    json.dump(s2w_p_r, out_f, indent=4)

In [None]:
model_predictions = {}
for u, hyp in zip(fsh_filt_utts, h):
    model_predictions[u] = hyp

### word level analysis

In [None]:
top_k = 100
min_word_len = 1

In [None]:
stop_words = set(nltk.corpus.stopwords.words("english"))
len(stop_words)

In [None]:
def get_words(m_dict):
    words = []
    for u in m_dict:
        if type(m_dict[u]['en_w']) == list:
            words.extend([w.decode() for w in m_dict[u]['en_w']])
        else:
            for ref in m_dict[u]['en_w']:
                words.extend([w.decode() for w in ref])
    return Counter(words)

In [None]:
# words in train
train_words = get_words(map_dict['fisher_train'])
train_words_top_k = [(w,f) for w, f in sorted(train_words.items(), reverse=True, key=lambda t:t[1]) 
                     if w not in stop_words and len(w) >= min_word_len][:top_k]

train_only_words = set(train_words.keys())

print("{0:20s} | {1:10d}".format("# train word types", len(train_words)))
print("{0:20s} | {1:10d}".format("# train word tokens", sum(train_words.values())))

In [None]:
train_words_top_k[:5]

In [None]:
[(w,f) for w,f in train_words_top_k if "'" in w]

In [None]:
dev_words = get_words(map_dict['fisher_dev'])
dev_words_top_k = [(w,f) for w, f in sorted(dev_words.items(), reverse=True, key=lambda t:t[1]) 
                     if w not in stop_words and len(w) >= min_word_len][:top_k]

dev_only_words = set(dev_words.keys())

In [None]:
dev_words_top_k[:5]

In [None]:
oov_words = {w:f for w,f in dev_words.items() if w not in train_only_words}

In [None]:
print("{0:20s} | {1:10d}".format("# dev word types", len(dev_only_words)))
print("{0:20s} | {1:10d}".format("# dev word tokens", sum(dev_words.values())))

print("{0:20s} | {1:10d}".format("# oov word types", len(oov_words)))
print("{0:20s} | {1:10d}".format("# oov word tokens", sum(oov_words.values())))


In [None]:
"{0:.1f}%".format(sum(oov_words.values()) / sum(dev_words.values()) * 100)

Look at utterances which our model is doing better at, and compare to Google.

Are we doing better on a few calls only? Is there any particular speaker or call messing up our results?

### Query task

In [None]:
crises_lex_fname = "../criseslex/CrisisLexLexicon/CrisisLexRec.txt"
train_text_fname= "../installs/fisher-callhome-corpus/corpus/ldc/fisher_train.en"
topics_fname = "../criseslex/fsp06_topics_in_english.txt"

In [None]:
topics = [ "peace", "Music", "Marriage", "Religion", "Cell phones", 
           "Dating", "Telemarketing and SPAM", "Politics", "Travel", 
           "Technical devices", "Healthcare", "Advertisements", "Power", 
           "Occupations", "Movies", "Welfare", "Breaking up", "Location", 
           "Justice", "Memories", "Crime", "Violence against women", "Equality", 
            "Housing", "Immigration",     
            # new topics
           "Interracial", "Christians", "muslims", "jews", "e-mail", 
           "phone", "democracy", "Democratic", "Republican", "technology", 
           "leadership", "community", "jury", "police", "inequality", 
           "renting", "Violence", "immigrants", "immigrant", "skilled", 
           "Telemarketing", "SPAM", "skill", "job", "health", "mobile", 
            "ads", "physical", "emotional", "bubble", "rent", "economy", 
            "abuse", "women", "city", "country", "suburban", "dollar", 
            "united states", "laws", "phone", "race", "biracial", "interracial", 
            "marriage", "lyrics", "sexuality", "medicine", "television", "european",
            "home", "protect", "spouse", "language", "cellphone", "money",
            "doctor", "insurance", "cigarettes", "alcohol", "income", "salary",
            "class", "censor", "rating", "programs", "shows", "government",
            "relationship", "legal", "event", "life", "safe", "victim", "cops",
            "wage", "illegal"
            ]
topics = list(set(t.lower() for t in topics))
topics_stem = [stem(t) for t in topics]

In [None]:
dev_words_stem = {}
for w in dev_words:
    stem_w = stem(w)
    if stem_w not in dev_words_stem:
        dev_words_stem[stem_w] = 0
    dev_words_stem[stem_w] += dev_words[w]
# end for

In [None]:
len(topics)

In [None]:
# NOTE this checks for # of occurrences of the word in the specified text
# We later use the # of utterrances, which will be lower or equal to this number
prune_topics = [t for t in topics if dev_words_stem.get(stem(t),0) > 5 and dev_words_stem.get(stem(t),0) < 100]
len(prune_topics)

In [None]:
prune_topics[:5]

In [None]:
random.seed("selec_query_terms")
sel_topics = random.sample(prune_topics, min(len(prune_topics), 100))

In [None]:
len(sel_topics)

In [None]:
def find_all_terms_references(utt_text, terms):
    terms_set = set([stem(i) for i in terms])
    terms_search_res = {}
    full_words = {}

    for u in tqdm(utt_text, ncols=80):
        l = (utt_text[u]['en_w'] if type(utt_text[u]) == dict 
             else utt_text[u])
        for r in l:
            for w in set(r):
                decoded_w = w.decode() if type(w) != str else w
                w_to_search = stem(decoded_w)
                if w_to_search in terms_set:
                    if w_to_search not in terms_search_res:
                        terms_search_res[w_to_search] = set()
                        full_words[w_to_search] = set()
                    terms_search_res[w_to_search].add(u)
                    full_words[w_to_search].add(decoded_w)
                # end if found
            # end for current reference
        # end for all references
    # end for all utterances
    return terms_search_res, full_words


def find_all_terms_predictions(utt_text, terms):
    terms_set = set([stem(i) for i in terms])
    terms_search_res = {}
    full_words = {}

    for u in tqdm(utt_text, ncols=80):
        r = utt_text[u]
        for w in set(r):
            w_to_search = stem(w)
            if w_to_search in terms_set:
                if w_to_search not in terms_search_res:
                    terms_search_res[w_to_search] = set()
                    full_words[w_to_search] = set()
                terms_search_res[w_to_search].add(u)
                full_words[w_to_search].add(u)
                # end if found
            # end for current reference
    # end for all utterances
    return terms_search_res, full_words

In [None]:
def terms_prec_recall(preds, refs):
    prec_recall = {}
    prec_recall = {'t':0, 'tp':0, 'tc':0, 'terms':{}}
    for term in refs.keys():
        prec_recall['terms'][term] = {}
        prec_recall['terms'][term]['t'] = len(refs[term])
        preds_occ = preds.get(term, set())
        prec_recall['terms'][term]['tp'] = len(preds_occ)
        prec_recall['terms'][term]['tc'] = len(refs[term] & preds_occ)
        prec_recall['t'] += prec_recall['terms'][term]['t']
        prec_recall['tp'] += prec_recall['terms'][term]['tp']
        prec_recall['tc'] += prec_recall['terms'][term]['tc']
    # end for each term
    return prec_recall

In [None]:
ref_topics, ref_topic_labels = find_all_terms_references(map_dict['fisher_dev'], sel_topics)
pred_topics, pred_topics_topics_labels = find_all_terms_predictions(model_predictions, sel_topics)

In [None]:
len(ref_topics), len(pred_topics)

In [None]:
ref_topics

In [None]:
prune_topics_in_utts = {k:v for k, v in ref_topics.items() if len(v) >=3}

In [None]:
len(prune_topics_in_utts)

In [None]:
from  matplotlib.ticker import FuncFormatter

In [None]:
sns.set_style("darkgrid")

In [None]:
sns.set_style("ticks")
fig, ax = plt.subplots(figsize=(6,10),nrows=1, ncols=1)


sorted_counts = [t[0] for t in sorted(ref_topics.items(), reverse=True, key=lambda t:len(t[1]))]
topic_count = [len(ref_topics[t]) for t in sorted_counts]

topic_labels = ["-".join(list(ref_topic_labels[t])[:3]) for t in sorted_counts]

ax = sns.barplot(x=topic_count, 
                 y=topic_labels, 
                 color=tableau20[6], alpha=.7)
                 #**{"label":"topic counts", "alpha":0.5}, ax=ax)

# ax.set_xlabel("# of utts in which topic occurs", size=20)

max_count_val = max([len(v) for v in ref_topics.values()])
min_count_val = max([len(v) for v in ref_topics.values()])
plt.xticks(list(range(0,max_count_val, 5))+[min_count_val, max_count_val])

# ax.legend(loc='upper right', bbox_to_anchor=(1.1, 0.9),
#                       ncol=1, fancybox=True, shadow=True, fontsize=20)

sns.despine(left=True, bottom=True, top=True, right=False)

for t in ax.get_yticklabels():
    t.set_fontsize(15) 
    
for i, t in enumerate(ax.get_xticklabels()):
    if i > 0 and i < (len(ax.get_xticklabels())-1):
        t.set_visible(False)
    else:
        t.set_fontsize(16)

fig.tight_layout()


fig.savefig("../criseslex/sel_topics.png", dpi=300)

In [None]:
model_topics_p_r = terms_prec_recall(preds = pred_topics, refs = ref_topics)

In [None]:
model_topics_p_r['precision'] = model_topics_p_r['tc'] / model_topics_p_r['tp']
model_topics_p_r['recall'] = model_topics_p_r['tc'] / model_topics_p_r['t']

In [None]:
print("-"*40)
print("-------------- Query by Text")
print("-"*40)
print("{0:20s} | {1:5.2f}%".format("precision", model_topics_p_r['precision']*100.0))
print("{0:20s} | {1:5.2f}%".format("recall", model_topics_p_r['recall']*100.0))

In [None]:
with open(os.path.join(m_cfg['model_dir'], "topics_p_r.json"), "w") as out_f:
    json.dump(model_topics_p_r, out_f, indent=4)