## Create Required Embeddings

Note: This section can be skipped if embeddings are already prepared

In [3]:
import gensim
import pickle
import os
import numpy as np
import pandas as pd
import sys
from tqdm import tqdm
# sys.setdefaultencoding() does not exist, here!
# reload(sys)  # Reload does the trick!
# sys.setdefaultencoding('UTF8')

# sys.setdefaultencoding('UTF8')

data_file= "data/drug_review/preprocessed_reviews.csv"        # default='', help='a .txt file containing the corpus'

dim_rho= 200                                               #default=300, help='dimensionality of the word embeddings'
min_count= 4                                               #default=2, help='minimum term frequency (to define the vocabulary)'
sg= 1                                                      # default=1, help='whether to use skip-gram'
# workers= 6                                                #default=25, help='number of CPU cores'
negative_samples= 10                                       # default=10, help='number of negative samples'
window_size= 8                                             # default=4, help='window size to determine context'
iters= 50                                                  #default=50, help='number of iterationst'

emb_file= "embeddings/embeddings"+"_dim_"+str(dim_rho)+"_min_count_"+str(min_count)+"_sg_"+str(sg)+"_negative_samples_"+str(negative_samples)+"_window_size_"+str(window_size)+"_iters_"+str(iters)+".txt"                      #default='embeddings.txt', help='file to save the word embeddings'



In [4]:

# Class for a memory-friendly iterator over the dataset
class MySentences(object):
    def __init__(self, filename):
        self.filename = filename
        self.file_type = 'text'
 
    def __iter__(self):
        if self.file_type == 'text':
            for line in open(self.filename,encoding="utf8"):
                yield line.split()
        elif self.file_type == 'csv':
            for line in tqdm(self.reviews.values):
                try:
                    yield line.split()
#                     print(line)
                except AttributeError:
#                     print(line)
                    continue
                
    def __init__(self, filename,col):
        self.filename = filename
        data = pd.read_csv(filename)
        self.reviews = data[col]#[:10000]
        self.file_type = 'csv'


In [5]:
sentences = MySentences(data_file,'review') # a memory-friendly iterator
model = gensim.models.Word2Vec(sentences, min_count=min_count, sg=sg, size=dim_rho, 
    iter=(iters), negative=negative_samples, window=window_size)

100%|████████████████████████████████████████████████████████████████████████| 15000/15000 [00:00<00:00, 111225.44it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 6091.18it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 6544.27it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 6862.31it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 6743.28it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 6848.07it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 7137.36it/s]
100%|██████████████████████████████████████████████████████████████████████████| 15000/15000 [00:02<00:00, 6816.98it/s]
100%|███████████████████████████████████

In [6]:


# Write the embeddings to a file
with open(emb_file, 'w',encoding='utf8') as f:
    for v in list(model.wv.vocab):
        vec = list(model.wv.__getitem__(v))
        f.write(v + ' ')
        vec_str = ['%.9f' % val for val in vec]
        vec_str = " ".join(vec_str)
        f.write(vec_str + '\n')

## Data Preprocessing

In [7]:
# import pickle

# abc = pickle.load("data/20ng/vocab.pkl")
# abc

## Modelling


In [8]:

from __future__ import print_function

import torch
import pickle 
import numpy as np 
import os 
import math 
import random 
import sys

import matplotlib.pyplot as plt 
import data
import scipy.io

from torch import nn, optim
from torch.nn import functional as F

from etm import ETM
from utils import nearest_neighbors, get_topic_coherence, get_topic_diversity


In [9]:
# df.head()

In [10]:
# df = pd.read_csv("data/drug_review/drugsComTrain_raw.tsv",delimiter="\t")[:1000]
# df.to_csv("data/drug_review/drugs_train_1000.csv",index=None)
# reviews = df.review
# with open("train_file.txt", 'w',encoding='utf8') as f:
#     for review in reviews.values:
#         f.write(review + '\n')

In [11]:
# dataset = "20ng"

dataset =   "train_file.txt"                      #default='20ng', help='name of corpus'
data_path = 'data/drug_review/'#default='data/20ng', help='directory containing data'
emb_path = "embeddings/embeddings"+"_dim_"+str(dim_rho)+"_min_count_"+str(min_count)+"_sg_"+str(sg)+"_negative_samples_"+str(negative_samples)+"_window_size_"+str(window_size)+"_iters_"+str(iters)+".txt"#default='data/20ng_embeddings.txt', help='directory containing word embeddings'
save_path = './results'#default='./results', help='path to save results'
batch_size = 100 #default=1000, help='input batch size for training'

