In [1]:
import pickle
import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from tqdm import tqdm
import torch

In [2]:
cd ..

/research4/projects/topic_modeling_autoencoding/fidgit


In [3]:
#Loading vocab
vocab ='./data/vocab.pkl'
vocab = pickle.load(open(vocab,'rb'))
vocab_size = len(vocab)
swapped_vocab_docu = dict((v,k) for k,v in vocab.items())
print("Vocab Size: %d"%(vocab_size))

Vocab Size: 21725


## Global 1: Topic Embeddings vs Word Embeddings (Un-Modified) vs Word Embeddings (Modified)

In [4]:
import torch
from sklearn.decomposition import PCA
import os
import numpy as np
import pandas as pd

In [5]:
import sys
sys.argv.pop(1)
sys.argv.pop(1)
sys.path.insert(0, '/research4/projects/topic_modeling_autoencoding/fidgit')

### Loading Topic Embeddings

In [6]:
#Load topic embeddings
etm_topic_embs = pickle.load(open('./myetm/results/model_0_concept_embed.pickle','rb')).detach().cpu().numpy()
print("ETM Topic Embeddings Shape: "+str(etm_topic_embs.shape))

ETM Topic Embeddings Shape: (90, 300)


In [8]:
soft_topic_embs = pickle.load(open('./vqvae_soft/results/model_0_concept_embed.pickle','rb')).cpu().numpy()
print("Soft VQ-VAE Topic Embeddings Shape: "+str(soft_topic_embs.shape))

Soft VQ-VAE Topic Embeddings Shape: (90, 300)


### Loading Un-modified Word Embeddings

In [9]:
unmodified_word_embs = os.path.join('./data/', 'embed.pth')
unmodified_word_embs = torch.load(unmodified_word_embs).cuda()
print("Un-Modified Word Embeddings Shape: "+str(unmodified_word_embs.shape))

Un-Modified Word Embeddings Shape: torch.Size([21725, 300])


### Loading Modified Word Embeddings

In [None]:
#Hard-vqvae model
from vqvae_hard.model import TopicVAE as hard
from vqvae_hard.config import parse_args as hard_parse_args

hard_args = hard_parse_args()
hard_args.numbr_concepts = 20
hard_pretrained = hard(hard_args, vocab_20ng)
hard_checkpoint = torch.load('./vqvae_hard/checkpoints/model_8.pt')
hard_pretrained.load_state_dict(hard_checkpoint['model'])
hard_pretrained.cuda()
hard_pretrained.eval()
_, hard_modified_word_embs , _, _ = hard_pretrained.vqvae(unmodified_word_embs, "test")
hard_modified_word_embs = hard_modified_word_embs.cpu().numpy()

In [None]:
#Soft-vqvae model
from vqvae_soft.model import TopicVAE as soft
from vqvae_soft.config import parse_args as soft_parse_args

soft_args = soft_parse_args()
soft_args.numbr_concepts = 20
soft_pretrained = soft(soft_args, vocab_20ng)
soft_checkpoint = torch.load('./vqvae_soft/checkpoints/model_8.pt')
soft_pretrained.load_state_dict(soft_checkpoint['model'])
soft_pretrained.cuda()
soft_pretrained.eval()
_, _, soft_modified_word_embs , _, _ = soft_pretrained.vqvae(unmodified_word_embs, "test")
soft_modified_word_embs = soft_modified_word_embs.cpu().numpy()

In [None]:
#Multi-vqvae model views
from vqvae_multi.model import TopicVAE as multi
from vqvae_multi.config import parse_args as multi_parse_args

multi_args = multi_parse_args()
multi_args.numbr_concepts = 20
heads = 8
multi_pretrained = multi(multi_args, vocab_20ng)
multi_checkpoint = torch.load('./vqvae_multi/checkpoints/model_20.pt')
multi_pretrained.load_state_dict(multi_checkpoint['model'])
multi_pretrained.cuda()
multi_pretrained.eval()
_, multi_modified_word_embs , _, _ = multi_pretrained.vqvae(unmodified_word_embs, "test")
multi_modified_word_embs = multi_modified_word_embs.cpu().numpy()

multi_topic_embs = torch.from_numpy(multi_topic_embs).cuda().repeat(1, heads)
if heads == 2:
    multi_topic_embs = torch.tanh(multi_pretrained.ff_1(multi_topic_embs))
elif heads == 4:
    multi_topic_embs = torch.tanh(multi_pretrained.ff_2(multi_topic_embs))
    multi_topic_embs = torch.tanh(multi_pretrained.ff_1(multi_topic_embs))
elif heads == 8:
    multi_topic_embs = torch.tanh(multi_pretrained.ff_3(multi_topic_embs))
    multi_topic_embs = torch.tanh(multi_pretrained.ff_2(multi_topic_embs))
    multi_topic_embs = torch.tanh(multi_pretrained.ff_1(multi_topic_embs))
multi_topic_embs = multi_topic_embs.detach().cpu().numpy()

In [None]:
cd /research4/projects/topic_modeling_autoencoding/20ng/dvtm_bc_loss

In [None]:
#DVITM
from model import DVTM as model_dvtm
from config import parse_args as parse_args

args = parse_args()
args.numbr_concepts = 20
dvtm_pretrained = model_dvtm(args, vocab_20ng)
dvtm_checkpoint = torch.load('./msft_checkpoints/checkpoints_numbr_concepts20_latent_dim16_kld_wt5e-05_bcbf4096.0.pt')
dvtm_pretrained.load_state_dict(dvtm_checkpoint['model'])
dvtm_pretrained.eval()
dvtm_pretrained = dvtm_pretrained.cuda()
thetas, log_thetas, logits_thetas = dvtm_pretrained.encoder(unmodified_word_embs, "test")
scores = torch.nn.functional.gumbel_softmax(logits_thetas, tau=dvtm_pretrained.temp, hard=False, eps=dvtm_pretrained.eps)
soft_modified_word_embs = torch.matmul(scores, dvtm_pretrained.emb_concept.weight)
soft_modified_word_embs = soft_modified_word_embs.view(-1, dvtm_pretrained.latent_dim * dvtm_pretrained.in_dim)
soft_modified_word_embs = torch.relu(dvtm_pretrained.dec_0(soft_modified_word_embs))
# soft_modified_word_embs = dvtm_pretrained.dec_1(soft_modified_word_embs)

# soft_modified_word_embs = soft_modified_word_embs.mean(dim=1)

soft_modified_word_embs = soft_modified_word_embs.detach().cpu().numpy()
pca_dvtm = PCA(n_components=300)
soft_modified_word_embs = pca_dvtm.fit_transform(soft_modified_word_embs)
soft_modified_word_embs = soft_modified_word_embs/np.max(soft_modified_word_embs, axis=1, keepdims=True)
# print(soft_modified_word_embs)
soft_topic_embs = dvtm_pretrained.emb_concept.weight.detach().cpu().numpy()
print(soft_modified_word_embs.shape)
print(soft_topic_embs.shape)

# args_with_lts = parse_args_with_lts()
# args_with_lts.numbr_concepts = 20
# dvtm_with_lts_pretrained = model_dvtm_with_lts(args_with_lts, vocab_20ng)
# dvtm_with_lts_checkpoint = torch.load('./checkpoints/%s%d.pt'%('model_', expno))
# dvtm_with_lts_pretrained.load_state_dict(dvtm_with_lts_checkpoint['model'])
# dvtm_with_lts_pretrained.eval()
# dvtm_with_lts_pretrained = dvtm_with_lts_pretrained.cuda()
# thetas, log_thetas = dvtm_with_lts_pretrained.encoder(unmodified_word_embs, "test")
# scores, _ = dvtm_with_lts_pretrained.rt(thetas, log_thetas, len(unmodified_word_embs), "test")
# dvtm_with_lts_modified_word_embs = torch.matmul(scores, soft_topic_embs)
# soft_modified_word_embs = dvtm_with_lts_modified_word_embs.detach().cpu().numpy()
# soft_topic_embs = soft_topic_embs.cpu()
# print(soft_modified_word_embs.shape)

In [None]:
unmodified_word_embs = unmodified_word_embs.cpu().numpy()

### Top moved words

In [None]:
#by calculating distance between Modified and Unmodifed word embeddings
numbr_top_words = 1000
data = []
for word, idx in vocab_20ng.items():
#     if word in ['jet','mile','km','train','road','honda','bmw','motor','bus','rear','engine','wheel','brake']:
    dist = np.linalg.norm(soft_modified_word_embs[idx]-unmodified_word_embs[idx], ord=2)
    data.append((word, idx, dist))

data = np.array(data)
df = pd.DataFrame(data, columns=['word','idx','dist'])
df = df.sort_values(by='dist', ascending=False)
top_moved_words = df.head(numbr_top_words)
print(top_moved_words)

In [None]:
soft_top_modified_word_embs = []
soft_top_unmodified_word_embs = []
etm_top_unmodified_word_embs = []

