Using https://www.machinelearningplus.com/nlp/topic-modeling-gensim-python/

In [1]:
from scipy.spatial.distance import euclidean
from scipy.spatial.distance import cosine

from sklearn.metrics.pairwise import cosine_similarity

import pickle
import os
import copy
from tqdm import tqdm
import random
import copy

import matplotlib.pyplot as plt

import seaborn as sns

import numpy as np
from collections import Counter

from scipy import stats
import scipy.io.wavfile
from IPython.display import Audio
from IPython.display import display

import time

import re
import numpy as np
import pandas as pd
import sklearn
from pprint import pprint

from collections import Counter

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation

import pyLDAvis
import pyLDAvis.sklearn
pyLDAvis.enable_notebook()

import warnings
warnings.filterwarnings("ignore",category=DeprecationWarning)

%matplotlib inline

In [2]:
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.
for i in range(len(tableau20)):
    r, g, b = tableau20[i]
    tableau20[i] = (r / 255., g / 255., b / 255.)
# end for


In [3]:
# NLTK Stop words
from nltk.corpus import stopwords
stop_words = stopwords.words('english')+stopwords.words('spanish')
stop_words.extend(['hmm', 'mm', 'mhm', 'umm', 'umm', 'aha', "uh", "yes", "ah", "um", "eh", "hm"])

In [4]:
len(stop_words)

452

Steps:

1. Train topic model using Fisher train text
2. Predict topic per call for dev set using reference text
3. Predict topic per call for dev set using predicted text

In [5]:
def clean_out_str(out_str):
    out_str = out_str.replace("`", "")
    out_str = out_str.replace('"', '')
    out_str = out_str.replace('¿', '')
    out_str = out_str.replace("''", "")

    # for BPE
    out_str = out_str.replace("@@ ", "")
    out_str = out_str.replace("@@", "")

    out_str = out_str.strip()
    return out_str


# In[46]:


def get_out_str(h, dec_key):
    h = [w.decode() for w in h]
    out_str = ""
    if dec_key == "en_w":
        for w in h:
            out_str += "{0:s}".format(w) if (w.startswith("'") or w=="n't") else " {0:s}".format(w)
    elif "bpe_w" in dec_key:
        out_str = " ".join(h)

    elif dec_key == "en_c":
        out_str = "".join(h)
        
    else:
        out_str = " ".join(h)

    out_str = clean_out_str(out_str)
    return out_str

In [6]:
def get_call_text(key, dec_key):
    call_text = {}
    for utt in map_dict[key]:
        call_id = utt.split("-",1)[0]
        if call_id not in call_text:
            call_text[call_id] = get_out_str(map_dict[key][utt][dec_key], dec_key)
        else:
            call_text[call_id] = "{0:s} {1:s}".format(call_text[call_id], 
                                                      get_out_str(map_dict[key][utt][dec_key], dec_key))
    return call_text

### Loading Fisher data

In [7]:
map_dict = pickle.load(open("./mfcc_13dim/bpe_map.dict", "rb"))

In [8]:
dec_key = "bpe_w"

In [9]:
train_calls_text = get_call_text("fisher_train", "bpe_w")

In [10]:
train_counter = Counter(" ".join([t.strip() for t in train_calls_text.values()]).split())

In [11]:
train_counter.most_common(5)

[('the', 55098), ('i', 52158), ('and', 45559), ('that', 37779), ('yes', 35097)]

In [12]:
dev_calls_text = get_call_text("fisher_dev", "bpe_w")

In [13]:
fisher_ids = []
fisher_utt_ids = []
with open("./fisher/fisher_dev/fisher_dev_eval.ids", "r", encoding="utf-8") as in_f:
    for line in in_f:
        fisher_ids.append(line.strip().split("-",1)[0])
        fisher_utt_ids.append(line.strip())
     # end for
# end with

dev2_utts = []
with open("fisher/fisher_dev2/fisher_dev2_eval.ids", "r", encoding="utf-8") as in_f:
    for u in in_f:
        dev2_utts.append(u.strip())

In [14]:
len(dev2_utts)

3959

In [15]:
dev2_calls_text = {}
set_key = "fisher_dev2"
dec_key = "en_w"
for utt in map_dict[set_key]:
    call_id = utt.split("-",1)[0]
    if call_id not in dev2_calls_text:
        dev2_calls_text[call_id] = get_out_str(map_dict[set_key][utt][dec_key][3], dec_key)
    else:
        dev2_calls_text[call_id] = "{0:s} {1:s}".format(dev2_calls_text[call_id], 
                                                  get_out_str(map_dict[set_key][utt][dec_key][3], dec_key))

