In [1]:
import numpy as np
import tensorflow as tf
import math
import os
import utils as utils

In [2]:
np.random.seed(0)
tf.set_random_seed(0)

data_dir = 'data/20news/'
learning_rate = 5e-5
batch_size = 64
n_hidden = 500
n_topic = 50
vocab_size = 2000
non_linearity = tf.nn.tanh

In [3]:
def record(epoch = 0):
    f_vocab = os.path.join(data_dir, 'vocab.new')

    f = open(f_vocab)
    temp = f.readlines()
    f.close()

    voca = [word.strip().split(' ')[0] for word in temp]

    R = sess.run(nvdm.W_dec)

    for top_ind in range(n_topic):
        as_ = np.argsort(R[top_ind])
        t = [voca[ind] for ind in as_[-10:]]
        t.reverse()
        s = 'Topic %2d : '%(top_ind + 1)
        for word in t:
            s += ' %s,' % (word)

        print(s[:-1])

In [4]:
A = [1, 2, 3, 87]
B = [4, 5, 6, 7]

for a, b in zip(A, B):
    print(a, b)

1 4
2 5
3 6
87 7


In [4]:
class NVDM(object):
    """ Neural Variational Document nvdm -- BOW VAE.
    """
    def __init__(self, vocab_size, n_hidden, n_topic, learning_rate, batch_size, non_linearity):
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_topic = n_topic
        self.non_linearity = non_linearity
        self.learning_rate = learning_rate
        self.batch_size = batch_size

        self.x = tf.placeholder(tf.float32, [None, vocab_size], name='input')
        self.mask = tf.placeholder(tf.float32, [None], name='mask')  # mask paddings
        
        # encoder
        with tf.variable_scope('encoder'): 
            self.W_enc = tf.get_variable('enc_W', [self.vocab_size, n_hidden], initializer = None)
            self.b_enc = tf.get_variable('enc_b', [n_hidden], initializer = None)            
            self.enc_vec = self.non_linearity(tf.matmul(self.x, self.W_enc) + self.b_enc)
            
            self.W_mean = tf.get_variable('mean_W', [self.n_hidden, self.n_topic], initializer = None)
            self.b_mean = tf.get_variable('mean_b', [self.n_topic], initializer = None)            
            self.mean = tf.matmul(self.enc_vec, self.W_mean) + self.b_mean
            
            self.W_logsigm = tf.get_variable('logsigm_W', [self.n_hidden, self.n_topic], initializer = tf.constant_initializer(0))
            self.b_logsigm = tf.get_variable('logsigm_b', [self.n_topic], initializer = tf.constant_initializer(0))            
            self.logsigm = tf.matmul(self.enc_vec, self.W_logsigm) + self.b_logsigm
            
            self.kld = -0.5 * tf.reduce_sum(1 - tf.square(self.mean) + 2 * self.logsigm - tf.exp(2 * self.logsigm), 1)
            self.kld = self.mask * self.kld  # mask paddings
        
        with tf.variable_scope('decoder'):
            eps = tf.random_normal((batch_size, self.n_topic), 0, 1)
            doc_vec = tf.multiply(tf.exp(self.logsigm), eps) + self.mean #reparametrization tech.
            
            self.W_dec = tf.get_variable('dec_W', [self.n_topic, self.vocab_size], initializer = None)
#             self.b_dec = tf.get_variable('dec_b', [self.vocab_size], initializer = None)            
            self.logits = tf.nn.log_softmax(tf.matmul(doc_vec, self.W_dec))
            
            self.recons_loss = - tf.reduce_sum(tf.multiply(self.logits, self.x), 1)

        self.objective = self.recons_loss + self.kld

        optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        fullvars = tf.trainable_variables()

        enc_vars = utils.variable_parser(fullvars, 'encoder')
        dec_vars = utils.variable_parser(fullvars, 'decoder')

        enc_grads = tf.gradients(self.objective, enc_vars)
        dec_grads = tf.gradients(self.objective, dec_vars)

        self.optim_enc = optimizer.apply_gradients(zip(enc_grads, enc_vars))
        self.optim_dec = optimizer.apply_gradients(zip(dec_grads, dec_vars))

In [5]:
train_url = os.path.join(data_dir, 'train.feat')
test_url = os.path.join(data_dir, 'test.feat')

train_set, train_count = utils.data_set(train_url)
test_set, test_count = utils.data_set(test_url)

test_batches = utils.create_batches(len(test_set), batch_size, shuffle=False)

In [6]:
nvdm = NVDM(vocab_size=vocab_size,
            n_hidden=n_hidden,
            n_topic=n_topic, 
            learning_rate=learning_rate, 
            batch_size=batch_size,
            non_linearity=non_linearity)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

training_epochs = 1000
alternate_epochs = 10

