Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

![title](img/word2vec.png)

In [1]:
from collections import Counter

import tensorflow as tf
import numpy as np
import re

  return f(*args, **kwds)


In [2]:
PARAMS = {
    'min_freq': 5,
    'skip_window': 5,
    'n_sampled': 100,
    'embed_dim': 200,
    'sample_words': ['six', 'gold', 'japan', 'college'],
    'batch_size': 1000,
    'n_epochs': 10,
}

In [3]:
def preprocess_text(text):
    text = text.replace('\n', ' ')
    text = re.sub('\s+', ' ', text).strip().lower()

    words = text.split()
    word2freq = Counter(words)
    words = [word for word in words if word2freq[word] > PARAMS['min_freq']]
    print("Total words:", len(words))

    _words = set(words)
    PARAMS['word2idx'] = {c: i for i, c in enumerate(_words)}
    PARAMS['idx2word'] = {i: c for i, c in enumerate(_words)}
    PARAMS['vocab_size'] = len(PARAMS['idx2word'])
    print('Vocabulary size:', PARAMS['vocab_size'])

    indexed = [PARAMS['word2idx'][w] for w in words]
    indexed = filter_high_freq(indexed)
    print("Word preprocessing completed ...")
    
    return indexed

def filter_high_freq(int_words, t=1e-5, threshold=0.8):
    int_word_counts = Counter(int_words)
    total_count = len(int_words)

    word_freqs = {w: c / total_count for w, c in int_word_counts.items()}
    prob_drop = {w: 1 - np.sqrt(t / word_freqs[w]) for w in int_word_counts}
    train_words = [w for w in int_words if prob_drop[w] < threshold]

    return train_words

def make_data(int_words):
    x, y = [], []
    for i in range(0, len(int_words)):
        input_w = int_words[i]
        labels = get_y(int_words, i)
        x.extend([input_w] * len(labels))
        y.extend(labels)
    return x, y


def get_y(words, idx):
    skip_window = np.random.randint(1, PARAMS['skip_window']+1)
    left = idx - skip_window if (idx - skip_window) > 0 else 0
    right = idx + skip_window
    y = words[left: idx] + words[idx+1: right+1]
    return list(set(y))

In [4]:
def model_fn(features, labels, mode, params):
    W = tf.get_variable('softmax_W', [PARAMS['vocab_size'], PARAMS['embed_dim']])
    b = tf.get_variable('softmax_b', [PARAMS['vocab_size']])
    E = tf.get_variable('embedding', [PARAMS['vocab_size'], PARAMS['embed_dim']])
    
    embedded = tf.nn.embedding_lookup(E, features['x']) # forward activation
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        loss_op = tf.reduce_mean(tf.nn.sampled_softmax_loss(
            weights = W,
            biases = b,
            labels = labels,
            inputs = embedded,
            num_sampled = PARAMS['n_sampled'],
            num_classes = PARAMS['vocab_size']))

        train_op = tf.train.AdamOptimizer().minimize(
            loss_op, global_step=tf.train.get_global_step())
        
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss_op, train_op=train_op)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        normalized_E = tf.nn.l2_normalize(E, -1)
        sample_E = tf.nn.embedding_lookup(normalized_E, features['x'])
        similarity = tf.matmul(sample_E, normalized_E, transpose_b=True)
        
        return tf.estimator.EstimatorSpec(mode, predictions=similarity)
    

def print_neighbours(similarity, top_k=5):
    for i in range(len(PARAMS['sample_words'])):
        neighbours = (-similarity[i]).argsort()[1:top_k+1]
        log = 'Nearest to [%s]:' % PARAMS['sample_words'][i]
        for k in range(top_k):
            neighbour = PARAMS['idx2word'][neighbours[k]]
            log = '%s %s,' % (log, neighbour)
        print(log)

In [5]:
with open('../temp/ptb_train.txt') as f:
    x_train, y_train = make_data(preprocess_text(f.read()))

estimator = tf.estimator.Estimator(model_fn)

estimator.train(tf.estimator.inputs.numpy_input_fn(
    x = {'x': np.array(x_train)},
    y = np.expand_dims(y_train, -1),
    batch_size = PARAMS['batch_size'],
    num_epochs = PARAMS['n_epochs'],
    shuffle = True))

sim = np.array(list(estimator.predict(tf.estimator.inputs.numpy_input_fn(
    x = {'x': np.array([PARAMS['word2idx'][w] for w in PARAMS['sample_words']])},
    shuffle = False))))

print_neighbours(sim)

Total words: 885720
Vocabulary size: 9582
Word preprocessing completed ...
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpkhu995ue', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x121748d68>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensor

INFO:tensorflow:loss = 2.3783972, step = 6901 (1.462 sec)
INFO:tensorflow:global_step/sec: 68.2986
INFO:tensorflow:loss = 3.0956774, step = 7001 (1.464 sec)
INFO:tensorflow:global_step/sec: 69.2588
INFO:tensorflow:loss = 2.9790354, step = 7101 (1.444 sec)
INFO:tensorflow:global_step/sec: 68.4006
INFO:tensorflow:loss = 3.6135902, step = 7201 (1.462 sec)
INFO:tensorflow:global_step/sec: 65.8242
INFO:tensorflow:loss = 2.4605467, step = 7301 (1.520 sec)
INFO:tensorflow:global_step/sec: 66.8466
INFO:tensorflow:loss = 2.7080777, step = 7401 (1.496 sec)
INFO:tensorflow:global_step/sec: 67.4517
INFO:tensorflow:loss = 2.3850899, step = 7501 (1.482 sec)
INFO:tensorflow:global_step/sec: 68.0477
INFO:tensorflow:loss = 2.744026, step = 7601 (1.470 sec)
INFO:tensorflow:global_step/sec: 68.1496
INFO:tensorflow:loss = 2.5234501, step = 7701 (1.467 sec)
INFO:tensorflow:global_step/sec: 67.7294
INFO:tensorflow:loss = 3.2040606, step = 7801 (1.476 sec)
INFO:tensorflow:Saving checkpoints for 7880 into /va