In [16]:
len(train_calls_text), len(dev_calls_text), len(dev2_calls_text)

(759, 20, 20)

In [17]:
list(train_calls_text.values())[0][:100]

'hello hello hello hello with whom am i speaking eh silvia yes what is your name hello silvia eh my n'

### SKLEARN


https://medium.com/mlreview/topic-modeling-with-scikit-learn-e80d33668730

https://github.com/bmabey/pyLDAvis/blob/master/notebooks/sklearn.ipynb


In [18]:
def display_topics(model, feature_names, no_top_words):
    topics = {}
    print("-"*80)
    print("{0:>10s} | {1:s}".format("topic id", "top {0:d} words".format(no_top_words)))
    print("-"*80)
    for topic_idx, topic in enumerate(model.components_):
#         print("Topic %d:" % (topic_idx))
        top_words = " ".join([feature_names[i]
                        for i in topic.argsort()[:-no_top_words - 1:-1]])
        print("{0:>10d} | {1:s}".format(topic_idx, top_words))
        topics[topic_idx] = top_words
    return topics

no_top_words = 10

In [19]:
no_topics = 10

In [20]:
data = list(train_calls_text.values())

In [21]:
tf_vectorizer = CountVectorizer(strip_accents = 'unicode',
                                stop_words = stop_words,
                                lowercase = True,
#                                 token_pattern = r'\b[a-zA-Z]{3,}\b',
                                max_df = 0.5, 
                                min_df = 1)
dtm_tf = tf_vectorizer.fit_transform(data)
print(dtm_tf.shape)

(759, 16737)


In [22]:
tfidf_vectorizer = TfidfVectorizer(**tf_vectorizer.get_params())
dtm_tfidf = tfidf_vectorizer.fit_transform(data)
print(dtm_tfidf.shape)

(759, 16737)


In [23]:
tfidf_feature_names = tfidf_vectorizer.get_feature_names()

In [24]:
# Run NMF
nmf = NMF(n_components=no_topics, random_state=1, alpha=.1, l1_ratio=.5, init='nndsvd').fit(dtm_tfidf)

In [25]:
nmf_topics = display_topics(nmf, tfidf_feature_names, no_top_words)

--------------------------------------------------------------------------------
  topic id | top 10 words
--------------------------------------------------------------------------------
         0 | city mexico york spanish cold philadelphia puerto wow speak huh
         1 | music listen dance play songs salsa song radio rap listening
         2 | religion church catholic religions god religious bible christian christians homosexual
         3 | insurance health system pay doctor medical hospital dollars expensive hundred
         4 | phone cell use computer cellular telephone cellphone phones internet driving
         5 | relationship married women break marriage together internet woman men kids
         6 | power vote parties politics countries party world president bush government
         7 | movies watch movie tv commercials kids television watching pg violence
         8 | jury guilty system justice case judge trial innocent cases lawyers
         9 | race white black interraci

In [26]:
pyLDAvis.sklearn.prepare(nmf, dtm_tfidf, tfidf_vectorizer, mds='tsne')

  kernel = (topic_given_term * np.log((topic_given_term.T / topic_proportion).T))
  log_lift = np.log(topic_term_dists / term_proportion)
  log_ttd = np.log(topic_term_dists)


In [27]:
# for TFIDF DTM
lda_tfidf = LatentDirichletAllocation(n_topics=10, random_state=0).fit(dtm_tfidf)

In [28]:
lda_topics = display_topics(lda_tfidf, tfidf_feature_names, no_top_words)

--------------------------------------------------------------------------------
  topic id | top 10 words
--------------------------------------------------------------------------------
         0 | vote bronx letters change chicago party government company wow basque
         1 | music dangerous internet insulin phone huh scholarship send smith yea
         2 | music kids dance school aid internet system spanish mexico married
         3 | phone internet computer telemarketing pay bought use information number puerto
         4 | music religion god virgin church play kids puerto listen songs
         5 | spanish guatemala english immigrants music colorado rice tennessee mom phone
         6 | music religion city married york god mexico huh spanish phone
         7 | phone mexico york jury puerto city system pay immigrants insurance
         8 | music commercials religion god dance married listen church black english
         9 | phone music cellular sprint spanish man god zury nexte

