In [None]:
from __future__ import division
from itertools import islice
from imp import reload
import sys
from random import shuffle
import datetime
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import tensorflow as tf
import corpus_tools as ct
ct = reload(ct)

In [None]:
def plot_with_labels(low_dim_embs, labels, filename='tsne.png', size=(100, 100)):
    assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"
    figure = plt.figure(figsize=size) #in inches
    for i, label in enumerate(labels):
        x, y = low_dim_embs[i,:]
        plt.scatter(x, y)
        plt.annotate(label,
                     xy=(x, y),
                     xytext=(5, 2),
                     textcoords='offset points',
                     ha='right',
                     va='bottom')
    figure.savefig(filename)
    plt.close(figure)

def output_tsne(embeddings, filename, size=(100, 100), plot_only=1000):
    tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
    low_dim_embs = tsne.fit_transform(embeddings[:plot_only,:])
    labels = reddit_corpus.words[:plot_only]
    plot_with_labels(low_dim_embs, labels, filename, size)
    
def batchify(batch_size, *sequences):
    for i in xrange(0, len(sequences[0]), batch_size):
        yield tuple(sequence[i:i+batch_size] for sequence in sequences)

In [None]:
COUNT_MAX = 100
SCALING_FACTOR = 3 / 4
EMBEDDING_SIZE = 150
CONTEXT_SIZE = 10
MIN_OCCURRENCES = 25
LEARNING_RATE = 0.05
NUM_EPOCHS = 50
BATCH_SIZE = 512
REPORT_BATCH_SIZE = 10000
TSNE_EPOCH_FREQ = 1

In [None]:
# load the corpus
reddit_corpus = ct.RedditCorpus("/media/grady/PrimeMover/Datasets/RC_2015-01-1m_sample", size=CONTEXT_SIZE)
reddit_corpus.fit(min_occurrences=MIN_OCCURRENCES)
vocab_size = len(reddit_corpus.words)
print("Unique tokens: {}".format(vocab_size))
print("Non-zero elems of cooccurrence matrix: {}".format(len(reddit_corpus.cooccurrence_matrix)))

In [None]:
def device_for_node(n):
  if n.type == "MatMul":
    return "/gpu:0"
  else:
    return "/cpu:0"

graph = tf.Graph()

with graph.as_default():
    with graph.device(device_for_node):
        count_max = tf.constant([COUNT_MAX], dtype=tf.float32)
        scaling_factor = tf.constant([SCALING_FACTOR], dtype=tf.float32)
        focal_input = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
        context_input = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
        cooccurrence_count = tf.placeholder(tf.float32, shape=[BATCH_SIZE])


        focal_embeddings = tf.Variable(
            tf.random_uniform([vocab_size, EMBEDDING_SIZE], 1.0, -1.0)
        )

        context_embeddings = tf.Variable(
            tf.random_uniform([vocab_size, EMBEDDING_SIZE], 1.0, -1.0)
        )

        focal_biases = tf.Variable(
            tf.random_uniform([vocab_size], 1.0, -1.0)
        )

        context_biases = tf.Variable(
            tf.random_uniform([vocab_size], 1.0, -1.0)
        )

        focal_embedding = tf.nn.embedding_lookup([focal_embeddings], focal_input)
        context_embedding = tf.nn.embedding_lookup([context_embeddings], context_input)
        focal_bias = tf.nn.embedding_lookup([focal_biases], focal_input)
        context_bias = tf.nn.embedding_lookup([context_biases], context_input)

        weighting_factor = tf.minimum(
            1.0, 
            tf.pow(
                tf.div(cooccurrence_count, count_max), 
                scaling_factor
            )
        )
        
        embedding_product = tf.reduce_sum(tf.mul(focal_embedding, context_embedding), 1) 
        
        log_cooccurrences = tf.log(tf.to_float(cooccurrence_count))
        
        distance_expr = tf.square(tf.add_n([
            embedding_product,
            focal_bias,
            context_bias,
            tf.neg(log_cooccurrences)
        ]))

        single_losses = tf.mul(weighting_factor, distance_expr)
        total_loss = tf.reduce_sum(single_losses)
        optimizer = tf.train.AdagradOptimizer(LEARNING_RATE).minimize(total_loss)
        
        combined_embeddings = tf.add(focal_embeddings, context_embeddings)

# okay, so the graph is maybe built... Time to run it!


In [None]:
cooccurrences = [(pos[0], pos[1], count) for pos, count in reddit_corpus.cooccurrence_matrix.items()]

i_indices, j_indices, counts = zip(*cooccurrences)

batches = list(batchify(BATCH_SIZE, i_indices, j_indices, counts))

In [None]:
print("Begin training: {}".format(datetime.datetime.now().time()))
print("=================")
sys.stdout.flush()
with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()
    for epoch in range(NUM_EPOCHS):
        shuffle(batches)
        print("Batches shuffled")
        print("-----------------")
        sys.stdout.flush()
        accumulated_loss = 0
        for batch_index, batch in enumerate(batches):
            i_s, j_s, counts = batch
            if len(counts) != BATCH_SIZE:
                continue
            feed_dict = {focal_input: i_s, context_input: j_s, cooccurrence_count: counts}
            _, total_loss_,  = session.run([optimizer, total_loss], feed_dict=feed_dict)            
            accumulated_loss += total_loss_
            if (batch_index + 1) % REPORT_BATCH_SIZE == 0:
                print("Epoch: {0}/{1}".format(epoch + 1, NUM_EPOCHS))
                print("Batch: {0}/{1}".format(batch_index + 1, len(batches)))
                print("Average loss: {}".format(accumulated_loss / REPORT_BATCH_SIZE))
                print("-----------------")
                sys.stdout.flush()
                accumulated_loss = 0
        if (epoch + 1) % TSNE_EPOCH_FREQ == 0:
            print("Outputting t-SNE: {}".format(datetime.datetime.now().time()))
            print("-----------------")
            sys.stdout.flush()
            current_embeddings = combined_embeddings.eval()
            output_tsne(current_embeddings, "epoch{:02d}.png".format(epoch + 1))
        print("Epoch finished: {}".format(datetime.datetime.now().time()))
        print("=================")
        sys.stdout.flush()
    final_embeddings = combined_embeddings.eval()
print("End: {}".format(datetime.datetime.now().time()))

In [None]:
output_tsne(final_embeddings, "final_big-minibatch-1m.png", plot_only=4000)