In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import scipy

import torch
from torch import nn, optim
from torch.nn import functional as F

from src import data
from src.detm import DETM

from gensim.models import CoherenceModel
from gensim.corpora import Dictionary

from analysis_utils import get_detm_topics, topic_diversity

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [3]:
### data and file related arguments
arg_str = """
parser.add_argument('--dataset', type=str, default='un', help='name of corpus')
parser.add_argument('--data_path', type=str, default='un/', help='directory containing data')
parser.add_argument('--emb_path', type=str, default='skipgram/embeddings.txt', help='directory containing embeddings')
parser.add_argument('--save_path', type=str, default='./results', help='path to save results')
parser.add_argument('--batch_size', type=int, default=1000, help='number of documents in a batch for training')
parser.add_argument('--min_df', type=int, default=100, help='to get the right data..minimum document frequency')

### model-related arguments
parser.add_argument('--num_topics', type=int, default=50, help='number of topics')
parser.add_argument('--rho_size', type=int, default=300, help='dimension of rho')
parser.add_argument('--emb_size', type=int, default=300, help='dimension of embeddings')
parser.add_argument('--t_hidden_size', type=int, default=800, help='dimension of hidden space of q(theta)')
parser.add_argument('--theta_act', type=str, default='relu', help='tanh, softplus, relu, rrelu, leakyrelu, elu, selu, glu)')
parser.add_argument('--train_embeddings', type=int, default=1, help='whether to fix rho or train it')
parser.add_argument('--eta_nlayers', type=int, default=3, help='number of layers for eta')
parser.add_argument('--eta_hidden_size', type=int, default=200, help='number of hidden units for rnn')
parser.add_argument('--delta', type=float, default=0.005, help='prior variance')

### optimization-related arguments
parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
parser.add_argument('--lr_factor', type=float, default=4.0, help='divide learning rate by this')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train')
parser.add_argument('--mode', type=str, default='train', help='train or eval model')
parser.add_argument('--optimizer', type=str, default='adam', help='choice of optimizer')
parser.add_argument('--seed', type=int, default=2019, help='random seed (default: 1)')
parser.add_argument('--enc_drop', type=float, default=0.0, help='dropout rate on encoder')
parser.add_argument('--eta_dropout', type=float, default=0.0, help='dropout rate on rnn for eta')
parser.add_argument('--clip', type=float, default=0.0, help='gradient clipping')
parser.add_argument('--nonmono', type=int, default=10, help='number of bad hits allowed')
parser.add_argument('--wdecay', type=float, default=1.2e-6, help='some l2 regularization')
parser.add_argument('--anneal_lr', type=int, default=0, help='whether to anneal the learning rate or not')
parser.add_argument('--bow_norm', type=int, default=1, help='normalize the bows or not')

### evaluation, visualization, and logging-related arguments
parser.add_argument('--num_words', type=int, default=20, help='number of words for topic viz')
parser.add_argument('--log_interval', type=int, default=10, help='when to log training')
parser.add_argument('--visualize_every', type=int, default=1, help='when to visualize results')
parser.add_argument('--eval_batch_size', type=int, default=1000, help='input batch size for evaluation')
parser.add_argument('--load_from', type=str, default='', help='the name of the ckpt to eval from')
parser.add_argument('--tc', type=int, default=0, help='whether to compute tc or not')
""".split('\n')

In [5]:
keys = [x.strip("parser.add_argument('").split(',')[0].strip('--').strip("'") for x in arg_str if (len(x) > 0) and (not x.startswith('#'))]
values = [x.strip("parser.add_argument('").split(',')[2].strip(" default=").strip("'") for x in arg_str if (len(x) > 0) and (not x.startswith('#'))]
tmp_dict = dict(zip(keys, values))

for k, v in tmp_dict.items():
    if v.isnumeric():
        tmp_dict[k] = int(v)
    elif ('.' in v) and (v[0].isnumeric()):
        tmp_dict[k] = float(v)    

args = AttrDict()
args.update(tmp_dict)

args.train_embeddings = 0
args.rho_size = 768
args.num_topics = 10
args.batch_size = 100

In [55]:
train_arr = np.load('test_data.npz', allow_pickle=True)

In [56]:
train_tokens = train_arr['train_tokens']
train_counts = train_arr['train_counts']
train_times = train_arr['train_times']
vocab = train_arr['vocab']
embeddings = train_arr['embeddings']

args.num_times = len(np.unique(train_times))
args.num_docs_train = len(train_tokens)
args.vocab_size = len(vocab)
args.num_words = 10

In [57]:
args.num_times

4

In [58]:
%%time
train_rnn_inp = data.get_rnn_input(train_tokens, train_counts, train_times, args.num_times, args.vocab_size, args.num_docs_train)

idx: 0/2
CPU times: user 3.83 s, sys: 471 ms, total: 4.3 s
Wall time: 383 ms