In [29]:
lda_tfidf.score(dtm_tfidf)

-95883.1673073247

In [30]:
# pyLDAvis.sklearn.prepare(lda_tfidf, dtm_tfidf, tfidf_vectorizer, mds='tsne')

### Read predictions

In [31]:
google_s2t_hyps_path = os.path.join("../chainer2/speech2text/both_fbank_out/", "google_s2t_hyps.dict")
google_s2t_hyps = pickle.load(open(google_s2t_hyps_path, "rb"))
google_hyp_r0 = google_s2t_hyps['fisher_dev_r0']

In [32]:
model_pred_files = {
#     "2.5h": "experiments/nmt_asr/sp-2.5hrs_swbd1-train_nodev_ep25_baseline",
#     "2.5h+asr": "experiments/nmt_asr/sp-2.5hrs_swbd1-train_nodev_ep25",
    "5h": "experiments/nmt_asr/sp-5hrs_swbd1-train100k_baseline",
    "5h+asr": "experiments/nmt_asr/sp-5hrs_swbd1-train_nodev_ep25",
#     "10h": "experiments/nmt_asr/sp-10hrs_swbd1-train100k_baseline",
#     "10h+asr": "experiments/nmt_asr/sp-10hrs_swbd1-train_nodev_ep25_enc-attn-dec",
    "20h": "interspeech_bpe/sp_20hrs_best_bn-nobias_batch-32_buck-n25-w80/",
    "20h+asr": "experiments/nmt_asr/sp-20hrs_swbd1-train-nodev_ep25_enc-attn-dec/",
    "50h": "interspeech_bpe/sp_50hrs_best_bn-nobias_bucks-n25-w80_x0.2/",
    "50h+asr": "experiments/nmt_asr/sp-50hrs_swbd1-train-nodev_ep25/",
#     "160h": "interspeech_bpe/sp_160hrs_cnn-512-9-mfcc-13_drpt-0.3_l2e-4_rnn-3"
}

dev_beam_fname = "fisher_dev_beam_len-norm_min-0_max-300_N-5_K-5_W-0.6.en"
dev2_beam_fname = "fisher_dev2_beam_len-norm_min-0_max-300_N-5_K-5_W-0.6.en"

In [33]:
pred_dev_text = {}
pred_dev2_text = {}
for k, v in model_pred_files.items():
    pred_dev_text[k] = {}
    with open(os.path.join(v,dev_beam_fname), "r", encoding="utf-8") as in_f:
        for i, line in enumerate(in_f):
            utt = fisher_ids[i]
            if utt not in pred_dev_text[k]:
                pred_dev_text[k][utt] = line.strip()
            else:
                pred_dev_text[k][utt] = "{0:s} {1:s}".format(pred_dev_text[k][utt], line.strip())
    
    pred_dev2_text[k] = {}
    with open(os.path.join(v,dev2_beam_fname), "r", encoding="utf-8") as in_f:
        for i, line in enumerate(in_f):
            utt = dev2_utts[i].split("-")[0]
            if utt not in pred_dev2_text[k]:
                pred_dev2_text[k][utt] = line.strip()
            else:
                pred_dev2_text[k][utt] = "{0:s} {1:s}".format(pred_dev2_text[k][utt], line.strip())
#         end for line
#     end with open
# end for

In [34]:
pred_dev_text["Weiss"] = {}
for utt in fisher_utt_ids:
    callid = utt.split("-",1)[0]
    if callid not in pred_dev_text["Weiss"]:
        pred_dev_text["Weiss"][callid] = " ".join(google_hyp_r0[utt])
    else:
        pred_dev_text["Weiss"][callid] = "{0:s} {1:s}".format(pred_dev_text["Weiss"][callid], " ".join(google_hyp_r0[utt]))

In [35]:
pred_dev_text["20h"][fisher_ids[0]][:100]

'good afternoon yes but how long have you been in the united states so yes we we spoke here you work '

In [36]:
list(pred_dev_text["20h"].keys())[:5]

['20051009_182032_217_fsp',
 '20051009_210519_219_fsp',
 '20051010_212418_225_fsp',
 '20051016_180547_265_fsp',
 '20051016_210626_267_fsp']

### Predict topics using English translations

