## Create Required Embeddings

Note: This section can be skipped if embeddings are already prepared

In [1]:
import gensim
import pickle
import os
import numpy as np
import pandas as pd
import sys
# sys.setdefaultencoding() does not exist, here!
# reload(sys)  # Reload does the trick!
# sys.setdefaultencoding('UTF8')

# sys.setdefaultencoding('UTF8')

data_file= "data/drug_review/drugsComTrain_raw.tsv"        # default='', help='a .txt file containing the corpus'
emb_file= "embeddings/embeddings.txt"                      #default='embeddings.txt', help='file to save the word embeddings'
dim_rho= 300                                               #default=300, help='dimensionality of the word embeddings'
min_count= 2                                               #default=2, help='minimum term frequency (to define the vocabulary)'
sg= 1                                                      # default=1, help='whether to use skip-gram'
workers= 6                                                 #default=25, help='number of CPU cores'
negative_samples= 10                                       # default=10, help='number of negative samples'
window_size= 4                                             # default=4, help='window size to determine context'
iters= 50                                                  #default=50, help='number of iterationst'



In [2]:

# Class for a memory-friendly iterator over the dataset
class MySentences(object):
    def __init__(self, filename):
        self.filename = filename
        self.file_type = 'text'
 
    def __iter__(self):
        if self.file_type == 'text':
            for line in open(self.filename,encoding="utf8"):
                yield line.split()
        elif self.file_type == 'csv':
            for line in self.reviews.values:
                yield line.split()
                
    def __init__(self, filename,col,delimiter = "\t"):
        self.filename = filename
        data = pd.read_csv(filename,delimiter=delimiter)
        self.reviews = data[col][:1000]
        self.file_type = 'csv'


In [3]:
sentences = MySentences(data_file,"review") # a memory-friendly iterator
model = gensim.models.Word2Vec(sentences, min_count=min_count, sg=sg, size=dim_rho, 
    iter=iters, workers=workers, negative=negative_samples, window=window_size)

In [4]:


# Write the embeddings to a file
with open(emb_file, 'w') as f:
    for v in list(model.wv.vocab):
        vec = list(model.wv.__getitem__(v))
        f.write(v + ' ')
        vec_str = ['%.9f' % val for val in vec]
        vec_str = " ".join(vec_str)
        f.write(vec_str + '\n')

## Data Preprocessing

In [5]:
# import pickle

# abc = pickle.load("data/20ng/vocab.pkl")
# abc

## Modelling


In [1]:

from __future__ import print_function

import torch
import pickle 
import numpy as np 
import os 
import math 
import random 
import sys

import matplotlib.pyplot as plt 
import data
import scipy.io

from torch import nn, optim
from torch.nn import functional as F

from etm import ETM
from utils import nearest_neighbors, get_topic_coherence, get_topic_diversity


In [2]:
# df.head()

In [3]:
# df = pd.read_csv("data/drug_review/drugsComTrain_raw.tsv",delimiter="\t")[:1000]
# df.to_csv("data/drug_review/drugs_train_1000.csv",index=None)
# reviews = df.review
# with open("train_file.txt", 'w',encoding='utf8') as f:
#     for review in reviews.values:
#         f.write(review + '\n')

In [4]:
# dataset = "20ng"

dataset =   "train_file.txt"                      #default='20ng', help='name of corpus'
data_path = 'data/drug_review/'#default='data/20ng', help='directory containing data'
emb_path = 'embeddings/embeddings.txt'#default='data/20ng_embeddings.txt', help='directory containing word embeddings'
save_path = './results'#default='./results', help='path to save results'
batch_size = 100 #default=1000, help='input batch size for training'

### model-related arguments
num_topics = 25   #default=50, help='number of topics'
rho_size = 300    #default=300, help='dimension of rho'
emb_size = 300    #default=300, help='dimension of embeddings'
t_hidden_size = 800 #default=800, help='dimension of hidden space of q(theta)'
theta_act = 'relu' #default='relu', help='tanh, softplus, relu, rrelu, leakyrelu, elu, selu, glu)'
train_embeddings = 0 #default=0, help='whether to fix rho or train it'

