In [1]:
import pandas as pd
import numpy as np
import scipy.sparse as sp
import tensorflow as tf
import itertools as it
import pickle
import os

from tensorflow.contrib.tensorboard.plugins import projector

from time import time

from importlib import reload
import sys
if '..' not in sys.path:
    sys.path.append('..')
from sertf import core
reload(core)

LOG_DIR = '/tmp/tensorboard-logs/semantic/'

PATH_DATA = '../data/amazon/food/reviews_df.msg'
PATH_ENC_TXT = '../data/amazon/food/reviews_txt_enc_s.msg'
PATH_VOCAB = '../data/amazon/food/vocab.p'

entity_col = 'ProductId'

In [2]:
df = pd.read_msgpack(PATH_DATA)
data_words_enc = pd.read_msgpack('../data/amazon/food/reviews_txt_enc_s.msg')
vocab = pickle.load(open(PATH_VOCAB, 'rb'))

In [3]:
n_entities = len(df[entity_col].cat.categories)

In [4]:
entity_codes = df[entity_col].cat.codes.values

In [5]:
model = core.Model(vocab, n_entities)

In [6]:
gen = core.win_gen(data_words_enc, entity_codes, n_entities,
                        model.n_negs_per_pos,
                        model.ph_d, model.batch_size)

In [8]:
proj_config = projector.ProjectorConfig()

word_proj = proj_config.embeddings.add()
word_proj.tensor_name = model.embs_d['word'].name
word_proj.metadata_path = os.path.join(LOG_DIR, 'word_metadata.tsv')

# single column meta does not have header
pd.Series(list(enumerate(vocab))).to_csv(os.path.join(LOG_DIR, 'word_metadata.tsv'), sep='\t', index=False, header=False)

summary_writer = tf.summary.FileWriter(LOG_DIR)

In [10]:
%%time
max_steps = 10000

print(f'Approx # epochs: {max_steps*model.batch_size/len(df)}')

gpu_opts = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
saver = tf.train.Saver()
summary_writer = tf.summary.FileWriter(LOG_DIR)

with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
    sess.run(tf.global_variables_initializer())
    tic = time()
    for step in range(max_steps):
        feed = next(gen)
        sess.run(model.train_op, feed_dict=feed)
    
        if (step%1000) == 0:
            toc = time() - tic
            print(step, toc)
            tic = time()

            saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"), step)
            projector.visualize_embeddings(summary_writer, proj_config)

Approx # epochs: 18.013770683291877
0 0.03285670280456543
1000 16.049452304840088
2000 15.728713274002075
3000 15.971048355102539
4000 16.119661331176758
5000 15.836517572402954
6000 16.446090936660767
7000 17.299010276794434
8000 15.81524109840393
9000 15.67464303970337
CPU times: user 3min 35s, sys: 8.35 s, total: 3min 43s
Wall time: 2min 41s