In [9]:
for epoch in range(training_epochs):
    
    train_batches = utils.create_batches(len(train_set), batch_size, shuffle=True)
    #-------------------------------
    # train
    for switch in range(2):
        if switch == 0:
            optim = nvdm.optim_dec
            print_mode = 'decoder'
            
        else:
            optim = nvdm.optim_enc
            print_mode = 'encoder'

        for i in range(alternate_epochs):
            loss_sum = 0.0
            ppx_sum = 0.0
            kld_sum = 0.0
            word_count = 0
            doc_count = 0
            for idx_batch in train_batches: # train_batches : 전체 N, idx_batch : 64개씩 있는 집합.
                data_batch, count_batch, mask = utils.fetch_data(
                train_set, train_count, idx_batch, vocab_size)
                input_feed = {nvdm.x.name: data_batch, nvdm.mask.name: mask}
                _, (loss, kld) = sess.run((optim, [nvdm.objective, nvdm.kld]), input_feed)
                loss_sum += np.sum(loss)
                kld_sum += np.sum(kld) / np.sum(mask) 
                word_count += np.sum(count_batch)
                # to avoid nan error
                count_batch = np.add(count_batch, 1e-12)
                # per document loss
                ppx_sum += np.sum(np.divide(loss, count_batch)) 
                doc_count += np.sum(mask)
            print_ppx = np.exp(loss_sum / word_count)
            print_ppx_perdoc = np.exp(ppx_sum / doc_count)
            print_kld = kld_sum/len(train_batches)
        print('Epoch train : %3d | Perplexity : %10.5f | Per doc ppx : %10.5f | KLD : %10.5f'
              %(epoch + 1, print_ppx, print_ppx_perdoc, print_kld))

    #-------------------------------
    # test
    loss_sum = 0.0
    kld_sum = 0.0
    ppx_sum = 0.0
    word_count = 0
    doc_count = 0
    for idx_batch in test_batches:
        data_batch, count_batch, mask = utils.fetch_data(test_set, test_count, idx_batch, vocab_size)
        input_feed = {nvdm.x.name: data_batch, nvdm.mask.name: mask}
        loss, kld = sess.run([nvdm.objective, nvdm.kld], input_feed)
        loss_sum += np.sum(loss)
        kld_sum += np.sum(kld)/np.sum(mask) 
        word_count += np.sum(count_batch)
        count_batch = np.add(count_batch, 1e-12)
        ppx_sum += np.sum(np.divide(loss, count_batch))
        doc_count += np.sum(mask) 
    print_ppx = np.exp(loss_sum / word_count)
    print_ppx_perdoc = np.exp(ppx_sum / doc_count)
    print_kld = kld_sum/len(test_batches)
    print('Epoch test  : %3d | Perplexity : %10.5f | Per doc ppx : %10.5f | KLD : %10.5f'
          %(epoch + 1, print_ppx, print_ppx_perdoc, print_kld))
    print('')    

Epoch train :   1 | Perplexity :  951.93697 | Per doc ppx : 1228.36851 | KLD :   18.25757
Epoch train :   1 | Perplexity :  943.20369 | Per doc ppx : 1223.29072 | KLD :   19.17136
Epoch test  :   1 | Perplexity : 1021.50917 | Per doc ppx : 1241.21377 | KLD :   17.16118

Epoch train :   2 | Perplexity :  909.58945 | Per doc ppx : 1180.28747 | KLD :   19.18850
Epoch train :   2 | Perplexity :  903.99683 | Per doc ppx : 1174.74989 | KLD :   19.81148
Epoch test  :   2 | Perplexity :  990.32191 | Per doc ppx : 1203.46759 | KLD :   17.69938

Epoch train :   3 | Perplexity :  876.05640 | Per doc ppx : 1140.34456 | KLD :   19.84391
Epoch train :   3 | Perplexity :  870.99740 | Per doc ppx : 1133.42457 | KLD :   20.30892
Epoch test  :   3 | Perplexity :  964.05298 | Per doc ppx : 1163.50348 | KLD :   17.93583

Epoch train :   4 | Perplexity :  849.45201 | Per doc ppx : 1102.90146 | KLD :   20.28076
Epoch train :   4 | Perplexity :  844.67123 | Per doc ppx : 1100.06844 | KLD :   20.85953
Epoch t

In [10]:
record() #top 10

Topic  1 :  netcom, proposal, braves, ball, pgp, cops, ripem, escrow, solar, austin
Topic  2 :  environment, umd, differences, minority, east, mph, guide, men, alan, higher
Topic  3 :  georgia, johnson, db, secretary, equipment, music, battery, volume, education, select
Topic  4 :  tax, rates, taxes, congress, bnr, printf, okay, col, authority, bmw
Topic  5 :  chi, iastate, van, church, andy, released, militia, ex, meeting, cal
Topic  6 :  pro, germany, german, gun, muslim, muslims, genocide, isc, armenians, select
Topic  7 :  like, just, out, know, don, who, how, only, which, up
Topic  8 :  pages, expansion, sets, brian, east, postscript, party, criminal, moon, described
Topic  9 :  rom, cd, animals, product, march, archive, att, products, gm, compression
Topic 10 :  door, her, she, apartment, search, announced, coverage, package, austin, zip
Topic 11 :  gm, stratus, button, xt, congress, em, art, microsoft, business, father
Topic 12 :  entries, russian, gif, objects, document, fonts,