### optimization-related arguments
lr = 0.05 # default=0.005, help='learning rate'
lr_factor =4.0  #default=4.0, help='divide learning rate by this...'
epochs = 20 # default=20, help='number of epochs to train...150 for 20ng 100 for others'
mode = 'train'# default='train', help='train or eval model'
optimizer = 'adam'# default='adam', help='choice of optimizer'
seed = 2019# default=2019, help='random seed (default: 1)
enc_drop = 0.0# default=0.0, help='dropout rate on encoder'
clip = 0.0# default=0.0, help='gradient clipping'
nonmono = 10# default=10, help='number of bad hits allowed'
wdecay = 1.2e-6# default=1.2e-6, help='some l2 regularization'
anneal_lr = 0#  default=0, help='whether to anneal the learning rate or not'
bow_norm = 1# default=1, help='normalize the bows or not'

### evaluation, visualization, and logging-related arguments
num_words = 10  # default=10, help='number of words for topic viz' 
log_interval = 2 # default=2, help='when to log training'
visualize_every = 1 # default=10, help='when to visualize results'
eval_batch_size = 1000 # default=1000, help='input batch size for evaluation'
load_from = 'results/etm_20ng_K_50_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_1'# default='', help='the name of the ckpt to eval from'
tc = 0# default=0, help='whether to compute topic coherence or not'
td = 0# default=0, help='whether to compute topic diversity or not'


In [5]:
device = torch.device("cpu")#("cuda" if torch.cuda.is_available() else "cpu")

print('\n')
np.random.seed(seed)
torch.manual_seed(seed)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed(seed)





<torch._C.Generator at 0x20b1dc03b10>

In [6]:
vocab, train, valid, test = data.get_data(os.path.join(data_path))
vocab_size = len(vocab)

# 1. training data
train_tokens = train['tokens']
train_counts = train['counts']
num_docs_train = len(train_tokens)

# 2. dev set
valid_tokens = valid['tokens']
valid_counts = valid['counts']
num_docs_valid = len(valid_tokens)

# 3. test data
test_tokens = test['tokens']
test_counts = test['counts']
num_docs_test = len(test_tokens)
# test_1_tokens = test['tokens_1']
# test_1_counts = test['counts_1']
# num_docs_test_1 = len(test_1_tokens)
# test_2_tokens = test['tokens_2']
# test_2_counts = test['counts_2']
# num_docs_test_2 = len(test_2_tokens)

embeddings = None

In [7]:
#/usr/bin/python
## get data
# 1. vocabulary

if not train_embeddings:
    emb_path = emb_path
    vect_path = os.path.join(data_path.split('/')[0], 'vocab.pkl')   
    vectors = {}
    with open(emb_path, 'rb') as f:
        for l in f:
            line = l.decode().split()
            word = line[0]
            if word in vocab:
                vect = np.array(line[1:]).astype(np.float)
                vectors[word] = vect
    embeddings = np.zeros((vocab_size, emb_size))
    words_found = 0
    for i, word in enumerate(vocab):
        try: 
            embeddings[i] = vectors[word]
            words_found += 1
        except KeyError:
            embeddings[i] = np.random.normal(scale=0.6, size=(emb_size, ))
    embeddings = torch.tensor(embeddings).to(device)
    embeddings_dim = embeddings.size()

print('=*'*100)
# print('Training an Embedded Topic Model on {} with the following settings: {}'.format(dataset.upper()))
print('=*'*100)


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


In [8]:
emb_path

'embeddings/embeddings.txt'

In [9]:

## define checkpoint
if not os.path.exists(save_path):
    os.makedirs(save_path)

if mode == 'eval':
    ckpt = load_from
else:
    ckpt = os.path.join(save_path, 
        'etm_{}_K_{}_Htheta_{}_Optim_{}_Clip_{}_ThetaAct_{}_Lr_{}_Bsz_{}_RhoSize_{}_trainEmbeddings_{}'.format(
        dataset, num_topics, t_hidden_size, optimizer, clip, theta_act, 
            lr, batch_size, rho_size, train_embeddings))

## define model and optimizer
model = ETM(num_topics, vocab_size, t_hidden_size, rho_size, emb_size, 
                theta_act, embeddings, train_embeddings, enc_drop).to(device)

print('model: {}'.format(model))

if optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=wdecay)
elif optimizer == 'asgd':
    optimizer = optim.ASGD(model.parameters(), lr=lr, t0=0, lambd=0., weight_decay=wdecay)