considered_lst_words = []
count = 0
for wrd, idx in zip(top_moved_words['word'], top_moved_words['idx']):
    soft_top_modified_word_embs.append(soft_modified_word_embs[int(idx)])
    soft_top_unmodified_word_embs.append(unmodified_word_embs[int(idx)])
    etm_top_unmodified_word_embs.append(unmodified_word_embs[int(idx)])
    
#     if wrd in ['jet','mile','km','train','road','honda','bmw','motor','bus','rear','engine','wheel','brake']: #soft
#         considered_lst_words.append(count)
#     if wrd in ['mit','cambridge','library','department','facility','professor','scientist','foundation','period','nasa','corporation']:
#         considered_lst_words.append(count)
#     if wrd in ['university', 'group', 'institute', 'information', 'computer', 'year', 'science', 'research', 'college', 'education', 'program', 'professor', 'school', 'engineering']: #hard
#         considered_lst_words.append(count)
#israel, people, war, israeli, jews, turkish, armenians, country, armenian, government
#     if wrd in ['jew', 'hitler', 'surrender', 'massacre', 'mary', 'ignorance', 'atheism']: #multi
#         considered_lst_words.append(count)
    
#     if wrd in ['mit', 'engineering', 'holy', 'ab', 'biblical', 'prophet', 'nec', 'bible', 'wiretap', 'simm']: #multi
#     if wrd in ['palestinian','troops','security','foreign','government']:
#     considered_lst_words.append(count)
#     if wrd in ['unix', 'server', 'os', 'macintosh', 'dos', 'windows', 'ftp', 'mac', 'software', 'amiga'
#               'use', 'drive', 'system', 'card', 'run', 'disk', 'problem', 'image']:
#         considered_lst_words.append(count)

#     if wrd in ['lebanese', 'resistance', 'israeli', 'buffer', 'zone', 'lebanon', 'israel', 'palestinian', 'civilian', 'bomb',
#               'people', 'war', 'jews', 'turkish', 'armenians', 'country', 'armenian', 'government']: #dvitm without lts
#         considered_lst_words.append(count)
    if wrd in ['hockey', 'canada', 'state', 'cover', 'southern', 'board', 'engineering', 'canadian', 'helmet',
              'team', 'division', 'san', 'nhl', 'toronto', 'york']: #dvitm with lts
        considered_lst_words.append(count)
    count = count + 1
    
soft_top_modified_word_embs = np.array(soft_top_modified_word_embs)
soft_top_unmodified_word_embs = np.array(soft_top_unmodified_word_embs)
etm_top_unmodified_word_embs = np.array(etm_top_unmodified_word_embs)

print("Soft Top Modified Word embeddings shape: "+ str(soft_top_modified_word_embs.shape))
print("Soft Top Unmodified Word embeddings shape: "+str(soft_top_unmodified_word_embs.shape))
print("ETM Top Unmodified Word embeddings shape: "+str(etm_top_unmodified_word_embs.shape))

pca_data = np.concatenate((soft_topic_embs, etm_topic_embs, soft_top_modified_word_embs, soft_top_unmodified_word_embs))


In [None]:
pca = PCA(n_components=2)

pca_result = pca.fit_transform(pca_data)

x_pca  = pca_result[:,0]
y_pca = pca_result[:,1]

In [None]:
considered_lst_words

In [None]:
fig_pca = make_subplots(rows=1, cols=1, 
                    shared_yaxes= True, 
                    shared_xaxes= True)

num_topics = soft_topic_embs.shape[0]
# soft_topic_trace = go.Scatter(x=x_pca[:num_topics],
#                             y=y_pca[:num_topics],
#                             mode='markers+text',
#                             name='Topics',
#                             legendgroup = 'a',
#                             textposition="bottom center",
#                             text = np.arange(num_topics),
#                             marker=dict(color='red'))

soft_topic_trace = go.Scatter(x=[x_pca[13]],
                            y=[y_pca[13]],
                            mode='markers+text',
                            name='Topics',
                            legendgroup = 'a',
                            textposition="bottom center",
                            text = ['13'],
                            marker=dict(color='red'))


shapes = []
for idx in considered_lst_words:
        
    soft_mod_trace = go.Scatter(x=[x_pca[2*num_topics+idx]],
                                y=[y_pca[2*num_topics+idx]],
                                mode='markers',
                                name='Modified Words',
                                #textposition="bottom center",
                                text=list(top_moved_words['word'])[idx],
                                legendgroup = 'b',
                                marker=dict(color='green'))

    unmod_trace = go.Scatter(x=[x_pca[2*num_topics+numbr_top_words+idx]],
                                y=[y_pca[2*num_topics+numbr_top_words+idx]],
                                mode='markers',
                                name='Un-Modified Words',
                                #textposition="bottom center",
                                text=list(top_moved_words['word'])[idx],
                                legendgroup = 'c',
                                marker=dict(color='lightgreen'))

# soft_mod_trace = go.Scatter(x=x_pca[2*num_topics:2*num_topics+numbr_top_words],
#                             y=y_pca[2*num_topics:2*num_topics+numbr_top_words],
#                             mode='markers',
#                             name='Modified Words',
#                             #textposition="bottom center",
#                             text=list(top_moved_words['word']),
#                             legendgroup = 'b',
#                             marker=dict(color='green'))

# unmod_trace = go.Scatter(x=x_pca[2*num_topics+numbr_top_words:],
#                             y=y_pca[2*num_topics+numbr_top_words:],
#                             mode='markers',
#                             name='Un-Modified Words',
#                             #textposition="bottom center",
#                             text=list(top_moved_words['word']),
#                             legendgroup = 'c',
#                             marker=dict(color='lightgreen'))

    fig_pca.append_trace(unmod_trace,1,1)
    fig_pca.append_trace(soft_mod_trace,1,1)



    shapes.append(go.layout.Shape(
                type="line",
            x0=x_pca[2*num_topics+idx],
            y0=y_pca[2*num_topics+idx],
            x1=x_pca[2*num_topics+numbr_top_words+idx],
            y1=y_pca[2*num_topics+numbr_top_words+idx],
            line=dict(
                color="grey",
                width=0.5,
                dash="dot",
            )
        ))

fig_pca.append_trace(soft_topic_trace,1,1)
# shapes = []
# for x1,x0,y1,y0 in zip(x_pca[2*num_topics:2*num_topics+numbr_top_words], \
#                     x_pca[2*num_topics+numbr_top_words:], \
#                     y_pca[2*num_topics:2*num_topics+numbr_top_words], \
#                     y_pca[2*num_topics+numbr_top_words:]):
#     shapes.append(go.layout.Shape(
#                 type="line",
#             x0=x0,
#             y0=y0,
#             x1=x1,
#             y1=y1,
#             line=dict(
#                 color="grey",
#                 width=0.5,
#                 dash="dot",
#             )
#         ))



In [None]:
# etm_topic_trace = go.Scatter(x=x_pca[num_topics:2*num_topics],
#                             y=y_pca[num_topics:2*num_topics],
#                             mode='markers+text',
#                             name='Topics',
#                             legendgroup = 'a',
#                             textposition="bottom center",
#                             text = np.arange(num_topics),
#                             marker=dict(color='purple'))

etm_topic_trace = go.Scatter(x=[x_pca[num_topics+6]],
                            y=[y_pca[num_topics+6]],
                            mode='markers+text',
                            name='Topics',
                            legendgroup = 'a',
                            textposition="bottom center",
                            text = ['6'],
                            marker=dict(color='purple'))

#fig_pca.append_trace(unmod_trace,1,1)
fig_pca.append_trace(etm_topic_trace,1,1)

In [None]:
fig_pca.update_layout(shapes= shapes, showlegend=False, title="Topic Embeddings and Word Embeddings (Modified and Un-modified) for Soft VQ-VAE and ETM")    
fig_pca.update_layout(width=1000, height=600)
fig_pca.show()
#Show specific example of 
#1. word bmw, in Soft-VQ-VAE Topic = 9, In ETM Topic = 8
#2. word univeristy, in Soft-VQ-VAE Topic = 15, In ETM Topic = 12

## Global2: Beta matrix

In [None]:
from tqdm import tqdm_notebook

In [None]:
#ProdLDA
expno = 8
prodlda_beta = pickle.load(open('/research4/projects/topic_modeling_autoencoding/20ng/myprodlda/results/model_%d_concept_embed.pickle'%(expno),'rb')).cpu()
prodlda_beta = torch.softmax(prodlda_beta, dim=1)
print("ProdLDA Beta Shape: %s"%(str(prodlda_beta.shape)))

#ETM
etm_beta = torch.matmul(torch.from_numpy(etm_topic_embs).cuda(), unmodified_word_embs.t())
etm_beta = torch.softmax(etm_beta, dim=1)
print("ETM Beta Shape: %s"%(str(etm_beta.shape)))

In [None]:
cd /research4/projects/topic_modeling_autoencoding/20ng/dvtm