### model-related arguments
num_topics = 15   #default=50, help='number of topics'
rho_size = 200    #default=300, help='dimension of rho'
emb_size = 200    #default=300, help='dimension of embeddings'
t_hidden_size = 800 #default=800, help='dimension of hidden space of q(theta)'
theta_act = 'relu' #default='relu', help='tanh, softplus, relu, rrelu, leakyrelu, elu, selu, glu)'
train_embeddings = 0 #default=0, help='whether to fix rho or train it'

### optimization-related arguments
lr = 0.005 # default=0.005, help='learning rate'
lr_factor =4.0  #default=4.0, help='divide learning rate by this...'
epochs = 10 # default=20, help='number of epochs to train...150 for 20ng 100 for others'
mode = 'train'# default='train', help='train or eval model'
optimizer = 'adam'# default='adam', help='choice of optimizer'
seed = 2019# default=2019, help='random seed (default: 1)
enc_drop = 0.0# default=0.0, help='dropout rate on encoder'
clip = 0.0# default=0.0, help='gradient clipping'
nonmono = 10# default=10, help='number of bad hits allowed'
wdecay = 1.2e-6# default=1.2e-6, help='some l2 regularization'
anneal_lr = 0#  default=0, help='whether to anneal the learning rate or not'
bow_norm = 1# default=1, help='normalize the bows or not'

### evaluation, visualization, and logging-related arguments
num_words = 10  # default=10, help='number of words for topic viz' 
log_interval = 2 # default=2, help='when to log training'
visualize_every = 1 # default=10, help='when to visualize results'
eval_batch_size = 1000 # default=1000, help='input batch size for evaluation'
load_from = 'results/etm_train_file.txt_K_15_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_100_RhoSize_200_trainEmbeddings_0'
tc = 0# default=0, help='whether to compute topic coherence or not'
td = 0# default=0, help='whether to compute topic diversity or not'


In [12]:
device = torch.device("cpu")#("cuda" if torch.cuda.is_available() else "cpu")

print('\n')
np.random.seed(seed)
torch.manual_seed(seed)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed(seed)





<torch._C.Generator at 0x283ac112670>

In [13]:
vocab, train, valid, test,test_1,test_2 = data.get_data(os.path.join(data_path))
vocab_size = len(vocab)

# 1. training data
train_tokens = train['tokens']
train_counts = train['counts']
num_docs_train = len(train_tokens)

# 2. dev set
valid_tokens = valid['tokens']
valid_counts = valid['counts']
num_docs_valid = len(valid_tokens)

# 3. test data
test_tokens = test['tokens']
test_counts = test['counts']
num_docs_test = len(test_tokens)
test_1_tokens = test_1['tokens']
test_1_counts = test_1['counts']
num_docs_test_1 = len(test_1_tokens)
test_2_tokens = test_2['tokens']
test_2_counts = test_2['counts']
num_docs_test_2 = len(test_2_tokens)

embeddings = None

In [14]:
#/usr/bin/python
## get data
# 1. vocabulary

if not train_embeddings:
    emb_path = emb_path
    vect_path = os.path.join(data_path.split('/')[0], 'vocab.pkl')   
    vectors = {}
    with open(emb_path, 'rb') as f:
        for l in f:
            line = l.decode().split()
            word = line[0]
            if word in vocab:
                vect = np.array(line[1:]).astype(np.float)
                vectors[word] = vect
    embeddings = np.zeros((vocab_size, emb_size))
    words_found = 0
    for i, word in enumerate(vocab):
        try: 
            embeddings[i] = vectors[word]
            words_found += 1
        except KeyError:
            embeddings[i] = np.random.normal(scale=0.6, size=(emb_size, ))
    embeddings = torch.tensor(embeddings).to(device)
    embeddings_dim = embeddings.size()

print('=*'*100)
# print('Training an Embedded Topic Model on {} with the following settings: {}'.format(dataset.upper()))
print('=*'*100)


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


In [15]:
emb_path

'embeddings/embeddings_dim_200_min_count_4_sg_1_negative_samples_10_window_size_8_iters_50.txt'

In [16]:

## define checkpoint
if not os.path.exists(save_path):
    os.makedirs(save_path)

if mode == 'eval':
    ckpt = load_from
else:
    ckpt = os.path.join(save_path, 
        'etm_{}_K_{}_Htheta_{}_Optim_{}_Clip_{}_ThetaAct_{}_Lr_{}_Bsz_{}_RhoSize_{}_trainEmbeddings_{}'.format(
        dataset, num_topics, t_hidden_size, optimizer, clip, theta_act, 
            lr, batch_size, rho_size, train_embeddings))