In [59]:
if not os.path.exists(args['save_path']):
    os.makedirs(args['save_path'])

In [60]:
if args.mode == 'eval':
    ckpt = args.load_from
else:
    ckpt = os.path.join(args.save_path, 
        'detm_{}_K_{}_Htheta_{}_Optim_{}_Clip_{}_ThetaAct_{}_Lr_{}_Bsz_{}_RhoSize_{}_L_{}_minDF_{}_trainEmbeddings_{}'.format(
        args.dataset, args.num_topics, args.t_hidden_size, args.optimizer, args.clip, args.theta_act, 
            args.lr, args.batch_size, args.rho_size, args.eta_nlayers, args.min_df, args.train_embeddings))

In [62]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embeddings = torch.from_numpy(embeddings).to(device)
args.embeddings_dim = embeddings.size()

model = DETM(args, embeddings)
model.to(device)

DETM(
  (t_drop): Dropout(p=0.0, inplace=False)
  (theta_act): ReLU()
  (q_theta): Sequential(
    (0): Linear(in_features=4948, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=10, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=10, bias=True)
  (q_eta_map): Linear(in_features=4938, out_features=200, bias=True)
  (q_eta): LSTM(200, 200, num_layers=3)
  (mu_q_eta): Linear(in_features=210, out_features=10, bias=True)
  (logsigma_q_eta): Linear(in_features=210, out_features=10, bias=True)
)

In [63]:
if args.optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'asgd':
    optimizer = optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay)
else:
    print('Defaulting to vanilla SGD')
    optimizer = optim.SGD(model.parameters(), lr=args.lr)

In [64]:
def train(epoch):
    """Train DETM on data for one epoch.
    """
    model.train()
    acc_loss = 0
    acc_nll = 0
    acc_kl_theta_loss = 0
    acc_kl_eta_loss = 0
    acc_kl_alpha_loss = 0
    cnt = 0
    indices = torch.randperm(args.num_docs_train)
    indices = torch.split(indices, args.batch_size) 
    for idx, ind in enumerate(indices):
        optimizer.zero_grad()
        model.zero_grad()
        data_batch, times_batch = data.get_batch(
            train_tokens, train_counts, ind, args.vocab_size, args.emb_size, temporal=True, times=train_times)
        sums = data_batch.sum(1).unsqueeze(1)
        if args.bow_norm:
            normalized_data_batch = data_batch / sums
        else:
            normalized_data_batch = data_batch

        loss, nll, kl_alpha, kl_eta, kl_theta = model(data_batch, normalized_data_batch, times_batch, train_rnn_inp, args.num_docs_train)
        loss.backward()
        if args.clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        acc_loss += torch.sum(loss).item()
        acc_nll += torch.sum(nll).item()
        acc_kl_theta_loss += torch.sum(kl_theta).item()
        acc_kl_eta_loss += torch.sum(kl_eta).item()
        acc_kl_alpha_loss += torch.sum(kl_alpha).item()
        cnt += 1

        if idx % args.log_interval == 0 and idx > 0:
            cur_loss = round(acc_loss / cnt, 2) 
            cur_nll = round(acc_nll / cnt, 2) 
            cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
            cur_kl_eta = round(acc_kl_eta_loss / cnt, 2) 
            cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2) 
            lr = optimizer.param_groups[0]['lr']
            print('Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format(
                epoch, idx, len(indices), lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss))
    
    cur_loss = round(acc_loss / cnt, 2) 
    cur_nll = round(acc_nll / cnt, 2) 
    cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
    cur_kl_eta = round(acc_kl_eta_loss / cnt, 2) 
    cur_kl_alpha = round(acc_kl_alpha_loss / cnt, 2) 
    lr = optimizer.param_groups[0]['lr']
    print('*'*100)
    print('Epoch----->{} .. LR: {} .. KL_theta: {} .. KL_eta: {} .. KL_alpha: {} .. Rec_loss: {} .. NELBO: {}'.format(
            epoch, lr, cur_kl_theta, cur_kl_eta, cur_kl_alpha, cur_nll, cur_loss))
    print('*'*100)

In [65]:
%%time
## train model on data by looping through multiple epochs
best_epoch = 0
best_val_ppl = 1e9
all_val_ppls = []
for epoch in range(1, args.epochs):
    train(epoch)
    # if epoch % args.visualize_every == 0:
    #     visualize()
    # val_ppl = get_completion_ppl('val')
    # print('val_ppl: ', val_ppl)
    # if val_ppl < best_val_ppl:
    #     with open(ckpt, 'wb') as f:
    #         torch.save(model, f)
    #     best_epoch = epoch
    #     best_val_ppl = val_ppl
    # else:
    ## check whether to anneal lr
    lr = optimizer.param_groups[0]['lr']
    if args.anneal_lr and (len(all_val_ppls) > args.nonmono and val_ppl > min(all_val_ppls[:-args.nonmono]) and lr > 1e-5):
        optimizer.param_groups[0]['lr'] /= args.lr_factor
    #all_val_ppls.append(val_ppl)