else:
    print('Defaulting to vanilla SGD')
    optimizer = optim.SGD(model.parameters(), lr=lr)

model: ETM(
  (t_drop): Dropout(p=0.0)
  (theta_act): ReLU()
  (alphas): Linear(in_features=300, out_features=25, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=87165, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=25, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=25, bias=True)
)


In [10]:
def train(epoch):
    model.train()
    acc_loss = 0
    acc_kl_theta_loss = 0
    cnt = 0
    indices = torch.randperm(num_docs_train)
    indices = torch.split(indices, batch_size)
    for idx, ind in enumerate(indices):
        optimizer.zero_grad()
        model.zero_grad()
        data_batch = data.get_batch(train_tokens, train_counts, ind, vocab_size, device)
        sums = data_batch.sum(1).unsqueeze(1)
        if bow_norm:
            normalized_data_batch = data_batch / sums
        else:
            normalized_data_batch = data_batch
        recon_loss, kld_theta = model(data_batch, normalized_data_batch)
        total_loss = recon_loss + kld_theta
        total_loss.backward()

        if clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        acc_loss += torch.sum(recon_loss).item()
        acc_kl_theta_loss += torch.sum(kld_theta).item()
        cnt += 1

        if idx % log_interval == 0 and idx > 0:
            cur_loss = round(acc_loss / cnt, 2) 
            cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
            cur_real_loss = round(cur_loss + cur_kl_theta, 2)

            print('Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format(
                epoch, idx, len(indices), optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss))
    
    cur_loss = round(acc_loss / cnt, 2) 
    cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
    cur_real_loss = round(cur_loss + cur_kl_theta, 2)
    print('*'*100)
    print('Epoch----->{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format(
            epoch, optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss))
    print('*'*100)