In [None]:
#DVITM without lts loss
from model import DVTM as model_dvtm_without_lts
from config import parse_args as parse_args_without_lts

args_without_lts = parse_args_without_lts()
args_without_lts.numbr_concepts = 20
dvtm_without_lts_pretrained = model_dvtm_without_lts(args_without_lts, vocab_20ng)
dvtm_without_lts_checkpoint = torch.load('./msft_checkpoints/checkpoints_numbr_concepts20_latent_dim16_kld_wt5e-05.pt')
dvtm_without_lts_pretrained.load_state_dict(dvtm_without_lts_checkpoint['model'])
dvtm_without_lts_pretrained = dvtm_without_lts_pretrained.cuda()
dvtm_without_lts_pretrained.eval()
tmp_scores = torch.zeros((args_without_lts.numbr_concepts, dvtm_without_lts_pretrained.latent_dim, args_without_lts.numbr_concepts)).cuda()
for idx in range(args_without_lts.numbr_concepts):
    tmp_scores[idx, :, idx] = 1

out = torch.matmul(tmp_scores, dvtm_without_lts_pretrained.emb_concept.weight)
out = out.view(-1, dvtm_without_lts_pretrained.latent_dim * dvtm_without_lts_pretrained.in_dim)
out = torch.relu(dvtm_without_lts_pretrained.dec_0(out))
out = dvtm_without_lts_pretrained.dec_1(out)
dvtm_without_lts_beta = torch.softmax(out, dim=1)
print("DVTM without lts beta Shape: "+str(dvtm_without_lts_beta.shape))

In [None]:
cd /research4/projects/topic_modeling_autoencoding/20ng/dvtm_bc_loss/

In [None]:
#DVITM with lts loss
from model import DVTM as model_dvtm_with_lts
from config import parse_args as parse_args_with_lts

# expno = 12
# soft_topic_embs = pickle.load(open('/research4/projects/topic_modeling_autoencoding/20ng/dvtm_bc_loss/results/model_%d_concept_embed.pickle'%(expno),'rb'))
# expno = 30
# soft_topic_embs = pickle.load(open('/research4/projects/topic_modeling_autoencoding/20ng/dvtm/results_prev/model_%d_concept_embed.pickle'%(expno),'rb'))
args_with_lts = parse_args_with_lts()
args_with_lts.numbr_concepts = 20
dvtm_with_lts_pretrained = model_dvtm_with_lts(args_with_lts, vocab_20ng)
dvtm_with_lts_checkpoint = torch.load('./msft_checkpoints/checkpoints_numbr_concepts20_latent_dim16_kld_wt5e-05_bcbf4096.0.pt')
dvtm_with_lts_pretrained.load_state_dict(dvtm_with_lts_checkpoint['model'])
dvtm_with_lts_pretrained = dvtm_with_lts_pretrained.cuda()
dvtm_with_lts_pretrained.eval()

tmp_scores = torch.zeros((args_with_lts.numbr_concepts, dvtm_with_lts_pretrained.latent_dim, args_with_lts.numbr_concepts)).cuda()
for idx in range(args_with_lts.numbr_concepts):
    tmp_scores[idx, :, idx] = 1

out = torch.matmul(tmp_scores, dvtm_with_lts_pretrained.emb_concept.weight)
out = out.view(-1, dvtm_with_lts_pretrained.latent_dim * dvtm_with_lts_pretrained.in_dim)
out = torch.relu(dvtm_with_lts_pretrained.dec_0(out))
out = dvtm_with_lts_pretrained.dec_1(out)
dvtm_with_lts_beta = torch.softmax(out, dim=1)

print("DVTM with lts beta Shape: "+str(dvtm_with_lts_beta.shape))

In [None]:
#Soft-VQVAE
soft_beta = torch.matmul(torch.from_numpy(soft_topic_embs), torch.from_numpy(unmodified_word_embs).t())
soft_beta = torch.softmax(soft_beta, dim=1)

#Hard-VQVAE
hard_beta = torch.matmul(torch.from_numpy(hard_topic_embs), torch.from_numpy(unmodified_word_embs).t())
hard_beta = torch.softmax(hard_beta, dim=1)

#Multi-VQVAE
multi_beta = torch.matmul(torch.from_numpy(multi_topic_embs), torch.from_numpy(unmodified_word_embs).t())
multi_beta = torch.softmax(multi_beta, dim=1)

In [None]:
cd /research4/projects/topic_modeling_autoencoding/20ng

In [None]:
#LDA
lda_data = []
with open('../mymallet/20ng/20ng_output/train_tww_20.txt','r') as lda_beta_file:
    for line in lda_beta_file.readlines():
        line_split = line.split('\t')
        topic = int(line_split[0])
        word = line_split[1]
        unnormalized_wt = float(line_split[2].replace('\n',''))
        lda_data.append((topic, word, unnormalized_wt))

lda_df = pd.DataFrame(lda_data, columns=['topic','word','uwt'])

In [None]:
lda_vocab = list(lda_df['word'])
lda_vocab_lst = sorted(list(set(lda_vocab)))
lda_vocab_idx = np.arange(len(lda_vocab))
lda_vocab = {}
for word, idx in zip(lda_vocab_lst, lda_vocab_idx):
    lda_vocab[word] = idx

In [None]:
lda_beta = []

for topic_idx in np.unique(sorted(list(lda_df['topic']))):
    topic_data = lda_df[lda_df['topic']==topic_idx]
    topic_data = topic_data.sort_values(by='word')
    lda_beta.append(np.array(list(topic_data['uwt'])))

lda_words = np.unique(sorted(list(topic_data['word'])))
lda_beta = np.array(lda_beta)
lda_beta = lda_beta/np.sum(lda_beta, axis=1, keepdims=True)

In [None]:
def idx2word(topk_indices, vocab):
    arr_words = []
    for idx in tqdm_notebook(topk_indices):
        words = []
        for word_idx in idx:
            words.append(list(vocab.keys())[list(vocab.values()).index(word_idx)])
        words =np.array(words)
        arr_words.append(words)
    
    
    return np.array(arr_words)

In [None]:
lda_topk_values, lda_topk_indices = torch.topk(torch.from_numpy(lda_beta).cuda(),10, dim=1)
prodlda_topk_values, prodlda_topk_indices = torch.topk(prodlda_beta,10, dim=1)
etm_topk_values, etm_topk_indices = torch.topk(etm_beta,10, dim=1)

In [None]:
dvtm_without_lts_topk_values, dvtm_without_lts_topk_indices = torch.topk(dvtm_without_lts_beta,10, dim=1)
dvtm_with_lts_topk_values, dvtm_with_lts_topk_indices = torch.topk(dvtm_with_lts_beta,10, dim=1)

In [None]:
soft_topk_values, soft_topk_indices = torch.topk(soft_beta,10, dim=1)
hard_topk_values, hard_topk_indices = torch.topk(hard_beta,10, dim=1)
multi_topk_values, multi_topk_indices = torch.topk(multi_beta,10, dim=1)

In [None]:
lda_words = idx2word(lda_topk_indices, lda_vocab)
prodlda_words = idx2word(prodlda_topk_indices, vocab_20ng)
etm_words = idx2word(etm_topk_indices, vocab_20ng)

In [None]:
dvtm_without_lts_words = idx2word(dvtm_without_lts_topk_indices, vocab_20ng)
dvtm_with_lts_words = idx2word(dvtm_with_lts_topk_indices, vocab_20ng)

In [None]:
soft_words = idx2word(soft_topk_indices, vocab_20ng)
hard_words = idx2word(hard_topk_indices, vocab_20ng)
multi_words = idx2word(multi_topk_indices, vocab_20ng)

In [None]:
#all_words = np.union1d(np.union1d(np.union1d(lda_words, np.union1d(etm_words, soft_words)),hard_words),multi_words)
all_words = np.union1d(np.union1d(np.union1d(lda_words, np.union1d(etm_words, prodlda_words)),dvtm_without_lts_words),dvtm_with_lts_words)
print(all_words)

In [None]:
lda_sub_beta = []
prodlda_sub_beta = []
etm_sub_beta = []
dvtm_without_lts_sub_beta = []
dvtm_with_lts_sub_beta = []
# soft_sub_beta = []
# hard_sub_beta = []
# multi_sub_beta = []


