In [9]:
import torch
import pickle 
import math 
import random 
import sys
import matplotlib.pyplot as plt 
import data
import scipy.io

from torch import nn, optim
from torch.nn import functional as F

from etm import ETM
from utils import nearest_neighbors, get_topic_coherence, get_topic_diversity

from easydict import EasyDict as edict
import os 
import numpy as np
import pandas as pd

ModuleNotFoundError: No module named 'data'

In [None]:
__C = edict()
args = __C

### data and file related arguments
args.dataset = 'diag'
args.data_path = './scripts/diag_dim60'
args.emb_path = '../diag_embedding.txt'
args.save_path = '../results'
args.batch_size = 1000

### model-related arguments

args.num_topics = 10
args.rho_size = 60
args.emb_size = 60
args.t_hidden_size = 800
args.theta_act = 'relu'
args.train_embeddings = 1

### optimization-related arguments

args.lr = 0.005
args.lr_factor = 4.0
args.epochs = 100
args.mode = 'eval'  #train or eval
args.optimizer = 'adam'
args.seed = 2019
args.enc_drop = 0.0
args.clip = 0.0
args.nonmono = 10
args.wdecay = 1.2e-6
args.anneal_lr = 0
args.bow_norm = 1


### evaluation, visualization, and logging-related arguments

args.num_words = 10  # number of words for topic viz
args.log_interval = 40
args.visualize_every = 10
args.eval_batch_size = 1000
# args.load_from = './results/etm_diag_K_40_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_0'
# args.load_from = ''
args.tc = 1
args.td = 1

In [5]:
for fn in os.listdir():
    if fn.startswith('Dec'):
        print(fn)

Dec17_etm_diag_K_25_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_35_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_45_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_10_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_20_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_5_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_15_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_50_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbeddings_1
Dec17_etm_diag_K_30_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_60_trainEmbedd

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('\n')
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)

## get data
# 1. vocabulary
vocab, train, valid, test = data.get_data(os.path.join(args.data_path))
vocab_size = len(vocab)
args.vocab_size = vocab_size

# 1. training data
train_tokens = train['tokens']
train_counts = train['counts']
args.num_docs_train = len(train_tokens)

# 2. dev set
valid_tokens = valid['tokens']
valid_counts = valid['counts']
args.num_docs_valid = len(valid_tokens)

# 3. test data
test_tokens = test['tokens']
test_counts = test['counts']
args.num_docs_test = len(test_tokens)
test_1_tokens = test['tokens_1']
test_1_counts = test['counts_1']
args.num_docs_test_1 = len(test_1_tokens)
test_2_tokens = test['tokens_2']
test_2_counts = test['counts_2']
args.num_docs_test_2 = len(test_2_tokens)






NameError: name 'args' is not defined

In [None]:
ckpt = args.load_from


In [8]:
def evaluate(m, source, tc=False, td=False):
    """Compute perplexity on document completion.
    """
    m.eval()
    with torch.no_grad():
        if source == 'val':
            indices = torch.split(torch.tensor(range(args.num_docs_valid)), args.eval_batch_size)
            tokens = valid_tokens
            counts = valid_counts
        else: 
            indices = torch.split(torch.tensor(range(args.num_docs_test)), args.eval_batch_size)
            tokens = test_tokens
            counts = test_counts

        ## get \beta here
        beta = m.get_beta()

        ### do dc and tc here
        acc_loss = 0
        cnt = 0
        indices_1 = torch.split(torch.tensor(range(args.num_docs_test_1)), args.eval_batch_size)
        for idx, ind in enumerate(indices_1):
            ## get theta from first half of docs
            data_batch_1 = data.get_batch(test_1_tokens, test_1_counts, ind, args.vocab_size, device)
            sums_1 = data_batch_1.sum(1).unsqueeze(1)
            if args.bow_norm:
                normalized_data_batch_1 = data_batch_1 / sums_1
            else:
                normalized_data_batch_1 = data_batch_1
            theta, _ = m.get_theta(normalized_data_batch_1)

            ## get prediction loss using second half
            data_batch_2 = data.get_batch(test_2_tokens, test_2_counts, ind, args.vocab_size, device)
            sums_2 = data_batch_2.sum(1).unsqueeze(1)
            res = torch.mm(theta, beta)
            preds = torch.log(res)
            recon_loss = -(preds * data_batch_2).sum(1)
            
            loss = recon_loss / sums_2.squeeze()
            loss = loss.mean().item()
            acc_loss += loss
            cnt += 1
        cur_loss = acc_loss / cnt
        ppl_dc = round(math.exp(cur_loss), 1)
        print('*'*100)
        print('{} Doc Completion PPL: {}'.format(source.upper(), ppl_dc))
        print('*'*100)
        if tc or td:
            beta = beta.data.cpu().numpy()
            if tc:
                print('Computing topic coherence...')
                get_topic_coherence(beta, train_tokens, vocab)
            if td:
                print('Computing topic diversity...')
                get_topic_diversity(beta, 25)
        return ppl_dc


/home/jupyter/ETM/results


In [None]:
with open(ckpt, 'rb') as f:
    model = torch.load(f)
model = model.to(device)
model.eval()

with torch.no_grad():
    ## get document completion perplexities
    test_ppl = evaluate(model, 'test', tc=args.tc, td=args.td)

    ## get most used topics
    indices = torch.tensor(range(args.num_docs_train))
    indices = torch.split(indices, args.batch_size)
    thetaAvg = torch.zeros(1, args.num_topics).to(device)
    thetaWeightedAvg = torch.zeros(1, args.num_topics).to(device)
    cnt = 0
    for idx, ind in enumerate(indices):
        data_batch = data.get_batch(train_tokens, train_counts, ind, args.vocab_size, device)
        sums = data_batch.sum(1).unsqueeze(1)
        cnt += sums.sum(0).squeeze().cpu().numpy()
        if args.bow_norm:
            normalized_data_batch = data_batch / sums
        else:
            normalized_data_batch = data_batch
        theta, _ = model.get_theta(normalized_data_batch)
        thetaAvg += theta.sum(0).unsqueeze(0) / args.num_docs_train
        weighed_theta = sums * theta
        thetaWeightedAvg += weighed_theta.sum(0).unsqueeze(0)
        if idx % 100 == 0 and idx > 0:
            print('batch: {}/{}'.format(idx, len(indices)))
    thetaWeightedAvg = thetaWeightedAvg.squeeze().cpu().numpy() / cnt
    print('\nThe 10 most used topics are {}'.format(thetaWeightedAvg.argsort()[::-1][:10]))

    ## show topics
    beta = model.get_beta()
    topic_indices = list(np.random.choice(args.num_topics, 10)) # 10 random topics
    print('\n')
    for k in range(args.num_topics):#topic_indices:
        gamma = beta[k]
        top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1])
        topic_words = [vocab[a] for a in top_words]
        print('Topic {}: {}'.format(k, topic_words))

    if args.train_embeddings:
        ## show etm embeddings 
        try:
            rho_etm = model.rho.weight.cpu()
        except:
            rho_etm = model.rho.cpu()
#             queries = ['andrew', 'woman', 'computer', 'sports', 'religion', 'man', 'love', 
#                             'intelligence', 'money', 'politics', 'health', 'people', 'family']
        queries = ['5855']
        print('\n')
        print('ETM embeddings...')
        for word in queries:
            print('word: {} .. etm neighbors: {}'.format(word, nearest_neighbors(word, rho_etm, vocab)))
        print('\n')