model.eval()
with torch.no_grad():
    print('saving topic matrix beta...')
    alpha = model.mu_q_alpha
    beta = model.get_beta(alpha).cpu().numpy()
    scipy.io.savemat(ckpt+'_beta.mat', {'values': beta}, do_compression=True)
    if args.train_embeddings:
        print('saving word embedding matrix rho...')
        rho = model.rho.weight.cpu().numpy()
        scipy.io.savemat(ckpt+'_rho.mat', {'values': rho}, do_compression=True)
    # print('computing validation perplexity...')
    # val_ppl = get_completion_ppl('val')
    # print('computing test perplexity...')
    # test_ppl = get_completion_ppl('test')

Epoch: 1 .. batch: 10/17 .. LR: 0.005 .. KL_theta: 16055.03 .. KL_eta: 2699.23 .. KL_alpha: 11959263.09 .. Rec_loss: 19987842.91 .. NELBO: 31965860.55
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 18886.62 .. KL_eta: 2048.93 .. KL_alpha: 11799006.18 .. Rec_loss: 20247932.59 .. NELBO: 32067874.71
****************************************************************************************************
Epoch: 2 .. batch: 10/17 .. LR: 0.005 .. KL_theta: 21484.79 .. KL_eta: 503.92 .. KL_alpha: 11008494.0 .. Rec_loss: 20054444.91 .. NELBO: 31084928.0
****************************************************************************************************
Epoch----->2 .. LR: 0.005 .. KL_theta: 19715.92 .. KL_eta: 408.92 .. KL_alpha: 10860758.82 .. Rec_loss: 19759509.41 .. NELBO: 30640393.29
****************************************************************************************************
Epoch: 3 .. batch: 1

In [68]:
args.num_words = 10

In [69]:
with torch.no_grad():
    alpha = model.mu_q_alpha
    beta = model.get_beta(alpha) 
    print('beta: ', beta.size())
    print('\n')
    print('#'*100)
    print('Visualize topics...')
    times = [0, 2]
    topics_words = []
    for k in range(args.num_topics):
        for t in times:
            gamma = beta[k, t, :]
            top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            topics_words.append(' '.join(topic_words))
            print('Topic {} .. Time: {} ===> {}'.format(k, t, topic_words)) 

beta:  torch.Size([10, 4, 4938])