## define model and optimizer
model = ETM(num_topics, vocab_size, t_hidden_size, rho_size, emb_size, 
                theta_act, embeddings, train_embeddings, enc_drop).to(device)

print('model: {}'.format(model))

if optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'asgd':
    optimizer = optim.ASGD(model.parameters(), lr=lr, t0=0, lambd=0., weight_decay=wdecay)
else:
    print('Defaulting to vanilla SGD')
    optimizer = optim.SGD(model.parameters(), lr=lr)

model: ETM(
  (t_drop): Dropout(p=0.0)
  (theta_act): ReLU()
  (alphas): Linear(in_features=200, out_features=15, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=19856, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=15, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=15, bias=True)
)


In [17]:
def train(epoch):
    model.train()
    acc_loss = 0
    acc_kl_theta_loss = 0
    cnt = 0
    indices = torch.randperm(num_docs_train)
    indices = torch.split(indices, batch_size)
    for idx, ind in enumerate(indices):
        try:
            optimizer.zero_grad()
            model.zero_grad()
            data_batch = data.get_batch(train_tokens, train_counts, ind, vocab_size, device)
            sums = data_batch.sum(1).unsqueeze(1)
            if bow_norm:
                normalized_data_batch = data_batch / sums
            else:
                normalized_data_batch = data_batch
            recon_loss, kld_theta = model(data_batch, normalized_data_batch)
            total_loss = recon_loss + kld_theta
            total_loss.backward()

            if clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            acc_loss += torch.sum(recon_loss).item()
            acc_kl_theta_loss += torch.sum(kld_theta).item()
            cnt += 1

            if idx % log_interval == 0 and idx > 0:
                cur_loss = round(acc_loss / cnt, 2) 
                cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
                cur_real_loss = round(cur_loss + cur_kl_theta, 2)

                print('Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format(
                    epoch, idx, len(indices), optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss))
        except IndexError:
            cnt+=1
            continue
    cur_loss = round(acc_loss / cnt, 2) 
    cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
    cur_real_loss = round(cur_loss + cur_kl_theta, 2)
    print('*'*100)
    print('Epoch----->{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format(
            epoch, optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss))
    print('*'*100)