In [11]:
def visualize(m, show_emb=True):
    if not os.path.exists('./results'):
        os.makedirs('./results')

    m.eval()

    queries = ['skin','cycle','effects','price','worst','best','neuropsychologist','efficacy','performance','cancer']

    ## visualize topics using monte carlo
    with torch.no_grad():
        print('#'*100)
        print('Visualize topics...')
        topics_words = []
        gammas = m.get_beta()
        for k in range(num_topics):
            gamma = gammas[k]
            top_words = list(gamma.cpu().numpy().argsort()[-num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            topics_words.append(' '.join(topic_words))
            print('Topic {}: {}'.format(k, topic_words))

        if show_emb:
            ## visualize word embeddings by using V to get nearest neighbors
            print('#'*100)
            print('Visualize word embeddings by using output embedding matrix')
            try:
                embeddings = m.rho.weight  # Vocab_size x E
            except:
                embeddings = m.rho         # Vocab_size x E
            neighbors = []
            for word in queries:
                print('word: {} .. neighbors: {}'.format(
                    word, nearest_neighbors(word, embeddings, vocab)))
            print('#'*100)

In [12]:
def evaluate(m, source, tc=False, td=False):
    """Compute perplexity on document completion.
    """
    m.eval()
    with torch.no_grad():
        if source == 'val':
            indices = torch.split(torch.tensor(range(num_docs_valid)), eval_batch_size)
            tokens = valid_tokens
            counts = valid_counts
        else: 
            indices = torch.split(torch.tensor(range(num_docs_test)), eval_batch_size)
            tokens = test_tokens
            counts = test_counts

        ## get \beta here
        beta = m.get_beta()

        ### do dc and tc here
        acc_loss = 0
        cnt = 0
        indices_1 = torch.split(torch.tensor(range(num_docs_test_1)), eval_batch_size)
        for idx, ind in enumerate(indices_1):
            ## get theta from first half of docs
            data_batch_1 = data.get_batch(test_1_tokens, test_1_counts, ind, vocab_size, device)
            sums_1 = data_batch_1.sum(1).unsqueeze(1)
            if bow_norm:
                normalized_data_batch_1 = data_batch_1 / sums_1
            else:
                normalized_data_batch_1 = data_batch_1
            theta, _ = m.get_theta(normalized_data_batch_1)

            ## get prediction loss using second half
            data_batch_2 = data.get_batch(test_2_tokens, test_2_counts, ind, vocab_size, device)
            sums_2 = data_batch_2.sum(1).unsqueeze(1)
            res = torch.mm(theta, beta)
            preds = torch.log(res)
            recon_loss = -(preds * data_batch_2).sum(1)
            
            loss = recon_loss / sums_2.squeeze()
            loss = loss.mean().item()
            acc_loss += loss
            cnt += 1
        cur_loss = acc_loss / cnt
        ppl_dc = round(math.exp(cur_loss), 1)
        print('*'*100)
        print('{} Doc Completion PPL: {}'.format(source.upper(), ppl_dc))
        print('*'*100)
        if tc or td:
            beta = beta.data.cpu().numpy()
            if tc:
                print('Computing topic coherence...')
                get_topic_coherence(beta, train_tokens, vocab)
            if td:
                print('Computing topic diversity...')
                get_topic_diversity(beta, 25)
        return ppl_dc


In [None]:

if mode == 'train':
    ## train model on data 
    best_epoch = 0
    best_val_ppl = 1e9
    all_val_ppls = []
    print('\n')
    print('Visualizing model quality before training...')
    visualize(model)
    print('\n')
    for epoch in range(1, epochs):
        train(epoch)
        val_ppl = evaluate(model, 'val')
        if val_ppl < best_val_ppl:
            with open(ckpt, 'wb') as f:
                torch.save(model, f)
            best_epoch = epoch
            best_val_ppl = val_ppl
        else:
            ## check whether to anneal lr
            lr = optimizer.param_groups[0]['lr']
            if anneal_lr and (len(all_val_ppls) > nonmono and val_ppl > min(all_val_ppls[:-nonmono]) and lr > 1e-5):
                optimizer.param_groups[0]['lr'] /= lr_factor
        if epoch % visualize_every == 0:
            visualize(model)
        all_val_ppls.append(val_ppl)
    with open(ckpt, 'rb') as f:
        model = torch.load(f)
    model = model.to(device)
    val_ppl = evaluate(model, 'val')
else:   
    with open(ckpt, 'rb') as f:
        model = torch.load(f)
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        ## get document completion perplexities
        test_ppl = evaluate(model, 'test', tc=tc, td=td)

        ## get most used topics
        indices = torch.tensor(range(num_docs_train))
        indices = torch.split(indices, batch_size)
        thetaAvg = torch.zeros(1, num_topics).to(device)
        thetaWeightedAvg = torch.zeros(1, num_topics).to(device)
        cnt = 0
        for idx, ind in enumerate(indices):
            try:
                data_batch = data.get_batch(train_tokens, train_counts, ind, vocab_size, device)
                sums = data_batch.sum(1).unsqueeze(1)
                cnt += sums.sum(0).squeeze().cpu().numpy()
                if bow_norm:
                    normalized_data_batch = data_batch / sums
                else:
                    normalized_data_batch = data_batch
                theta, _ = model.get_theta(normalized_data_batch)
                thetaAvg += theta.sum(0).unsqueeze(0) / num_docs_train
                weighed_theta = sums * theta
                thetaWeightedAvg += weighed_theta.sum(0).unsqueeze(0)
                if idx % 100 == 0 and idx > 0:
                    print('batch: {}/{}'.format(idx, len(indices)))
            except IndexError:
                continue
        thetaWeightedAvg = thetaWeightedAvg.squeeze().cpu().numpy() / cnt
        print('\nThe 10 most used topics are {}'.format(thetaWeightedAvg.argsort()[::-1][:10]))

        ## show topics
        beta = model.get_beta()
        topic_indices = list(np.random.choice(num_topics, 10)) # 10 random topics
        print('\n')
        for k in range(num_topics):#topic_indices:
            gamma = beta[k]
            top_words = list(gamma.cpu().numpy().argsort()[-num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            print('Topic {}: {}'.format(k, topic_words))

        if train_embeddings:
            ## show etm embeddings 
            try:
                rho_etm = model.rho.weight.cpu()
            except:
                rho_etm = model.rho.cpu()
            queries = ['andrew', 'woman', 'computer', 'sports', 'religion', 'man', 'love', 
                            'intelligence', 'money', 'politics', 'health', 'people', 'family']
            print('\n')
            print('ETM embeddings...')
            for word in queries:
                print('word: {} .. etm neighbors: {}'.format(word, nearest_neighbors(word, rho_etm, vocab)))
            print('\n')




Visualizing model quality before training...
####################################################################################################
Visualize topics...
Topic 0: ['dancequotany', 'reviewxx', 'touching', 'implore', 'beggars', 'withdrawalsand', 'bequotmorequot', 'becaus', 'stigmatizm']
Topic 1: ['flarei', 'hypoxmia', 'yyaaaa', 'janumetx', 'arthroc', 'unwind', 'someweird', 'primitive', 'milkywhite']
Topic 2: ['relearned', 'regulari', 'cod', 'invegai', 'quotenduringquot', 'periodspotting', 'sleepynessauditory', 'lsquojollyrsquo', 'dizzinessdid']
Topic 3: ['quotunderwaterquot', 'cruch', 'quotwillpower', 'benzoids', 'inhe', 'laundry', 'scapula', 'finances', 'contolling']
Topic 4: ['genericpregnancy', 'tabletsmg', 'jesting', 'quotzilchesquot', 'maralthia', 'sciatica', 'wellaches', 'quotserenquot', 'venlafazxine']
Topic 5: ['craigadkinslipozene', 'lightskinned', 'workincreased', 'selfcritical', 'neuropsychologist', 'advert', 'optomistic', 'chancequot', 'reheats']
Topic 6: ['dist

Epoch: 1 .. batch: 34/1452 .. LR: 0.05 .. KL_theta: 677.27 .. Rec_loss: 677.13 .. NELBO: 1354.4
Epoch: 1 .. batch: 36/1452 .. LR: 0.05 .. KL_theta: 640.8 .. Rec_loss: 673.0 .. NELBO: 1313.8
Epoch: 1 .. batch: 38/1452 .. LR: 0.05 .. KL_theta: 608.16 .. Rec_loss: 666.82 .. NELBO: 1274.98
Epoch: 1 .. batch: 40/1452 .. LR: 0.05 .. KL_theta: 578.64 .. Rec_loss: 662.3 .. NELBO: 1240.94
Epoch: 1 .. batch: 42/1452 .. LR: 0.05 .. KL_theta: 551.99 .. Rec_loss: 660.37 .. NELBO: 1212.36
Epoch: 1 .. batch: 44/1452 .. LR: 0.05 .. KL_theta: 527.61 .. Rec_loss: 655.88 .. NELBO: 1183.49
Epoch: 1 .. batch: 46/1452 .. LR: 0.05 .. KL_theta: 505.35 .. Rec_loss: 651.2 .. NELBO: 1156.55
Epoch: 1 .. batch: 48/1452 .. LR: 0.05 .. KL_theta: 484.89 .. Rec_loss: 646.62 .. NELBO: 1131.51
Epoch: 1 .. batch: 50/1452 .. LR: 0.05 .. KL_theta: 466.05 .. Rec_loss: 643.35 .. NELBO: 1109.4
Epoch: 1 .. batch: 52/1452 .. LR: 0.05 .. KL_theta: 448.64 .. Rec_loss: 640.35 .. NELBO: 1088.99
Epoch: 1 .. batch: 54/1452 .. LR: 0.0

Epoch: 1 .. batch: 202/1452 .. LR: 0.05 .. KL_theta: 121.05 .. Rec_loss: 574.59 .. NELBO: 695.64
Epoch: 1 .. batch: 204/1452 .. LR: 0.05 .. KL_theta: 119.92 .. Rec_loss: 574.06 .. NELBO: 693.98
Epoch: 1 .. batch: 206/1452 .. LR: 0.05 .. KL_theta: 118.82 .. Rec_loss: 573.63 .. NELBO: 692.45
Epoch: 1 .. batch: 208/1452 .. LR: 0.05 .. KL_theta: 117.74 .. Rec_loss: 573.04 .. NELBO: 690.78
Epoch: 1 .. batch: 210/1452 .. LR: 0.05 .. KL_theta: 116.67 .. Rec_loss: 572.9 .. NELBO: 689.57
Epoch: 1 .. batch: 212/1452 .. LR: 0.05 .. KL_theta: 115.62 .. Rec_loss: 572.78 .. NELBO: 688.4
Epoch: 1 .. batch: 214/1452 .. LR: 0.05 .. KL_theta: 114.61 .. Rec_loss: 572.3 .. NELBO: 686.91
Epoch: 1 .. batch: 216/1452 .. LR: 0.05 .. KL_theta: 113.61 .. Rec_loss: 572.02 .. NELBO: 685.63
Epoch: 1 .. batch: 218/1452 .. LR: 0.05 .. KL_theta: 112.62 .. Rec_loss: 572.03 .. NELBO: 684.65
Epoch: 1 .. batch: 220/1452 .. LR: 0.05 .. KL_theta: 111.64 .. Rec_loss: 571.33 .. NELBO: 682.97
Epoch: 1 .. batch: 222/1452 .. LR

Epoch: 1 .. batch: 372/1452 .. LR: 0.05 .. KL_theta: 68.54 .. Rec_loss: 562.74 .. NELBO: 631.28
Epoch: 1 .. batch: 374/1452 .. LR: 0.05 .. KL_theta: 68.21 .. Rec_loss: 562.72 .. NELBO: 630.93
Epoch: 1 .. batch: 376/1452 .. LR: 0.05 .. KL_theta: 67.88 .. Rec_loss: 562.39 .. NELBO: 630.27
Epoch: 1 .. batch: 378/1452 .. LR: 0.05 .. KL_theta: 67.55 .. Rec_loss: 562.57 .. NELBO: 630.12
Epoch: 1 .. batch: 380/1452 .. LR: 0.05 .. KL_theta: 67.23 .. Rec_loss: 562.52 .. NELBO: 629.75
Epoch: 1 .. batch: 382/1452 .. LR: 0.05 .. KL_theta: 66.91 .. Rec_loss: 562.26 .. NELBO: 629.17
Epoch: 1 .. batch: 384/1452 .. LR: 0.05 .. KL_theta: 66.59 .. Rec_loss: 562.22 .. NELBO: 628.81
Epoch: 1 .. batch: 386/1452 .. LR: 0.05 .. KL_theta: 66.28 .. Rec_loss: 562.25 .. NELBO: 628.53
Epoch: 1 .. batch: 388/1452 .. LR: 0.05 .. KL_theta: 65.97 .. Rec_loss: 562.14 .. NELBO: 628.11
Epoch: 1 .. batch: 390/1452 .. LR: 0.05 .. KL_theta: 65.66 .. Rec_loss: 562.02 .. NELBO: 627.68
Epoch: 1 .. batch: 392/1452 .. LR: 0.05 

Epoch: 1 .. batch: 542/1452 .. LR: 0.05 .. KL_theta: 48.98 .. Rec_loss: 557.34 .. NELBO: 606.32
Epoch: 1 .. batch: 544/1452 .. LR: 0.05 .. KL_theta: 48.82 .. Rec_loss: 557.42 .. NELBO: 606.24
Epoch: 1 .. batch: 546/1452 .. LR: 0.05 .. KL_theta: 48.67 .. Rec_loss: 557.26 .. NELBO: 605.93
Epoch: 1 .. batch: 548/1452 .. LR: 0.05 .. KL_theta: 48.51 .. Rec_loss: 557.16 .. NELBO: 605.67
Epoch: 1 .. batch: 550/1452 .. LR: 0.05 .. KL_theta: 48.36 .. Rec_loss: 557.04 .. NELBO: 605.4
Epoch: 1 .. batch: 552/1452 .. LR: 0.05 .. KL_theta: 48.2 .. Rec_loss: 557.06 .. NELBO: 605.26
Epoch: 1 .. batch: 554/1452 .. LR: 0.05 .. KL_theta: 48.05 .. Rec_loss: 556.89 .. NELBO: 604.94
Epoch: 1 .. batch: 556/1452 .. LR: 0.05 .. KL_theta: 47.9 .. Rec_loss: 556.75 .. NELBO: 604.65
Epoch: 1 .. batch: 558/1452 .. LR: 0.05 .. KL_theta: 47.75 .. Rec_loss: 556.58 .. NELBO: 604.33
Epoch: 1 .. batch: 560/1452 .. LR: 0.05 .. KL_theta: 47.6 .. Rec_loss: 556.63 .. NELBO: 604.23
Epoch: 1 .. batch: 562/1452 .. LR: 0.05 .. K

In [None]:
# !python main.py --mode train --dataset 20ng --data_path data/20ng --num_topics 50 --train_embeddings 1 --epochs 1000