####################################################################################################
Visualize topics...
Topic 0 .. Time: 0 ===> ['proactive', 'model', 'antidote', 'summarize', 'adaptive', 'criticize', 'stringent', 'pause', 'naturally']
Topic 0 .. Time: 2 ===> ['model', 'proactive', 'adaptive', 'antidote', 'proceeding', 'summarize', '1920', 'criticize', 'manipulation']
Topic 1 .. Time: 0 ===> ['verify', 'vice', 'concerted', 'profitably', 'respect', 'memorial', 'copper', 'ecb', 'denominate']
Topic 1 .. Time: 2 ===> ['verify', 'concerted', 'profitably', 'vice', 'copper', 'innovator', 'ecb', 'stride', 'fruitful']
Topic 2 .. Time: 0 ===> ['entire', 'correlation', 'master', 'impending', 'climb', 'experience', 'bulk', 'government', 'erosion']
Topic 2 .. Time: 2 ===> ['correlation', 'climb', 'impending', 'master', 'prerequisite', 'bulk', 'sudden', 'entire', 'erosion']
Topic 3 .. Time: 0 ===> ['crisis', 'memorial', 'overwhelming', 'burden', 'e

In [18]:
from src.utils import get_topic_coherence

In [19]:
def _diversity_helper(beta, num_tops):
    list_w = np.zeros((args.num_topics, num_tops))
    for k in range(args.num_topics):
        gamma = beta[k, :]
        top_words = gamma.cpu().numpy().argsort()[-num_tops:][::-1]
        list_w[k, :] = top_words
    list_w = np.reshape(list_w, (-1))
    list_w = list(list_w)
    n_unique = len(np.unique(list_w))
    diversity = n_unique / (args.num_topics * num_tops)
    return diversity

model.eval()
with torch.no_grad():
    alpha = model.mu_q_alpha
    beta = model.get_beta(alpha) 
    print('beta: ', beta.size())

    print('\n')
    print('#'*100)
    print('Get topic diversity...')
    num_tops = 25
    TD_all = np.zeros((args.num_times,))
    for tt in range(args.num_times):
        TD_all[tt] = _diversity_helper(beta[:, tt, :], num_tops)
    TD = np.mean(TD_all)
    print('Topic Diversity is: {}'.format(TD))

    print('\n')
    print('Get topic coherence...')
    print('train_tokens: ', train_tokens[0])
    TC_all = []
    cnt_all = []
    for tt in range(args.num_times):
        tc, cnt = get_topic_coherence(beta[:, tt, :].cpu().numpy(), train_tokens, vocab)
        TC_all.append(tc)
        cnt_all.append(cnt)
    print('TC_all: ', TC_all)
    TC_all = torch.tensor(TC_all)
    print('TC_all: ', TC_all.size())
    print('\n')
    print('Get topic quality...')
    quality = tc * TD
    print('Topic Quality is: {}'.format(quality))
    print('#'*100)

beta:  torch.Size([10, 4, 4938])


####################################################################################################
Get topic diversity...
Topic Diversity is: 0.7849999999999999


Get topic coherence...
train_tokens:  [  53   81  223  315  374  630  752  828 1104 1106 1142 1405 1511 1639
 1642 1698 1743 1780 1948 2153 2231 2367 2797 2825 2826 3006 3030 3251
 3273 3357 3503 3553 3752 4026 4081 4091 4185 4363 4398 4456 4512]
D:  1650
k: 0/10
k: 1/10
k: 2/10
k: 3/10
k: 4/10
k: 5/10
k: 6/10
k: 7/10
k: 8/10
k: 9/10
counter:  55
num topics:  10
Topic Coherence is: [4.2989044858211765, 5.444078348380458, 5.494838475015099, 4.9122535477502955, 5.637867237361495, 4.212042889671402, 3.92523855649542, 6.5435163527632625, 6.894564392415401, 4.829185300933675]
D:  1650
k: 0/10
k: 1/10
k: 2/10
k: 3/10
k: 4/10
k: 5/10
k: 6/10
k: 7/10
k: 8/10
k: 9/10
counter:  55
num topics:  10
Topic Coherence is: [6.003318624617093, 5.343364194291605, 4.903805150637484, 4.5720706152579265, 2.5845

TypeError: can't multiply sequence by non-int of type 'numpy.float64'

## see if i can get a measure of topic coherence

In [70]:
id2word = Dictionary.load('dict_save')
df = pd.read_parquet('data/combined_clean.parquet')
split_text = df['filtered_text'].str.split().values

In [71]:
for topic in get_detm_topics(model=model, time=0, num_words=10, vocab=vocab, num_topics=10):
    print(topic)

['proactive', 'model', 'antidote', 'summarize', 'adaptive', 'criticize', 'stringent', 'pause', 'naturally', 'compelling']
['verify', 'vice', 'concerted', 'profitably', 'respect', 'memorial', 'copper', 'ecb', 'denominate', 'college']
['entire', 'correlation', 'master', 'impending', 'climb', 'experience', 'bulk', 'government', 'erosion', 'prerequisite']
['crisis', 'memorial', 'overwhelming', 'burden', 'engaged', 'skepticism', 'detrimental', 'barely', 'possibly', 'virtuous']
['fulfil', 'friend', 'encompass', 'legitimate', 'birth', 'immense', 'statistic', 'pickup', 'mitigation', 'akin']
['unsettled', 'revitalization', 'call', 'productive', 'model', 'president', 'battle', 'reassessment', 'fulfil', 'sounder']
['fulfil', 'subordinate', 'introductory', 'clearing', 'native', 'manager', 'crude', 'revitalization', 'author', 'specifically']
['manipulation', 'comprehend', 'license', 'positively', 'unit', 'purely', 'peak', 'harness', 'burden', 'ecb']
['proceeding', 'contentious', 'appeal', 'barely',

In [72]:
coherences = []
for t in tqdm(range(args.num_times)):
    coherences.append(
        CoherenceModel(
            topics=get_detm_topics(model=model, time=t, num_words=20, vocab=vocab, num_topics=args.num_topics), # use 20 words to standardize with DTM
            texts=split_text, 
            dictionary=id2word, 
            coherence='c_v'
        ).get_coherence()
    )

coherences = np.array(coherences)

100%|██████████| 4/4 [00:43<00:00, 10.88s/it]


In [73]:
diversities = []
for t in tqdm(range(args.num_times)):
    diversities.append(
        topic_diversity(topics=get_topics(model=model, time=t, num_words=20, vocab=vocab, num_topics=args.num_topics))
    )
    
diversities = np.array(diversities)

100%|██████████| 4/4 [00:00<00:00, 180.88it/s]


In [74]:
qualities = diversities * coherences
qualities.mean(), qualities.std()

(0.43268138217844765, 0.008716614120983622)

In [75]:
qualities

array([0.42344675, 0.44412863, 0.43810009, 0.42505006])

In [76]:
np.savez_compressed(
    'detm_stats.npz',
    coherence=coherences,
    diversity=diversities
)