In [18]:
def visualize(m, show_emb=True):
    if not os.path.exists('./results'):
        os.makedirs('./results')

    m.eval()

    queries = ['skin','cycle','effects','price','worst','best','efficacy','performance','cancer','disease']

    ## visualize topics using monte carlo
    with torch.no_grad():
        print('#'*100)
        print('Visualize topics...')
        topics_words = []
        gammas = m.get_beta()
        
        for k in range(num_topics):
            gamma = gammas[k]
            top_words = list(gamma.cpu().numpy().argsort()[-num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            topics_words.append(' '.join(topic_words))
            print('Topic {}: {}'.format(k, topic_words))

        if show_emb:
            ## visualize word embeddings by using V to get nearest neighbors
            print('#'*100)
            print('Visualize word embeddings by using output embedding matrix')
            try:
                embeddings = m.rho.weight  # Vocab_size x E
            except:
                embeddings = m.rho         # Vocab_size x E
            neighbors = []
            for word in queries:
                try:
                    print('word: {} .. neighbors: {}'.format(
                        word, nearest_neighbors(word, embeddings, vocab)))
                except ValueError:
                    print("querry doesn't exist!!")
            print('#'*100)

In [19]:
def evaluate(m, source, tc=False, td=False):
    """Compute perplexity on document completion.
    """
    m.eval()
    with torch.no_grad():
        if source == 'val':
            indices = torch.split(torch.tensor(range(num_docs_valid)), eval_batch_size)
            tokens = valid_tokens
            counts = valid_counts
        else: 
            indices = torch.split(torch.tensor(range(num_docs_test)), eval_batch_size)
            tokens = test_tokens
            counts = test_counts

        ## get \beta here
        beta = m.get_beta()

        ### do dc and tc here
        acc_loss = 0
        cnt = 0
        indices_1 = torch.split(torch.tensor(range(num_docs_test_1)), eval_batch_size)
        for idx, ind in enumerate(indices_1):
            try:
                ## get theta from first half of docs
                data_batch_1 = data.get_batch(test_1_tokens, test_1_counts, ind, vocab_size, device)
                sums_1 = data_batch_1.sum(1).unsqueeze(1)
                if bow_norm:
                    normalized_data_batch_1 = data_batch_1 / sums_1
                else:
                    normalized_data_batch_1 = data_batch_1
                theta, _ = m.get_theta(normalized_data_batch_1)

                ## get prediction loss using second half
                data_batch_2 = data.get_batch(test_2_tokens, test_2_counts, ind, vocab_size, device)
                sums_2 = data_batch_2.sum(1).unsqueeze(1)
                res = torch.mm(theta, beta)
                preds = torch.log(res)
                recon_loss = -(preds * data_batch_2).sum(1)

                loss = recon_loss / sums_2.squeeze()
                loss = loss.mean().item()
                acc_loss += loss
                cnt += 1
            except IndexError:
                cnt+=1
                continue
        cur_loss = acc_loss / cnt
        ppl_dc = round(math.exp(cur_loss), 1)
        print('*'*100)
        print('{} Doc Completion PPL: {}'.format(source.upper(), ppl_dc))
        print('*'*100)
        if tc or td:
            beta = beta.data.cpu().numpy()
            if tc:
                print('Computing topic coherence...')
                get_topic_coherence(beta, train_tokens, vocab)
            if td:
                print('Computing topic diversity...')
                get_topic_diversity(beta, 25)
        return ppl_dc


In [20]:

if mode == 'train':
    ## train model on data 
    best_epoch = 0
    best_val_ppl = 1e9
    all_val_ppls = []
    print('\n')
    print('Visualizing model quality before training...')
    visualize(model)
    print('\n')
    for epoch in range(1, epochs):
        train(epoch)
        val_ppl = evaluate(model, 'val')
        if val_ppl < best_val_ppl:
            with open(ckpt, 'wb') as f:
                torch.save(model, f)
            best_epoch = epoch
            best_val_ppl = val_ppl
        else:
            ## check whether to anneal lr
            lr = optimizer.param_groups[0]['lr']
            if anneal_lr and (len(all_val_ppls) > nonmono and val_ppl > min(all_val_ppls[:-nonmono]) and lr > 1e-5):
                optimizer.param_groups[0]['lr'] /= lr_factor
        if epoch % visualize_every == 0:
            visualize(model)
        all_val_ppls.append(val_ppl)
    with open(ckpt, 'rb') as f:
        model = torch.load(f)
    model = model.to(device)
    val_ppl = evaluate(model, 'val')




Visualizing model quality before training...
####################################################################################################
Visualize topics...
Topic 0: ['haa', 'spendy', 'anythingi', 'aricept', 'skincellulitespider', 'ticsnot', 'forbid', 'weepyshort', 'wpreservatives']
Topic 1: ['depoproverai', 'simcor', 'quotplugquot', 'smother', 'badlyand', 'nondrug', 'attackanxiety', 'deminishing', 'apirl']
Topic 2: ['levoxyl', 'dermatogolist', 'driveid', 'bumbs', 'depersonalize', 'freaky', 'meso', 'ego', 'trustworthy']
Topic 3: ['omacor', 'nauseouslost', 'attacksfor', 'wrestle', 'celebrix', 'enjoyment', 'fluoxetineprozac', 'reliance', 'papuls']
Topic 4: ['awfuldidnt', 'levlen', 'perceptible', 'willprescribe', 'dominance', 'basement', 'clyamdia', 'clenil', 'so']
Topic 5: ['preauthoraziton', 'dayseww', 'addedcant', 'phoenix', 'creamits', 'incorporate', 'azurette', 'issuesinsertion', 'triphasiltrivora']
Topic 6: ['antibioticresistant', 'disinterest', 'spam', 'bodyfat', 'hammer

Epoch: 1 .. batch: 78/135 .. LR: 0.005 .. KL_theta: 0.03 .. Rec_loss: 312.42 .. NELBO: 312.45
Epoch: 1 .. batch: 80/135 .. LR: 0.005 .. KL_theta: 0.03 .. Rec_loss: 312.58 .. NELBO: 312.61
Epoch: 1 .. batch: 82/135 .. LR: 0.005 .. KL_theta: 0.04 .. Rec_loss: 312.37 .. NELBO: 312.41
Epoch: 1 .. batch: 84/135 .. LR: 0.005 .. KL_theta: 0.04 .. Rec_loss: 312.45 .. NELBO: 312.49
Epoch: 1 .. batch: 86/135 .. LR: 0.005 .. KL_theta: 0.04 .. Rec_loss: 312.04 .. NELBO: 312.08
Epoch: 1 .. batch: 88/135 .. LR: 0.005 .. KL_theta: 0.05 .. Rec_loss: 311.94 .. NELBO: 311.99
Epoch: 1 .. batch: 90/135 .. LR: 0.005 .. KL_theta: 0.06 .. Rec_loss: 311.84 .. NELBO: 311.9
Epoch: 1 .. batch: 92/135 .. LR: 0.005 .. KL_theta: 0.07 .. Rec_loss: 311.57 .. NELBO: 311.64
Epoch: 1 .. batch: 94/135 .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 311.8 .. NELBO: 311.88
Epoch: 1 .. batch: 96/135 .. LR: 0.005 .. KL_theta: 0.09 .. Rec_loss: 311.15 .. NELBO: 311.24
Epoch: 1 .. batch: 98/135 .. LR: 0.005 .. KL_theta: 0.1 .. Rec

Epoch: 2 .. batch: 8/135 .. LR: 0.005 .. KL_theta: 0.54 .. Rec_loss: 268.13 .. NELBO: 268.67
Epoch: 2 .. batch: 10/135 .. LR: 0.005 .. KL_theta: 0.6 .. Rec_loss: 274.14 .. NELBO: 274.74
Epoch: 2 .. batch: 12/135 .. LR: 0.005 .. KL_theta: 0.59 .. Rec_loss: 275.93 .. NELBO: 276.52
Epoch: 2 .. batch: 14/135 .. LR: 0.005 .. KL_theta: 0.6 .. Rec_loss: 279.64 .. NELBO: 280.24
Epoch: 2 .. batch: 16/135 .. LR: 0.005 .. KL_theta: 0.63 .. Rec_loss: 282.04 .. NELBO: 282.67
Epoch: 2 .. batch: 18/135 .. LR: 0.005 .. KL_theta: 0.65 .. Rec_loss: 285.95 .. NELBO: 286.6
Epoch: 2 .. batch: 20/135 .. LR: 0.005 .. KL_theta: 0.64 .. Rec_loss: 287.87 .. NELBO: 288.51
Epoch: 2 .. batch: 22/135 .. LR: 0.005 .. KL_theta: 0.65 .. Rec_loss: 288.49 .. NELBO: 289.14
Epoch: 2 .. batch: 24/135 .. LR: 0.005 .. KL_theta: 0.65 .. Rec_loss: 288.04 .. NELBO: 288.69
Epoch: 2 .. batch: 26/135 .. LR: 0.005 .. KL_theta: 0.65 .. Rec_loss: 291.47 .. NELBO: 292.12
Epoch: 2 .. batch: 28/135 .. LR: 0.005 .. KL_theta: 0.64 .. Rec_

Topic 12: ['quotdoctor', 'postmeal', 'chillstripled', 'fogwas', 'laydown', 'yearstried', 'ensures', 'wellbutrine', 'buckswhich']
Topic 13: ['quotdoctor', 'fogwas', 'postmeal', 'chillstripled', 'laydown', 'yearstried', 'wellbutrine', 'buckswhich', 'ensures']
Topic 14: ['postmeal', 'chillstripled', 'quotdoctor', 'laydown', 'fogwas', 'wellbutrine', 'yearstried', 'upscale', 'gi']
####################################################################################################
Visualize word embeddings by using output embedding matrix
vectors:  (19856, 200)
query:  (200,)
word: skin .. neighbors: ['skin', 'face', 'clear', 'acne', 'sunscreen', 'breakout', 'pimple', 'peel', 'flake', 'redness', 'epiduo', 'retin', 'papule', 'ziana', 'oily', 'neutrogena', 'clinique', 'acanya', 'complexion', 'moisturiser']
vectors:  (19856, 200)
query:  (200,)
word: cycle .. neighbors: ['cycle', 'period', 'menstrual', 'lighter', 'safyral', 'month', 'bleeding', 'menstruate', 'hormone', 'cramp', 'heavy', 'tracke

Epoch: 3 .. batch: 114/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.04 .. NELBO: 303.57
Epoch: 3 .. batch: 116/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.38 .. NELBO: 303.91
Epoch: 3 .. batch: 118/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.83 .. NELBO: 304.36
Epoch: 3 .. batch: 120/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 304.18 .. NELBO: 304.71
Epoch: 3 .. batch: 122/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.9 .. NELBO: 304.43
Epoch: 3 .. batch: 124/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.76 .. NELBO: 304.29
Epoch: 3 .. batch: 126/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 303.86 .. NELBO: 304.38
Epoch: 3 .. batch: 128/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 304.11 .. NELBO: 304.63
Epoch: 3 .. batch: 130/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 304.18 .. NELBO: 304.7
Epoch: 3 .. batch: 132/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 303.92 .. NELBO: 304.44
Epoch: 3 .. batch: 134/135 .. LR: 0.005 .. KL_theta:

Epoch: 4 .. batch: 44/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 305.52 .. NELBO: 306.05
Epoch: 4 .. batch: 46/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 306.11 .. NELBO: 306.64
Epoch: 4 .. batch: 48/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 305.97 .. NELBO: 306.5
Epoch: 4 .. batch: 50/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 306.12 .. NELBO: 306.64
Epoch: 4 .. batch: 52/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 306.18 .. NELBO: 306.7
Epoch: 4 .. batch: 54/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 306.31 .. NELBO: 306.83
Epoch: 4 .. batch: 56/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 306.12 .. NELBO: 306.65
Epoch: 4 .. batch: 58/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 306.39 .. NELBO: 306.92
Epoch: 4 .. batch: 60/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 305.91 .. NELBO: 306.44
Epoch: 4 .. batch: 62/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 305.35 .. NELBO: 305.88
Epoch: 4 .. batch: 64/135 .. LR: 0.005 .. KL_theta: 0.53 .. Re

word: efficacy .. neighbors: ['efficacy', 'generess', 'buti', 'version', 'maintenance', 'unisom', 'weighti', 'threshold', 'simponi', 'approval', 'reluctantly', 'sufficient', 'psoriatic', 'accustom', 'national', 'intolerable', 'ssrisnris', 'fee', 'attest', 'pdoc']
vectors:  (19856, 200)
query:  (200,)
word: performance .. neighbors: ['performance', 'sexual', 'cialis', 'bph', 'semen', 'erection', 'ratio', 'distraction', 'climax', 'avodart', 'stendra', 'viagra', 'benign', 'smarter', 'elate', 'stamen', 'prostate', 'anorgasmia', 'marriage', 'adhdadd']
vectors:  (19856, 200)
query:  (200,)
word: cancer .. neighbors: ['cancer', 'ibrance', 'radiation', 'chemo', 'lymph', 'node', 'tarceva', 'prostate', 'zoladex', 'femara', 'oncologist', 'chemotherapy', 'leukemia', 'sutent', 'alfuzosin', 'carcinoma', 'arimidex', 'casodex', 'psa', 'avastin']
vectors:  (19856, 200)
query:  (200,)
word: disease .. neighbors: ['disease', 'degenerative', 'crohn', 'autoimmune', 'lyme', 'obstructive', 'herniated', 'diag

Topic 0: ['fogwas', 'quotdoctor', 'postmeal', 'chillstripled', 'laydown', 'yearstried', 'upscale', 'buckswhich', 'wellbutrine']
Topic 1: ['chillstripled', 'fogwas', 'postmeal', 'quotdoctor', 'upscale', 'laydown', 'yearstried', 'wellbutrine', 'buckswhich']
Topic 2: ['fogwas', 'quotdoctor', 'postmeal', 'laydown', 'chillstripled', 'yearstried', 'upscale', 'wellbutrine', 'buckswhich']
Topic 3: ['quotdoctor', 'chillstripled', 'yearstried', 'postmeal', 'fogwas', 'laydown', 'upscale', 'buckswhich', 'gi']
Topic 4: ['quotdoctor', 'chillstripled', 'postmeal', 'fogwas', 'buckswhich', 'yearstried', 'laydown', 'upscale', 'wellbutrine']
Topic 5: ['quotdoctor', 'chillstripled', 'postmeal', 'fogwas', 'laydown', 'yearstried', 'upscale', 'buckswhich', 'gi']
Topic 6: ['period', 'gain', 'swing', 'weight', 'acne', 'sex', 'bleeding', 'birth', 'spot']
Topic 7: ['quotdoctor', 'postmeal', 'chillstripled', 'fogwas', 'laydown', 'upscale', 'wellbutrine', 'yearstried', 'buckswhich']
Topic 8: ['quotdoctor', 'yearst

Epoch: 6 .. batch: 82/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 305.55 .. NELBO: 306.05
Epoch: 6 .. batch: 84/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 305.68 .. NELBO: 306.18
Epoch: 6 .. batch: 86/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 302.12 .. NELBO: 302.62
Epoch: 6 .. batch: 88/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 302.0 .. NELBO: 302.5
Epoch: 6 .. batch: 90/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 301.87 .. NELBO: 302.37
Epoch: 6 .. batch: 92/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 301.86 .. NELBO: 302.36
Epoch: 6 .. batch: 94/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 301.71 .. NELBO: 302.21
Epoch: 6 .. batch: 96/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 301.62 .. NELBO: 302.12
Epoch: 6 .. batch: 98/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 301.79 .. NELBO: 302.29
Epoch: 6 .. batch: 100/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 301.39 .. NELBO: 301.89
Epoch: 6 .. batch: 102/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 3

Epoch: 7 .. batch: 14/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 305.49 .. NELBO: 306.0
Epoch: 7 .. batch: 16/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 304.11 .. NELBO: 304.63
Epoch: 7 .. batch: 18/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 302.91 .. NELBO: 303.43
Epoch: 7 .. batch: 20/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.46 .. NELBO: 303.99
Epoch: 7 .. batch: 22/135 .. LR: 0.005 .. KL_theta: 0.53 .. Rec_loss: 303.19 .. NELBO: 303.72
Epoch: 7 .. batch: 24/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 304.07 .. NELBO: 304.59
Epoch: 7 .. batch: 26/135 .. LR: 0.005 .. KL_theta: 0.52 .. Rec_loss: 306.51 .. NELBO: 307.03
Epoch: 7 .. batch: 28/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 305.25 .. NELBO: 305.76
Epoch: 7 .. batch: 30/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 304.14 .. NELBO: 304.64
Epoch: 7 .. batch: 32/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 304.52 .. NELBO: 305.02
Epoch: 7 .. batch: 34/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec

Topic 14: ['quotdoctor', 'postmeal', 'laydown', 'fogwas', 'buckswhich', 'chillstripled', 'upscale', 'wellbutrine', 'yearstried']
####################################################################################################
Visualize word embeddings by using output embedding matrix
vectors:  (19856, 200)
query:  (200,)
word: skin .. neighbors: ['skin', 'face', 'clear', 'acne', 'sunscreen', 'breakout', 'pimple', 'peel', 'flake', 'redness', 'epiduo', 'retin', 'papule', 'ziana', 'oily', 'neutrogena', 'clinique', 'acanya', 'complexion', 'moisturiser']
vectors:  (19856, 200)
query:  (200,)
word: cycle .. neighbors: ['cycle', 'period', 'menstrual', 'lighter', 'safyral', 'month', 'bleeding', 'menstruate', 'hormone', 'cramp', 'heavy', 'tracker', 'predictable', 'natazia', 'amethia', 'spot', 'bleed', 'mononessa', 'heavier', 'pill']
querry doesn't exist!!
vectors:  (19856, 200)
query:  (200,)
word: price .. neighbors: ['price', 'pay', 'retail', 'formulary', 'insurance', 'paid', 'fee', 'doll

Epoch: 8 .. batch: 118/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 301.97 .. NELBO: 302.46
Epoch: 8 .. batch: 120/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 302.13 .. NELBO: 302.62
Epoch: 8 .. batch: 122/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 302.16 .. NELBO: 302.65
Epoch: 8 .. batch: 124/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 301.99 .. NELBO: 302.48
Epoch: 8 .. batch: 126/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 301.95 .. NELBO: 302.44
Epoch: 8 .. batch: 128/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 301.89 .. NELBO: 302.38
Epoch: 8 .. batch: 130/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 302.21 .. NELBO: 302.7
Epoch: 8 .. batch: 132/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 302.81 .. NELBO: 303.3
Epoch: 8 .. batch: 134/135 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 302.83 .. NELBO: 303.32
****************************************************************************************************
Epoch----->8 .. LR: 0.005 .. KL_theta: 0.49 ..

Epoch: 9 .. batch: 48/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 304.61 .. NELBO: 305.12
Epoch: 9 .. batch: 50/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 304.44 .. NELBO: 304.95
Epoch: 9 .. batch: 52/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 304.22 .. NELBO: 304.73
Epoch: 9 .. batch: 54/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 304.68 .. NELBO: 305.19
Epoch: 9 .. batch: 56/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 305.13 .. NELBO: 305.64
Epoch: 9 .. batch: 58/135 .. LR: 0.005 .. KL_theta: 0.51 .. Rec_loss: 305.03 .. NELBO: 305.54
Epoch: 9 .. batch: 60/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 304.93 .. NELBO: 305.43
Epoch: 9 .. batch: 62/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 304.95 .. NELBO: 305.45
Epoch: 9 .. batch: 64/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 305.15 .. NELBO: 305.65
Epoch: 9 .. batch: 66/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_loss: 305.37 .. NELBO: 305.87
Epoch: 9 .. batch: 68/135 .. LR: 0.005 .. KL_theta: 0.5 .. Rec_l

word: performance .. neighbors: ['performance', 'sexual', 'cialis', 'bph', 'semen', 'erection', 'ratio', 'distraction', 'climax', 'avodart', 'stendra', 'viagra', 'benign', 'smarter', 'elate', 'stamen', 'prostate', 'anorgasmia', 'marriage', 'adhdadd']
vectors:  (19856, 200)
query:  (200,)
word: cancer .. neighbors: ['cancer', 'ibrance', 'radiation', 'chemo', 'lymph', 'node', 'tarceva', 'prostate', 'zoladex', 'femara', 'oncologist', 'chemotherapy', 'leukemia', 'sutent', 'alfuzosin', 'carcinoma', 'arimidex', 'casodex', 'psa', 'avastin']
vectors:  (19856, 200)
query:  (200,)
word: disease .. neighbors: ['disease', 'degenerative', 'crohn', 'autoimmune', 'lyme', 'obstructive', 'herniated', 'diagnose', 'fistula', 'classify', 'lumbar', 'hashimotos', 'disc', 'remicade', 'currently', 'inflammatory', 'ankylose', 'kapidex', 'herniation', 'leukemia']
####################################################################################################
*************************************************

In [21]:

with open(ckpt, 'rb') as f:
    model = torch.load(f)
model = model.to(device)
model.eval()

with torch.no_grad():
    ## get document completion perplexities
    test_ppl = evaluate(model, 'test', tc=tc, td=td)

    ## get most used topics
    indices = torch.tensor(range(num_docs_train))
    indices = torch.split(indices, batch_size)
    thetaAvg = torch.zeros(1, num_topics).to(device)
    thetaWeightedAvg = torch.zeros(1, num_topics).to(device)
    cnt = 0
    for idx, ind in enumerate(indices):
        try:
            data_batch = data.get_batch(train_tokens, train_counts, ind, vocab_size, device)
            sums = data_batch.sum(1).unsqueeze(1)
            cnt += sums.sum(0).squeeze().cpu().numpy()
            if bow_norm:
                normalized_data_batch = data_batch / sums
            else:
                normalized_data_batch = data_batch
            theta, _ = model.get_theta(normalized_data_batch)
            thetaAvg += theta.sum(0).unsqueeze(0) / num_docs_train
            weighed_theta = sums * theta
            thetaWeightedAvg += weighed_theta.sum(0).unsqueeze(0)
            if idx % 100 == 0 and idx > 0:
                print('batch: {}/{}'.format(idx, len(indices)))
        except IndexError:
            continue
    thetaWeightedAvg = thetaWeightedAvg.squeeze().cpu().numpy() / cnt
    print('\nThe 10 most used topics are {}'.format(thetaWeightedAvg.argsort()[::-1][:10]))

    ## show topics
    beta = model.get_beta()
    topic_indices = list(np.random.choice(num_topics, 10)) # 10 random topics
    print('\n')
    for k in range(num_topics):#topic_indices:
        gamma = beta[k]
        top_words = list(gamma.cpu().numpy().argsort()[-num_words+1:][::-1])
        topic_words = [vocab[a] for a in top_words]
        print('Topic {}: {}'.format(k, topic_words))

    if train_embeddings:
        ## show etm embeddings 
        try:
            rho_etm = model.rho.weight.cpu()
        except:
            rho_etm = model.rho.cpu()
        queries = ['andrew', 'woman', 'computer', 'sports', 'religion', 'man', 'love', 
                        'intelligence', 'money', 'politics', 'health', 'people', 'family']
        print('\n')
        print('ETM embeddings...')
        for word in queries:
            print('word: {} .. etm neighbors: {}'.format(word, nearest_neighbors(word, rho_etm, vocab)))
        print('\n')


****************************************************************************************************
TEST Doc Completion PPL: 876.8
****************************************************************************************************
batch: 100/135

The 10 most used topics are [ 6  4  9  5 14 12  2 10  7 11]


Topic 0: ['quotdoctor', 'postmeal', 'fogwas', 'yearstried', 'upscale', 'chillstripled', 'laydown', 'wellbutrine', 'valiumis']
Topic 1: ['postmeal', 'quotdoctor', 'yearstried', 'upscale', 'chillstripled', 'laydown', 'wellbutrine', 'fogwas', 'quotmyquot']
Topic 2: ['quotdoctor', 'chillstripled', 'postmeal', 'upscale', 'wellbutrine', 'laydown', 'yearstried', 'fogwas', 'quotmyquot']
Topic 3: ['quotdoctor', 'postmeal', 'chillstripled', 'wellbutrine', 'upscale', 'yearstried', 'fogwas', 'laydown', 'quotmyquot']
Topic 4: ['quotdoctor', 'postmeal', 'upscale', 'laydown', 'chillstripled', 'wellbutrine', 'fogwas', 'yearstried', 'quotmyquot']
Topic 5: ['postmeal', 'upscale', 'quotdoctor', 'lay

In [None]:
# !python main.py --mode train --dataset 20ng --data_path data/20ng --num_topics 50 --train_embeddings 1 --epochs 1000