final_word_list = []
for word in all_words:
    if word in lda_vocab.keys() and word in vocab_20ng.keys():
        lda_sub_beta.append(torch.from_numpy(lda_beta).cuda()[:, lda_vocab[word]:lda_vocab[word]+1])
        prodlda_sub_beta.append(prodlda_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
        etm_sub_beta.append(etm_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
        dvtm_without_lts_sub_beta.append(dvtm_without_lts_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
        dvtm_with_lts_sub_beta.append(dvtm_with_lts_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
#         soft_sub_beta.append(soft_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
#         hard_sub_beta.append(hard_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
#         multi_sub_beta.append(multi_beta[:, vocab_20ng[word]:vocab_20ng[word]+1])
        final_word_list.append(word)
        
lda_sub_beta = torch.cat(lda_sub_beta, dim=1).cpu().numpy()
prodlda_sub_beta = torch.cat(prodlda_sub_beta, dim=1).cpu().numpy()
etm_sub_beta = torch.cat(etm_sub_beta, dim=1).cpu().numpy()
dvtm_without_lts_sub_beta = torch.cat(dvtm_without_lts_sub_beta, dim=1).detach().cpu().numpy()
dvtm_with_lts_sub_beta = torch.cat(dvtm_with_lts_sub_beta, dim=1).detach().cpu().numpy()

# soft_sub_beta = torch.cat(soft_sub_beta,dim=1).cpu().numpy()
# hard_sub_beta = torch.cat(hard_sub_beta,dim=1).cpu().numpy()
# multi_sub_beta = torch.cat(multi_sub_beta,dim=1).cpu().numpy()

In [None]:
print(len(final_word_list))

In [None]:
num_topics = 20
fig_beta = make_subplots(rows=5, cols=1, 
                    shared_yaxes= True,
                    subplot_titles=("CGS-LDA", "ProdLDA","ETM", "DVITM w/o lts loss", "DVITM with lts loss"))  

lda_trace = go.Heatmap(z = lda_sub_beta,
                      x = final_word_list,
                      y=np.arange(num_topics),
                      showscale=False,
                      colorscale='algae')

prodlda_trace = go.Heatmap(z = prodlda_sub_beta,
                      x = final_word_list,
                      y=np.arange(num_topics),
                      showscale=False,
                      colorscale='algae')

etm_trace = go.Heatmap(z = etm_sub_beta,
                      x = final_word_list,
                      y=np.arange(num_topics),
                      showscale=False,
                      colorscale='algae')
dvtm_without_lts_trace = go.Heatmap(z = dvtm_without_lts_sub_beta,
                      x = final_word_list,
                      y=np.arange(num_topics),
                      showscale=False,
                      colorscale='algae')

dvtm_with_lts_trace = go.Heatmap(z = dvtm_with_lts_sub_beta,
                      x = final_word_list,
                      y=np.arange(num_topics),
                      showscale=False,
                      colorscale='algae')

# hard_trace = go.Heatmap(z = hard_sub_beta,
#                         x = final_word_list,
#                         y=np.arange(num_topics),
#                        showscale=False,
#                       colorscale='algae')

# soft_trace = go.Heatmap(z = soft_sub_beta,
#                         x = final_word_list,
#                         y=np.arange(num_topics),
#                        showscale=False,
#                       colorscale='algae')

# multi_trace = go.Heatmap(z = multi_sub_beta,
#                         x = final_word_list,
#                         y=np.arange(num_topics),
#                        showscale=False,
#                       colorscale='algae')


In [None]:
fig_beta.append_trace(lda_trace, 1, 1)
fig_beta.append_trace(prodlda_trace, 2, 1)
fig_beta.append_trace(etm_trace, 3, 1)
fig_beta.append_trace(dvtm_without_lts_trace, 4, 1)
fig_beta.append_trace(dvtm_with_lts_trace, 5, 1)
# fig_beta.append_trace(hard_trace, 3, 1)
# fig_beta.append_trace(soft_trace, 4, 1)
# fig_beta.append_trace(multi_trace, 5, 1)

fig_beta.update_layout(width=1000, height=1500)
fig_beta.update_yaxes(nticks=num_topics, autorange="reversed",type='category')
fig_beta.update_xaxes(tickangle=45,type='category')
fig_beta.show()

## Local1: Document specific topic distribution

In [None]:
cd /research4/projects/topic_modeling_autoencoding/20ng

In [None]:
#load test dataset
test_file = '/research4/projects/topic_modeling_autoencoding/20ng/data/test.pth'
test_dataset = torch.load(test_file)
print(len(test_dataset))

In [None]:
from sklearn.datasets import fetch_20newsgroups
newsgroups_test = fetch_20newsgroups(subset='test')

In [None]:
print(len(newsgroups_test.data))
for idx, doc in enumerate(newsgroups_test.data):
#     print(doc)
    doc_low = doc.lower()
    words_doc_low = doc_low.split()
#     if ('boston' in words_doc_low) and ('puck' in words_doc_low) and ('college' in words_doc_low) and ('baseball' in words_doc_low):
    if ('david' in words_doc_low) and ('science' in words_doc_low) and ('department' in words_doc_low) and ('research' in words_doc_low):
        print(idx)
        if idx == 6937:
            print(doc)
        print('*'*100)

In [None]:
#LDA
lda_doc_topics_data = []
with open('/research4/projects/topic_modeling_autoencoding/mymallet/20ng/20ng_output/test_doc_topics_20.txt','r') as lda_doc_topics_file:
    for line in lda_doc_topics_file.readlines()[1:]:
        line_split = line.split('\t')
        #print(line_split[1])
        filename_split = line_split[1].split('/')

        filename_idx = int(filename_split[len(filename_split)-1].replace('.txt',''))

        topic_proportion = np.array(list(map(float, line_split[2:])))
        lda_doc_topics_data.append((filename_idx, topic_proportion))

lda_doc_topic_df = pd.DataFrame(lda_doc_topics_data, columns=['doc_idx','topic_proportion'])
lda_doc_topic_df = lda_doc_topic_df.sort_values(by='doc_idx')
lda_doc_topic_df = lda_doc_topic_df.reset_index(drop=True)
lda_theta = np.array(list(lda_doc_topic_df['topic_proportion'])[:-1])
print("LDA Test Dataset Topic Proportion Shape: "+ str(lda_theta.shape))

In [None]:
#prodlda model
from myprodlda.model import ProdLDA as my_prodlda
from myprodlda.config import parse_args as myprodlda_parse_args

myprodlda_args = myprodlda_parse_args()
myprodlda_args.numbr_concepts = 20
myprodlda_pretrained = my_prodlda(myprodlda_args, vocab_20ng)
myprodlda_checkpoint = torch.load('/research4/projects/topic_modeling_autoencoding/20ng/myprodlda/checkpoints/model_8.pt')
myprodlda_pretrained.load_state_dict(myprodlda_checkpoint['model'])
myprodlda_pretrained.cuda()
myprodlda_pretrained.eval()
_, _, _, myprodlda_theta = myprodlda_pretrained(test_dataset, "test")

In [None]:
myprodlda_theta = myprodlda_theta.detach().cpu().numpy()
print("ProdLDA Test Dataset Topic Proportion Shape: "+ str(myprodlda_theta.shape))

In [None]:
#etm model
from myetm.model import ETM as my_etm
from myetm.config import parse_args as myetm_parse_args

myetm_args = myetm_parse_args()
myetm_args.numbr_concepts = 20
myetm_pretrained = my_etm(myetm_args, vocab_20ng)
myetm_checkpoint = torch.load('/research4/projects/topic_modeling_autoencoding/20ng/myetm/checkpoints/myetm_0.pt')
myetm_pretrained.load_state_dict(myetm_checkpoint['model'])
myetm_pretrained.cuda()
myetm_pretrained.eval()
# _, myetm_normalized_bows, _ = myetm_pretrained.convert2bow(test_dataset, "test")
myetm_reconst_loss, myetm_theta, _, _= myetm_pretrained(test_dataset, "test")


In [None]:
myetm_theta = myetm_theta.detach().cpu().numpy()
myetm_reconst_loss = myetm_reconst_loss.detach().cpu().numpy()
print("ETM Test Dataset Topic Proportion Shape: "+ str(myetm_theta.shape))
print("ETM Test Dataset Reconstructtion Loss shape: "+ str(myetm_reconst_loss.shape))

In [None]:
#ppl from myetm
# import math
# etm_ppl = math.exp(myetm_reconst_loss[3844]/len(test_dataset[3844]))
# lda_ppl = math.exp(141.4214211499424/len(test_dataset[3844]))
# print(lda_ppl)
# print(etm_ppl)

In [None]:
#to get actual tokens
docu_tokens_lst = []
for docu in test_dataset:
    tokens_lst = []
    for idx in docu.numpy():
        tokens_lst.append(swapped_vocab_docu[idx])
    tokens_str = ' '.join(tokens_lst)
    docu_tokens_lst.append(tokens_str)
print(len(docu_tokens_lst))
print(docu_tokens_lst[3844])

In [None]:
#dvtms
import math

dvtm_without_lts_theta_lst = []
dvtm_with_lts_theta_lst = []
ppl_data = []

for idx, docu in enumerate(tqdm(test_dataset, desc='Please wait..')):
    docu = docu.cuda()
    myetm_ppl = math.exp(myetm_reconst_loss[idx]/len(docu))
    dvtm_without_lts_outputs, _ ,dvtm_without_lts_thetas = dvtm_without_lts_pretrained(docu, "test")
#     print(dvtm_without_lts_thetas.shape)
    dvtm_without_lts_thetas = dvtm_without_lts_thetas.mean(dim=1)
#     dvtm_without_lts_thetas, _ = torch.max(dvtm_without_lts_thetas, dim=1)
    dvtm_without_lts_ppl = math.exp(-torch.sum(dvtm_without_lts_outputs)/len(docu))
    dvtm_without_lts_theta = torch.sum(dvtm_without_lts_thetas, dim=0, keepdim=True)
    dvtm_without_lts_theta = dvtm_without_lts_theta/torch.sum(dvtm_without_lts_theta)
#     topk_val, topk_idxs = torch.topk(dvtm_without_lts_theta, k=3, dim=1)
#     topk_idxs = torch.squeeze(topk_idxs)
#     topk_idxs = topk_idxs.detach().cpu().numpy()
#     topk_idxs_without_lts = set(topk_idxs.tolist())
    
    dvtm_with_lts_outputs, _ ,dvtm_with_lts_thetas = dvtm_with_lts_pretrained(docu, "test")
    dvtm_with_lts_thetas = dvtm_with_lts_thetas.mean(dim=1)
#     print(dvtm_with_lts_thetas.shape)
#     dvtm_with_lts_thetas, _ = torch.max(dvtm_with_lts_thetas, dim=1)
    dvtm_with_lts_ppl = math.exp(-torch.sum(dvtm_with_lts_outputs)/len(docu))
    dvtm_with_lts_theta = torch.sum(dvtm_with_lts_thetas, dim=0, keepdim=True)
    dvtm_with_lts_theta = dvtm_with_lts_theta/torch.sum(dvtm_with_lts_theta)
#     topk_val, topk_idxs = torch.topk(dvtm_with_lts_theta, k=3, dim=1)
#     topk_idxs = torch.squeeze(topk_idxs)
#     topk_idxs = topk_idxs.detach().cpu().numpy()
#     topk_idxs_with_lts = set(topk_idxs.tolist())
    
#     if topk_idxs_without_lts.intersection(fixed_set) != fixed_set and topk_idxs_with_lts.intersection(fixed_set) != fixed_set:
    dvtm_without_lts_theta_lst.append(dvtm_without_lts_theta.detach().cpu())
    dvtm_with_lts_theta_lst.append(dvtm_with_lts_theta.detach().cpu())
    ppl_data.append((myetm_ppl, dvtm_without_lts_ppl, dvtm_with_lts_ppl, dvtm_with_lts_ppl-myetm_ppl, dvtm_with_lts_ppl-dvtm_without_lts_ppl))
#     break
dvtm_without_lts_theta = torch.cat(dvtm_without_lts_theta_lst, dim=0)
dvtm_without_lts_theta = dvtm_without_lts_theta.numpy()

dvtm_with_lts_theta = torch.cat(dvtm_with_lts_theta_lst, dim=0)
dvtm_with_lts_theta = dvtm_with_lts_theta.numpy()

print("DVITM without lts Test Dataset Topic Proportion Shape: "+ str(dvtm_without_lts_theta.shape))
print("DVITM with lts Test Dataset Topic Proportion Shape: "+ str(dvtm_with_lts_theta.shape))

In [None]:
#hard, soft, multi-vqvae
import math

soft_theta_lst = []
hard_theta_lst = []
multi_theta_lst = []
ppl_data = []

for idx, docu in enumerate(tqdm(test_dataset, desc='Please wait..')):
    docu = docu.cuda()
    hard_outputs, hard_thetas, _ = hard_pretrained(docu, "test")
    hard_ppl = math.exp(-torch.sum(hard_outputs)/len(docu))
    hard_theta = torch.sum(hard_thetas, dim=0, keepdim=True)
    hard_theta = hard_theta/torch.sum(hard_theta)
    hard_theta_lst.append(hard_theta.detach().cpu())
#     print(hard_theta.shape)
#     print(hard_ppl)
    
    
    soft_outputs, soft_theta,_ = soft_pretrained(docu, "test")
    soft_ppl = math.exp(-torch.sum(soft_outputs)/len(docu))
    soft_theta_lst.append(soft_theta.detach().cpu())
#     print(soft_theta.shape)
#     print(soft_ppl)
    
    multi_outputs, multi_thetas,_ = multi_pretrained(docu, "test")
    multi_ppl = math.exp(-torch.sum(multi_outputs)/len(docu))
    multi_theta = torch.mean(multi_thetas, dim=0)
    multi_theta = torch.sum(multi_theta, dim=0, keepdim=True)
    multi_theta = multi_theta/torch.sum(multi_theta)
    multi_theta_lst.append(multi_theta.detach().cpu())
#     print(multi_theta.shape)
#     print(multi_ppl)
#     break
    
    ppl_data.append((hard_ppl, soft_ppl, multi_ppl, multi_ppl-hard_ppl))

hard_theta = torch.cat(hard_theta_lst, dim=0)
hard_theta = hard_theta.numpy()

soft_theta = torch.cat(soft_theta_lst, dim=0)
soft_theta = soft_theta.numpy()

multi_theta = torch.cat(multi_theta_lst, dim=0)
multi_theta = multi_theta.numpy()

print("Hard VQVAE Test Dataset Topic Proportion Shape: "+ str(hard_theta.shape))
print("Soft VQVAE Test Dataset Topic Proportion Shape: "+ str(soft_theta.shape))
print("Multi VQVAE Test Dataset Topic Proportion Shape: "+ str(multi_theta.shape))

In [None]:
ppl_df_data = np.array(ppl_data)
df = pd.DataFrame(ppl_df_data, columns=['hard_ppl','soft_ppl','multi_ppl','diff_ppl_multi_hard'])
df.sort_values(by=['diff_ppl_multi_hard']).tail(200)

In [None]:
ppl_df_data = np.array(ppl_data)
df = pd.DataFrame(ppl_df_data, columns=['etm','dvtm_without_lts','dvtm_with_lts','diff_ppl_dvtm_with_etm','diff_ppl_dvtm_with_dvtm_without'])
pd.set_option('display.max_rows', df.shape[0]+1)
# df.loc[3844]
# df.sort_values(by=['diff_ppl_dvtm_with_etm','diff_ppl_dvtm_with_dvtm_without']).head(200)

df.loc[(df['dvtm_without_lts'] < df['etm']) & (df['dvtm_with_lts'] < df['dvtm_without_lts'])]


In [None]:
df_subset = df.loc[(df['dvtm_without_lts'] < df['etm']) & (df['dvtm_with_lts'] < df['dvtm_without_lts'])]
df_subset_idxs = list(df_subset.index)

In [None]:
#vqtms: final examples: 3907, 3844
#dvitms: final examples: 
count = 1

num_topics = 20
doc_idxs = [3871]
for doc_idx in doc_idxs:
    fig = make_subplots(rows=5, 
                    cols=1,  
                    subplot_titles=("CGS-LDA","ProdLDA", "ETM",'DVITM without lts loss','DVITM with lts loss'))
#                    specs=[[{"type": "table"},{"type": "table"},{"type": "table"},{"type": "table"}],
#                           [{"type": "heatmap"},{"type": "heatmap"},{"type": "heatmap"},{"type": "heatmap"}]])

#     count = count + 1
#     lda_entropy_trace = go.Heatmap(z = np.expand_dims(lda_theta[doc_idx],axis=0),
#                                         x = np.arange(num_topics),
#                                         y = np.array([(df_entropy['doc'][doc_idx]," ", round(df_entropy['lda'][doc_idx],2))]),
#                                         #y = np.array([round(df_entropy['lda'][doc_idx],2)]),
#                                         showscale=False)
#                                         #colorscale='gray')
    
#     myetm_entropy_trace = go.Heatmap(z = np.expand_dims(myetm_theta[doc_idx],axis=0),
#                                         x = np.arange(num_topics),
#                                         y = np.array([round(df_entropy['etm'][doc_idx],2)]),
#                                         showscale=False,
#                                         colorscale='gray')
#     tokens_trace = go.Table(header=dict(values=["Tokens"],font=dict(size=10),align="left"),
#                             cells=dict(values=[docu_tokens_lst[doc_idx]],align = "left"))
    print(docu_tokens_lst[doc_idx])
    lda_theta_trace = go.Heatmap(z = np.expand_dims(lda_theta[doc_idx],axis=0),
                                        x = np.arange(num_topics),
                                        y = [''],
                                        #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
                                        showscale=False,
                                        colorscale='algae')
    
    prodlda_theta_trace = go.Heatmap(z = np.expand_dims(myprodlda_theta[doc_idx],axis=0),
                                        x = np.arange(num_topics),
                                        y = [''],
                                        #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
                                        showscale=False,
                                        colorscale='algae')
    
    
    etm_theta_trace = go.Heatmap(z = np.expand_dims(myetm_theta[doc_idx],axis=0),
                                        x = np.arange(num_topics),
                                        y = [''],
                                        #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
                                        showscale=False,
                                        colorscale='algae')
    
    dvtm_without_lts_theta_trace = go.Heatmap(z = np.expand_dims(dvtm_without_lts_theta[doc_idx],axis=0),
                                    x = np.arange(num_topics),
                                    y = [''],
                                    #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
                                    showscale=False,
                                    colorscale='algae')
    
    dvtm_with_lts_theta_trace = go.Heatmap(z = np.expand_dims(dvtm_with_lts_theta[doc_idx],axis=0),
                                    x = np.arange(num_topics),
                                    y = [''],
                                    #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
                                    showscale=False,
                                    colorscale='algae')
    
#     hard_theta_trace = go.Heatmap(z = np.expand_dims(hard_theta[doc_idx],axis=0),
#                                         x = np.arange(num_topics),
#                                         y = [''],
#                                         #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
#                                         showscale=False,
#                                         colorscale='algae')
        
#     soft_theta_trace = go.Heatmap(z = np.expand_dims(soft_theta[doc_idx],axis=0),
#                                         x = np.arange(num_topics),
#                                         y = [''],
#                                         #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
#                                         showscale=False,
#                                         colorscale='algae')
#     multi_theta_trace = go.Heatmap(z = np.expand_dims(multi_theta[doc_idx],axis=0),
#                                         x = np.arange(num_topics),
#                                         y = [''],
#                                         #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
#                                         showscale=False,
#                                         colorscale='algae')
#     fig.append_trace(tokens_trace , count,1)
    fig.append_trace(lda_theta_trace , 1, count)
    fig.append_trace(prodlda_theta_trace , 2, count)
    fig.append_trace(etm_theta_trace , 3, count)
    fig.append_trace(dvtm_without_lts_theta_trace , 4, count)
    fig.append_trace(dvtm_with_lts_theta_trace , 5, count)
        
#     fig.append_trace(hard_theta_trace , 3, count)
#     fig.append_trace(soft_theta_trace , 4, count)
#     fig.append_trace(multi_theta_trace , 5, count)

    fig.update_layout(height=500, width=500)
    fig.update_xaxes(nticks=num_topics, type='category',showticklabels=True, tickfont=dict(size=8) )
    fig.update_yaxes(showticklabels=True, type='category')


    fig.show()

In [None]:
#LDA
# 0	0.46368	[god, christian, jesus, people, bible, church, life, religion, faith, man]
# 1	0.32561	[file, entry, program, line, output, section, set, return, error, read]
# 2	0.37114	[university, fax, phone, write, internet, center, research, computer, article, information]
# 3	0.35226	[time, back, bike, write, thing, make, dod, ride, turn, leave]
# 4	0.32547	[power, ground, wire, work, water, current, circuit, make, sound, run]
# 5	0.32408	[number, write, point, article, time, rate, bit, give, man, difference]
# 6	0.33109	[problem, article, people, food, drug, write, study, year, time, effect]
# 7	0.37060	[gun, law, people, state, government, weapon, crime, police, kill, case]
# 8	0.37749	[drive, card, disk, system, problem, driver, mac, work, memory, dos]
# 9	0.44660	[people, turkish, armenian, armenians, woman, kill, child, turkey, armenia, turks]
# 10	0.33743	[question, make, argument, exist, claim, write, people, point, evidence, thing]
# 11	0.46893	[israel, war, jews, israeli, state, write, article, people, country, world]
# 12	0.34694	[president, make, work, money, people, tax, year, job, pay, clinton]
# 13	0.59197	[game, team, play, year, player, win, season, hockey, score, league]
# 14	0.37982	[key, chip, encryption, system, government, clipper, security, privacy, phone, secure]
# 15	0.38742	[car, buy, price, sell, good, sale, offer, drive, pay, engine]
# 16	0.32969	[post, send, list, book, group, information, mail, address, include, copy]
# 17	0.41678	[window, file, image, program, run, display, version, application, server, software]
# 18	0.33229	[write, article, people, thing, make, hear, read, good, post, opinion]
# 19	0.45590	[space, launch, nasa, system, earth, satellite, year, orbit, project, mission]
#ProdLDA
# 0	0.4658	[minnesota, angeles, toronto, san, montreal, vancouver, ottawa, stanley, calgary, louis]
# 1	0.3945	[finish, play, team, hitter, offense, smith, defensive, ice, tie, hit]
# 2	0.5129	[christian, god, scripture, interpretation, jesus, resurrection, teaching, doctrine, existence, holy]
# 3	0.3626	[troops, israel, border, turks, army, israeli, fire, arab, civilian, minority]
# 4	0.5116	[ram, windows, quadra, fine, speed, scsi, faster, microsoft, apple, external]
# 5	0.3740	[entry, char, compile, db, file, section, variable, contest, distribution, remark]
# 6	0.5557	[christian, scripture, god, teaching, resurrection, existence, jesus, christianity, doctrine, biblical]
# 7	0.3890	[ibm, interface, transfer, external, virtual, cpu, ram, dec, path, default]
# 8	0.4294	[fire, car, safety, btw, andy, surrender, country, rider, cold, stupid]
# 9	0.4889	[wiretap, escrow, drug, agency, warrant, clipper, illegal, encryption, crime, country]
# 10	0.4657	[wiretap, encryption, escrow, nsa, chip, clipper, pat, cheaper, scheme, car]
# 11	0.4874	[turks, armenian, army, turkish, russian, armenia, united, mountain, armenians, director]
# 12	0.4587	[music, fine, crash, ram, sl, external, simm, honda, wm, hd]
# 13	0.4824	[bmw, car, bike, rider, baseball, btw, cop, ball, hit, motorcycle]
# 14	0.4628	[satellite, mission, distribute, nasa, spacecraft, km, space, earth, distribution, module]
# 15	0.5309	[fine, ram, windows, apple, amp, simm, external, scsi, crash, quadra]
# 16	0.4720	[agency, encryption, cryptography, telephone, wiretap, enforcement, des, privacy, distribution, encrypt]
# 17	0.3659	[neighbor, christian, heart, harm, building, scripture, jesus, daughter, woman, holy]
# 18	0.4507	[scsi, ram, hd, external, mhz, meg, scsus, ide, fine, bus]
# 19	0.5172	[sl, wm, ram, mi, external, hd, connector, mg, mb, mw]
#ETM
# 0	0.34885	[get, know, one, say, think, like, see, thing, people, time]
# 1	0.3254	[use, may, make, case, many, also, part, however, system, president]
# 2	0.45049	[god, jesus, christian, say, believe, bible, one, christ, make, belief]
# 3	0.3682	[gun, people, child, kill, drug, crime, weapon, police, case, claim]
# 4	0.52558	[datum, space, db, launch, output, hus, widget, dod, nasa, sun]
# 5	0.33899	[file, use, program, send, available, list, code, email, please, line]
# 6	0.41394	[hockey, team, new, division, san, canada, nhl, toronto, york, gm]
# 7	0.34391	[new, look, buy, price, good, sell, include, package, offer, like]
# 8	0.38514	[car, power, use, drive, speed, engine, wire, water, fast, low]
# 9	0.46506	[game, year, play, win, team, player, season, go, good, run]
# 10	0.33123	[information, make, get, please, mail, use, go, file, help, take]
# 11	0.3378	[book, first, one, time, internet, study, earth, author, history, find]
# 12	0.35159	[university, group, science, information, computer, fax, center, year, call, research]
# 13	0.36281	[thanks, david, john, appreciate, steve, mark, wonder, jim, mike, michael]
# 14	0.37497	[write, article, post, question, read, opinion, ask, please, yes, answer]
# 15	0.33965	[period, la, pt, vs, de, van, pp, cal, power, second]
# 16	0.36068	[key, government, law, use, encryption, state, chip, public, right, security]
# 17	0.33544	[go, take, back, one, day, put, get, right, also, call]
# 18	0.39872	[use, drive, system, window, card, run, windows, disk, problem, image]
# 19	0.48538	[israel, people, war, israeli, jews, turkish, armenians, country, armenian, government]

#DVITM without lts loss
# 0	0.49933	[msg, eat, heaven, koresh, love, one, food, jesus, hell, sin]
# 1	0.49305	[koresh, love, heaven, eat, msg, one, hell, sin, pray, cult]
# 2	0.45914	[msg, eat, koresh, food, heaven, compound, one, love, hell, heart]
# 3	0.45914	[msg, eat, koresh, food, heaven, compound, love, one, hell, heart]
# 4	0.46723	[msg, eat, koresh, love, heaven, one, compound, hell, food, pray]
# 5	0.44593	[msg, eat, koresh, food, compound, heaven, one, heart, fbi, love]
# 6	0.46723	[msg, eat, koresh, food, heaven, one, love, pray, hell, compound]
# 7	0.46723	[msg, eat, koresh, heaven, love, one, hell, pray, food, compound]
# 8	0.42821	[msg, eat, koresh, food, compound, heaven, one, heart, love, child]
# 9	0.45914	[msg, eat, koresh, food, one, heaven, compound, love, hell, heart]
# 10	0.42445	[mother, father, child, msg, love, eat, die, compound, koresh, heart]
# 11	0.45914	[msg, eat, koresh, food, heaven, one, love, compound, hell, heart]
# 12	0.49725	[love, koresh, eat, pray, heaven, one, hell, msg, sin, god]
# 13	0.47096	[card, rg, mr, ax, vs, eus, oo, dos, bhj, slot]
# 14	0.50001	[lebanese, resistance, israeli, buffer, zone, lebanon, israel, palestinian, civilian, bomb]
# 15	0.46723	[msg, eat, koresh, one, food, heaven, love, hell, pray, compound]
# 16	0.45914	[msg, eat, koresh, compound, food, one, heaven, love, heart, hell]
# 17	0.45914	[msg, eat, koresh, love, one, heaven, hell, compound, food, heart]
# 18	0.46723	[msg, eat, koresh, heaven, one, love, food, compound, pray, hell]
# 19	0.42821	[msg, koresh, eat, compound, one, food, heart, heaven, love, child]

#DVITM with lts loss
# 0	0.56466	[andor, btw, heat, flame, reach, hd, eliminate, tax, shipping, hus]
# 1	0.36240	[sexual, provide, sex, male, physical, animal, morality, natural, activity, normal]
# 2	0.36994	[need, form, hus, surface, side, change, ring, normal, compound, order]
# 3	0.55837	[compound, xterm, andor, plane, initial, patient, rg, colormap, batf, hus]
# 4	0.31337	[necessarily, accident, least, person, cover, society, fact, moral, understanding, eternal]
# 5	0.33684	[show, april, russia, team, tv, local, allen, movie, board, satellite]
# 6	0.33354	[must, responsibility, human, conflict, means, desire, without, action, accept, upon]
# 7	0.49377	[compound, batf, flame, fire, btw, dod, proper, andor, plane, scsus]
# 8	0.32176	[compound, strip, wide, main, ground, rule, inside, four, shop, path]
# 9	0.38172	[andor, part, nuclear, call, server, include, char, date, game, concept]
# 10	0.38304	[technology, company, network, phone, business, corporation, development, agency, sector, government]
# 11	0.50270	[dod, target, zone, al, civilian, oo, batf, rg, mi, border]
# 12	0.44349	[race, impact, compare, hitter, previous, pitcher, truck, mark, braves, far]
# 13	0.43725	[hockey, canada, state, cover, southern, board, engineering, soon, canadian, helmet]
# 14	0.54549	[colormap, compound, initial, andor, xlib, range, distribution, detail, info, xterm]
# 15	0.51924	[koresh, batf, believe, die, btw, compound, else, hell, eternal, everyone]
# 16	0.57708	[xterm, hus, colormap, dod, pl, buf, char, printf, int, routine]
# 17	0.41406	[like, much, probably, look, feeling, worse, feel, hear, mine, sound]
# 18	0.48340	[every, compound, hit, writer, dod, imho, btw, eat, injury, interpretation]
# 19	0.40255	[cell, range, colormap, sequence, distance, input, null, match, compound, clock]




In [None]:
#LDA
# 0	0.46368	[god, christian, jesus, people, bible, church, life, religion, faith, man]
# 1	0.32561	[file, entry, program, line, output, section, set, return, error, read]
# 2	0.37114	[university, fax, phone, write, internet, center, research, computer, article, information]
# 3	0.35226	[time, back, bike, write, thing, make, dod, ride, turn, leave]
# 4	0.32547	[power, ground, wire, work, water, current, circuit, make, sound, run]
# 5	0.32408	[number, write, point, article, time, rate, bit, give, man, difference]
# 6	0.33109	[problem, article, people, food, drug, write, study, year, time, effect]
# 7	0.37060	[gun, law, people, state, government, weapon, crime, police, kill, case]
# 8	0.37749	[drive, card, disk, system, problem, driver, mac, work, memory, dos]
# 9	0.44660	[people, turkish, armenian, armenians, woman, kill, child, turkey, armenia, turks]
# 10	0.33743	[question, make, argument, exist, claim, write, people, point, evidence, thing]
# 11	0.46893	[israel, war, jews, israeli, state, write, article, people, country, world]
# 12	0.34694	[president, make, work, money, people, tax, year, job, pay, clinton]
# 13	0.59197	[game, team, play, year, player, win, season, hockey, score, league]
# 14	0.37982	[key, chip, encryption, system, government, clipper, security, privacy, phone, secure]
# 15	0.38742	[car, buy, price, sell, good, sale, offer, drive, pay, engine]
# 16	0.32969	[post, send, list, book, group, information, mail, address, include, copy]
# 17	0.41678	[window, file, image, program, run, display, version, application, server, software]
# 18	0.33229	[write, article, people, thing, make, hear, read, good, post, opinion]
# 19	0.45590	[space, launch, nasa, system, earth, satellite, year, orbit, project, mission]
#ETM
# 0	0.34885	[get, know, one, say, think, like, see, thing, people, time]
# 1	0.3254	[use, may, make, case, many, also, part, however, system, president]
# 2	0.45049	[god, jesus, christian, say, believe, bible, one, christ, make, belief]
# 3	0.3682	[gun, people, child, kill, drug, crime, weapon, police, case, claim]
# 4	0.52558	[datum, space, db, launch, output, hus, widget, dod, nasa, sun]
# 5	0.33899	[file, use, program, send, available, list, code, email, please, line]
# 6	0.41394	[hockey, team, new, division, san, canada, nhl, toronto, york, gm]
# 7	0.34391	[new, look, buy, price, good, sell, include, package, offer, like]
# 8	0.38514	[car, power, use, drive, speed, engine, wire, water, fast, low]
# 9	0.46506	[game, year, play, win, team, player, season, go, good, run]
# 10	0.33123	[information, make, get, please, mail, use, go, file, help, take]
# 11	0.3378	[book, first, one, time, science, study, earth, author, history, find]
# 12	0.35159	[university, group, internet, information, computer, fax, center, year, call, research]
# 13	0.36281	[thanks, david, john, appreciate, steve, mark, wonder, jim, mike, michael]
# 14	0.37497	[write, article, post, question, read, opinion, ask, please, yes, answer]
# 15	0.33965	[period, la, pt, vs, de, van, pp, cal, power, second]
# 16	0.36068	[key, government, law, use, encryption, state, chip, public, right, security]
# 17	0.33544	[go, take, back, one, day, put, get, right, also, call]
# 18	0.39872	[use, drive, system, window, card, run, windows, disk, problem, image]
# 19	0.48538	[israel, people, war, israeli, jews, turkish, armenians, country, armenian, government]
#Hard
# 0	0.3978	[murder, law, criminal, child, death, court, police, case, woman, man]
# 1	0.3923	[article, page, journal, magazine, constitution, newspaper, editor, report, book, news]
# 2	0.3738	[really, think, know, want, something, get, thing, feel, lot, anything]
# 3	0.3290	[ball, back, right, water, inside, away, around, space, front, small]
# 4	0.6077	[god, christ, jesus, christianity, christian, religious, religion, holy, church, faith]
# 5	0.3495	[help, want, make, need, try, take, get, must, able, give]
# 6	0.4231	[john, mike, david, steve, michael, james, chris, tom, jim, scott]
# 7	0.4001	[write, read, publish, book, writer, please, reader, mail, page, author]
# 8	0.6378	[software, pc, computer, server, disk, macintosh, windows, processor, interface, modem]
# 9	0.5482	[car, vehicle, engine, driver, truck, bus, wheel, motor, drive, speed]
# 10	0.6024	[ftp, unix, usenet, mb, newsgroup, pgp, server, toolkit, modem, faq]
# 11	0.3340	[think, might, need, much, whether, want, even, something, really, anything]
# 12	0.3808	[please, ask, call, tell, hear, want, answer, listen, know, let]
# 13	0.4515	[use, method, technique, allow, equipment, manufacture, systems, tool, available, common]
# 14	0.6510	[game, league, season, playoff, team, player, win, play, cup, hockey]
# 15	0.4908	[mail, internet, phone, email, fax, information, telephone, service, access, services]
# 16	0.3348	[people, many, population, number, least, among, million, dead, muslims, living]
# 17	0.4787	[system, digital, software, systems, data, computer, use, video, standard, format]
# 18	0.4883	[university, institute, science, research, college, education, program, professor, school, engineering]
# 19	0.3281	[year, last, two, government, world, three, time, since, first, week]

#Soft
# 0	0.40969	[blood, hus, knife, cancer, die, death, woman, disease, drug, girl]
# 1	0.50091	[university, science, research, institute, college, professor, engineering, education, physics, school]
# 2	0.35567	[really, know, think, somebody, want, get, everybody, ca, thing, lot]
# 3	0.41657	[near, station, line, area, north, road, east, west, across, south]
# 4	0.60767	[god, christ, jesus, religion, christianity, religious, christian, holy, faith, church]
# 5	0.34947	[help, want, get, make, try, take, need, must, able, give]
# 6	0.34497	[think, know, anyone, anything, whether, believe, really, understand, something, anybody]
# 7	0.38435	[write, read, publish, book, reader, writer, please, mail, tell, reading]
# 8	0.50053	[billion, percent, price, million, rate, dollar, increase, higher, per, value]
# 9	0.59022	[car, engine, driver, vehicle, truck, motorcycle, motor, wheel, bmw, honda]
# 10	0.63715	[software, server, interface, pc, computer, windows, macintosh, disk, modem, user]
# 11	0.45493	[use, system, systems, method, specific, function, common, data, type, standard]
# 12	0.42623	[mail, please, fax, phone, email, call, telephone, message, address, information]
# 13	0.4241	[two, three, one, four, five, six, second, another, single, first]
# 14	0.42306	[john, mike, steve, david, michael, chris, james, tom, jim, scott]
# 15	0.41854	[government, security, administration, military, state, federal, agency, policy, foreign, economic]
# 16	0.42642	[people, many, say, among, muslims, americans, jews, population, number, live]
# 17	0.42113	[article, magazine, journal, page, constitution, newspaper, editor, edition, book, publish]
# 18	0.62005	[game, season, league, win, cup, match, playoff, team, final, play]
# 19	0.3505	[really, good, pretty, little, something, lot, feel, bit, think, look]

#Multi
# 0	0.5670	[mb, homosexual, atheist, mg, cache, simm, valid, mhz, hitler, device]
# 1	0.3279	[kevin, mike, justice, al, goal, director, brother, bush, think, son, leader]
# 2	0.3805	[frequency, mhz, budget, space, digital, research, technology, useful, operate, domain]
# 3	0.4839	[fax, mail, please, email, reply, phone, comment, sorry, posting, mailing]
# 4	0.3460	[index, atlanta, rate, department, percent, fax, billion, economic, chicago, report]
# 5	0.5937	[oo, eus, batf, simm, koresh, wm, stanley, mit, dod, nec]
# 6	0.4527	[constitution, lewis, lebanese, murder, gordon, insurance, british, waco, wm, koresh]
# 7	0.3606	[army, islamic, military, la, police, fire, heavy, palestinian, wing, tank]
# 8	0.3449	[request, letter, food, secretary, case, congress, disease, director, judge, animal]
# 9	0.5344	[scsi, rocket, spacecraft, hd, fuel, radar, ac, converter, station, motherboard]
# 10	0.6278	[datum, buf, stephanopoulos, compression, pixel, vga, delete, bias, genocide, encrypt]
# 11	0.4720	[datum, mw, turkey, islamic, firearm, armenians, turks, turkish, island, tank]
# 12	0.3695	[goal, best, run, win, try, game, ahead, good, award, drive]
# 13	0.5676	[algorithm, binary, unix, server, interface, byte, configuration, os, cpu, encryption]
# 14	0.5825	[usenet, uucp, printf, btw, cryptography, newsgroup, vga, xt, widget, xlib]
# 15	0.3411	[relative, economic, former, significant, israel, family, male, lebanon, slave, war]
# 16	0.4289	[audio, bill, coverage, digital, law, card, services, full, cover, deep]
# 17	0.3409	[state, religious, town, local, police, texas, political, power, culture, government]
# 18	0.4240	[sick, care, option, son, older, insurance, family, child, choose, mine]
# 19	0.5686	[calgary, oo, min, bay, beat, buffer, jumper, neighbor, toolkit, mi]

In [None]:
# docus_len = []
# for docu in test_dataset:
#     docus_len.append(len(docu))

In [None]:
# import glob
# test_files = sorted(glob.glob('../20ng/data/test_mallet/*.txt'))
# len_file = []
# for each in test_files:
#     with open(each, 'r') as each_file:
#         len_file.append((int(each.split('/')[4].replace('.txt','')), len(each_file.read().split(' '))))

# df_file = pd.DataFrame(len_file, columns=['doc','len'])
# df_file = df_file.sort_values(by='doc')
# df_file = df_file.reset_index(drop=True)
# print(df_file)


In [None]:
pd.set_option('display.height', 1000)
pd.set_option('display.max_rows', 5000)
pd.set_option('display.max_columns', 9)
#pd.set_option('display.width', 2000)
pd.set_option('display.max_colwidth', -1)

entropy_data = []
docu_count = 0

for theta0, theta1, theta2, docu_len, docu in zip(lda_theta, myetm_theta, soft_theta, docus_len, test_dataset):
    docu = [swapped_vocab_docu[idx] for idx in docu.numpy()]
    docu = ' '.join(docu)
    
    lda_entropy = -np.sum(theta0 * np.log2(theta0))
    etm_entropy = -np.sum(theta1 * np.log2(theta1))
    soft_entropy = -np.sum(theta2 * np.log2(theta2))
    
#     if soft_negentropy > etm_negentropy and etm_negentropy > lda_negentropy:
#         condn = True
#     else:
#         condn = False
    diff_soft_etm = soft_entropy - etm_entropy
    diff_etm_lda = etm_entropy - lda_entropy
    diff_soft_lda = soft_entropy - lda_entropy
    
    #negentropy_data.append((docu_count, lda_negentropy, etm_negentropy, soft_negentropy, condn))
    entropy_data.append((docu_count,
                         docu_len,
                         docu,
                        lda_entropy, 
                        etm_entropy, 
                        soft_entropy,
                        diff_etm_lda,
                        diff_soft_etm,
                        diff_soft_lda))

    docu_count = docu_count + 1

#df_entropy = pd.DataFrame(negentropy_data, columns=['doc_idx','lda','etm','soft','condn'])
df_entropy = pd.DataFrame(entropy_data, columns=['doc_idx',
                                                 'doc_len',
                                                 'doc',
                                                    'lda',
                                                    'etm',
                                                    'soft',
                                                    'diff_etm_lda',
                                                    'diff_soft_etm',
                                                    'diff_soft_lda'])

In [None]:
stat_winning_docs_soft_lda = sum(n < 0 for n in df_entropy['diff_soft_lda'])
stat_losing_docs_soft_lda = sum(n > 0 for n in df_entropy['diff_soft_lda'])
print(stat_winning_docs_soft_lda)
print(stat_losing_docs_soft_lda)

stat_winning_docs_soft_etm = sum(n < 0 for n in df_entropy['diff_soft_etm'])
stat_losing_docs_soft_etm = sum(n > 0 for n in df_entropy['diff_soft_etm'])
print(stat_winning_docs_soft_etm)
print(stat_losing_docs_soft_etm)

### Docs where Soft is winning only against  ETM

In [None]:
df_entropy[(df_entropy['diff_soft_etm'] < 0)].sort_values(by='diff_soft_etm')

### Docs where Soft is winning only against LDA

In [None]:
df_entropy[(df_entropy['diff_soft_lda'] < 0) & (df_entropy['doc_len']> 5)].sort_values(by='diff_soft_lda')[60:90]

In [None]:
#winning docs = 4805, 3616, 318, 2306, 62
fig_entropy = make_subplots(rows=10, cols=2,  
                    subplot_titles=("LDA","Soft VQ-VAE"))
doc_idxs = [4805, 3616, 318, 2306, 62]
count = 0 
for doc_idx in doc_idxs:
    count = count + 1
    lda_entropy_trace = go.Heatmap(z = np.expand_dims(lda_theta[doc_idx],axis=0),
                                        x = np.arange(num_topics),
                                        y = np.array([(df_entropy['doc'][doc_idx]," ", round(df_entropy['lda'][doc_idx],2))]),
                                        #y = np.array([round(df_entropy['lda'][doc_idx],2)]),
                                        showscale=False)
                                        #colorscale='gray')
    
#     myetm_entropy_trace = go.Heatmap(z = np.expand_dims(myetm_theta[doc_idx],axis=0),
#                                         x = np.arange(num_topics),
#                                         y = np.array([round(df_entropy['etm'][doc_idx],2)]),
#                                         showscale=False,
#                                         colorscale='gray')
    
    soft_entropy_trace = go.Heatmap(z = np.expand_dims(soft_theta[doc_idx],axis=0),
                                        x = np.arange(num_topics),
                                        y = np.array([round(df_entropy['soft'][doc_idx],2)]),
                                        #y = np.array([(df_entropy['doc'][doc_idx],"  "+str(round(df_entropy['soft'][doc_idx],2)))]),
                                        showscale=False)
                                        #colorscale='gray')
    
    fig_entropy.append_trace(lda_entropy_trace , count,1)
    #fig_entropy.append_trace(myetm_entropy_trace , count,2)
    fig_entropy.append_trace(soft_entropy_trace , count,2)

fig_entropy.update_layout(height=500)
fig_entropy.update_xaxes(nticks=num_topics, type='category',showticklabels=True, tickfont=dict(size=8) )
fig_entropy.update_yaxes(showticklabels=True, type='category')


fig_entropy.show()