In [37]:
def get_ref_topics(ref_text):
    ref_topics = {}
    for utt in tqdm(ref_text):
        dev_sample = tfidf_vectorizer.transform([ref_text[utt]])
        doc_topic_dist_unnormalized = np.matrix(nmf.transform(dev_sample))
        doc_topic_dist = doc_topic_dist_unnormalized/doc_topic_dist_unnormalized.sum(axis=1)
        topic_id = np.argmax(doc_topic_dist)
        topic_prob = np.max(doc_topic_dist)
        topic_words = nmf_topics[topic_id]
        topic_words_in_utt = " ".join(list(set(topic_words.split()) & set(ref_text[utt].strip().split())))
        ref_topics[utt] = {"t_words": topic_words, "ref_words": topic_words_in_utt, "t_prob": topic_prob, "ref_tid": topic_id}
    return ref_topics
    

In [38]:
def get_probs(call_text, topic_model):
    call_probs = {}
    for callid in call_text:
        utt_text = tfidf_vectorizer.transform([call_text[callid]])
        call_probs[callid] = np.matrix(topic_model.transform(utt_text))
    
    return call_probs

In [39]:
def get_topics(probs):
    call_topic = {}
    for callid in probs:
        call_topic[callid] = np.argmax(probs[callid])
    return call_topic

In [40]:
def create_topic_df(ref_topics):
    topic_df = pd.DataFrame.from_dict(ref_topics, orient='index')
    calls_shape = topic_df.ref_tid.values.shape
    most_common_topic = Counter(topic_df.ref_tid).most_common(1)[0][0]
    topic_df["t_freq"] = np.full(calls_shape, most_common_topic)
    # Random topics
    topic_df["t_rand"] = np.random.randint(np.min(topic_df.ref_tid.values), np.max(topic_df.ref_tid.values), size=calls_shape[0])
    return topic_df

In [41]:
def get_pred_topics(pred_text, topic_df):
    pred_probs = {}
    pred_topics = {}
    for k in pred_text:
        pred_probs[k] = get_probs(pred_text[k], nmf)
        pred_topics[k] = get_topics(pred_probs[k])
    
    for k in pred_topics:
        topic_df[k] = pd.Series(pred_topics[k])

In [42]:
def eval_topics(topic_df, pred_text):
    sets_to_eval = ["t_freq", "t_rand"] + list(pred_text.keys()) 
    print("{0:10s} | {1:10s} | {2:10s}".format("model", "accuracy", "mutual info"))
    print("-"*40)
    for s in sets_to_eval:
        acc = sklearn.metrics.accuracy_score(topic_df["ref_tid"], topic_df[s])
        nmi = sklearn.metrics.normalized_mutual_info_score(topic_df["ref_tid"], topic_df[s])
        print("{0:10s} | {1:10.2f} | {2:10.2f}".format(s, acc, nmi))

### Fisher dev set

In [43]:
nmf_topics = display_topics(nmf, tfidf_feature_names, no_top_words)

--------------------------------------------------------------------------------
  topic id | top 10 words
--------------------------------------------------------------------------------
         0 | city mexico york spanish cold philadelphia puerto wow speak huh
         1 | music listen dance play songs salsa song radio rap listening
         2 | religion church catholic religions god religious bible christian christians homosexual
         3 | insurance health system pay doctor medical hospital dollars expensive hundred
         4 | phone cell use computer cellular telephone cellphone phones internet driving
         5 | relationship married women break marriage together internet woman men kids
         6 | power vote parties politics countries party world president bush government
         7 | movies watch movie tv commercials kids television watching pg violence
         8 | jury guilty system justice case judge trial innocent cases lawyers
         9 | race white black interraci

In [44]:
dev_topics_dict = get_ref_topics(dev_calls_text)
dev_topics_df = create_topic_df(dev_topics_dict)
get_pred_topics(pred_dev_text, dev_topics_df)

100%|██████████| 20/20 [00:00<00:00, 216.53it/s]


In [45]:
dev_topics_df.head(20)

Unnamed: 0,t_words,ref_words,t_prob,ref_tid,t_freq,t_rand,5h,5h+asr,20h,20h+asr,50h,50h+asr,Weiss
20051009_182032_217_fsp,religion church catholic religions god religio...,church christian god christians religious bibl...,0.619727,2,0,3,0,0,2,2,2,2,2
20051009_210519_219_fsp,religion church catholic religions god religio...,church christian god christians religious reli...,0.859646,2,0,5,0,2,2,2,2,2,2
20051010_212418_225_fsp,religion church catholic religions god religio...,church god religious religions religion catholic,0.744693,2,0,4,0,2,2,2,2,2,2
20051016_180547_265_fsp,city mexico york spanish cold philadelphia pue...,wow speak york spanish,0.600671,0,0,6,0,0,0,0,0,0,0
20051016_210626_267_fsp,phone cell use computer cellular telephone cel...,phone cell internet use,0.524293,4,0,5,0,0,0,0,0,0,0
20051017_180712_270_fsp,music listen dance play songs salsa song radio...,play music radio songs listening listen,0.469579,1,0,8,0,1,1,1,1,1,1
20051017_220530_275_fsp,music listen dance play songs salsa song radio...,play music radio songs salsa dance listening l...,0.727736,1,0,5,0,1,1,1,1,1,1
20051017_234550_276_fsp,music listen dance play songs salsa song radio...,music radio songs salsa dance listening listen,0.72266,1,0,1,0,1,1,1,1,1,1
20051018_210220_279_fsp,city mexico york spanish cold philadelphia pue...,wow,0.505092,0,0,3,0,0,0,0,0,3,3
20051018_210744_280_fsp,city mexico york spanish cold philadelphia pue...,cold wow york puerto,1.0,0,0,6,0,0,0,0,0,0,0


In [46]:
eval_topics(dev_topics_df, pred_dev_text)

model      | accuracy   | mutual info
----------------------------------------
t_freq     |       0.50 |       0.00
t_rand     |       0.15 |       0.47
5h         |       0.50 |       0.00
5h+asr     |       0.80 |       0.67
20h        |       0.85 |       0.80
20h+asr    |       0.90 |       0.87
50h        |       0.90 |       0.87
50h+asr    |       0.85 |       0.81
Weiss      |       0.90 |       0.88


### Fisher dev2

In [47]:
dev2_topics_dict = get_ref_topics(dev2_calls_text)
dev2_topics_df = create_topic_df(dev2_topics_dict)
get_pred_topics(pred_dev2_text, dev2_topics_df)

100%|██████████| 20/20 [00:00<00:00, 135.18it/s]


In [48]:
dev2_topics_df.head(20)

Unnamed: 0,t_words,ref_words,t_prob,ref_tid,t_freq,t_rand,5h,5h+asr,20h,20h+asr,50h,50h+asr
20050909_210655_26_fsp,relationship married women break marriage toge...,married relationship marriage together,0.55118,5,4,5,0,0,5,5,5,5
20050910_210708_33_fsp,relationship married women break marriage toge...,woman married marriage women kids,0.526548,5,4,3,0,0,0,5,5,5
20050913_210933_49_fsp,music listen dance play songs salsa song radio...,play music songs song salsa dance listening li...,0.850023,1,4,0,0,1,1,1,1,1
20050913_211649_50_fsp,music listen dance play songs salsa song radio...,music songs dance listening listen,0.581836,1,4,0,0,1,1,1,1,1
20050915_210434_65_fsp,music listen dance play songs salsa song radio...,music radio songs song salsa dance listening l...,0.857424,1,4,2,0,1,1,1,1,1
20050916_180332_68_fsp,phone cell use computer cellular telephone cel...,phone phones use cell cellphone computer telep...,0.799664,4,4,1,0,0,4,4,4,4
20050918_180733_81_fsp,phone cell use computer cellular telephone cel...,phone phones use cell cellphone computer driving,0.665164,4,4,2,0,0,0,0,0,4
20050918_210841_82_fsp,phone cell use computer cellular telephone cel...,phone internet phones use cell cellphone telep...,0.725916,4,4,0,0,0,4,4,4,4
20050920_212030_93_fsp,city mexico york spanish cold philadelphia pue...,puerto spanish mexico speak city york wow,0.643222,0,4,1,0,0,0,0,0,0
20050921_210443_99_fsp,phone cell use computer cellular telephone cel...,computer internet use,0.341486,4,4,4,0,0,0,0,4,0


In [49]:
eval_topics(dev2_topics_df, pred_dev2_text)

model      | accuracy   | mutual info
----------------------------------------
t_freq     |       0.35 |      -0.00
t_rand     |       0.10 |       0.42
5h         |       0.25 |      -0.00
5h+asr     |       0.40 |       0.51
20h        |       0.55 |       0.52
20h+asr    |       0.75 |       0.82
50h        |       0.85 |       0.84
50h+asr    |       0.85 |       0.84
