In [15]:
# Imports
from keras.models import Model
from keras.layers import Input, Dense, Reshape, dot
from keras.layers.embeddings import Embedding
from keras.preprocessing.sequence import skipgrams
from keras.preprocessing import sequence

from urllib.request import urlretrieve
import collections
import os
import zipfile

import numpy as np
import tensorflow as tf

In [10]:
class SimilarityCallback:
    def run_sim(self):
        for i in range(valid_size):
            valid_word = reverse_dictionary[valid_examples[i]]
            top_k = 8  # number of nearest neighbors
            sim = self._get_sim(valid_examples[i])
            nearest = (-sim).argsort()[1:top_k + 1]
            log_str = 'Nearest to %s:' % valid_word
            for k in range(top_k):
                close_word = reverse_dictionary[nearest[k]]
                log_str = '%s %s,' % (log_str, close_word)
            print(log_str)

    @staticmethod
    def _get_sim(valid_word_idx):
        sim = np.zeros((vocab_size,))
        in_arr1 = np.zeros((1,))
        in_arr2 = np.zeros((1,))
        in_arr1[0,] = valid_word_idx
        for i in range(vocab_size):
            in_arr2[0,] = i
            out = validation_model.predict_on_batch([in_arr1, in_arr2])
            sim[i] = out
        return sim

In [7]:
def maybe_download(filename, url, expected_bytes):
    """Download a file if not present, and make sure it's the right size."""
    if not os.path.exists(filename):
        filename, _ = urlretrieve(url + filename, filename)
    statinfo = os.stat(filename)
    if statinfo.st_size == expected_bytes:
        print('Found and verified', filename)
    else:
        print(statinfo.st_size)
        raise Exception(
            'Failed to verify ' + filename + '. Can you get to it with a browser?')
    return filename


# Read the data into a list of strings.
def read_data(filename):
    """Extract the first file enclosed in a zip file as a list of words."""
    with zipfile.ZipFile(filename) as f:
        data = tf.compat.as_str(f.read(f.namelist()[0])).split()
    return data


def build_dataset(words, n_words):
    """Process raw inputs into a dataset."""
    count = [['UNK', -1]]
    count.extend(collections.Counter(words).most_common(n_words - 1))
    dictionary = dict()
    for word, _ in count:
        dictionary[word] = len(dictionary)
    data = list()
    unk_count = 0
    for word in words:
        if word in dictionary:
            index = dictionary[word]
        else:
            index = 0  # dictionary['UNK']
            unk_count += 1
        data.append(index)
    count[0][1] = unk_count
    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    return data, count, dictionary, reversed_dictionary

def collect_data(vocabulary_size=10000):
    url = 'http://mattmahoney.net/dc/'
    filename = maybe_download('text8.zip', url, 31344016)
    vocabulary = read_data(filename)
    print(vocabulary[:7])
    data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
                                                                vocabulary_size)
    del vocabulary  # Hint to reduce memory.
    return data, count, dictionary, reverse_dictionary

In [8]:
vocab_size = 10000
data, count, dictionary, reverse_dictionary = collect_data(vocabulary_size=vocab_size)
print(data[:7])

Found and verified text8.zip
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']
[5234, 3081, 12, 6, 195, 2, 3134]


In [9]:
window_size = 3
vector_dim = 300
epochs = 200000

valid_size = 16     # Random set of words to evaluate similarity on.
valid_window = 100  # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)

sampling_table = sequence.make_sampling_table(vocab_size)
couples, labels = skipgrams(data, vocab_size, window_size=window_size, sampling_table=sampling_table)
word_target, word_context = zip(*couples)
word_target = np.array(word_target, dtype="int32")
word_context = np.array(word_context, dtype="int32")

print(couples[:10], labels[:10])

[[5292, 140], [3025, 5493], [8534, 9733], [1844, 7], [6008, 5486], [1036, 8201], [616, 8809], [1732, 5], [139, 29], [1394, 9212]] [1, 0, 0, 1, 0, 0, 0, 1, 1, 0]


In [19]:
# create some input variables
input_target = Input((1,))
input_context = Input((1,))

embedding = Embedding(vocab_size, vector_dim, input_length=1, name='embedding')
target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context)

# setup a cosine similarity operation which will be output in a secondary model
# similarity = dot([target, context], axes=0, normalize=True)
# similarity = concatenate([target, context], mode='cos', dot_axes=0)

# now perform the dot product operation to get a similarity measure
dot_product = dot([target, context], axes=1, normalize=False)
# dot_product = concatenate([target, context], mode='dot', dot_axes=1)
dot_product = Reshape((1,))(dot_product)
# add the sigmoid output layer
output = Dense(1, activation='sigmoid')(dot_product)
# create the primary training model
model = Model(inputs=[input_target, input_context], outputs=output)
model.compile(loss='binary_crossentropy', optimizer='rmsprop')

# create a secondary validation model to run our similarity checks during training
validation_model = Model(inputs=[input_target, input_context], outputs=output)

In [20]:
sim_cb = SimilarityCallback()

arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
    idx = np.random.randint(0, len(labels)-1)
    arr_1[0,] = word_target[idx]
    arr_2[0,] = word_context[idx]
    arr_3[0,] = labels[idx]
    loss = model.train_on_batch([arr_1, arr_2], arr_3)
    if cnt % 100 == 0:
        print("Iteration {}, loss={}".format(cnt, loss))
    if cnt % 10000 == 0:
        sim_cb.run_sim()

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Iteration 0, loss=0.6847344636917114
Nearest to so: harrison, sin, faithful, waiting, estonia, pirate, harbors, onset,
Nearest to for: italians, close, costa, holy, amount, applies, bread, alan,
Nearest to will: predecessor, dawn, insurance, nautical, spiritual, construct, martin, easter,
Nearest to such: vice, broader, cartridge, culturally, lawsuit, mathrm, targets, inability,
Nearest to history: demographics, excellent, perception, measuring, presentation, prince, rat, bias,
Nearest to if: at, bosnia, importantly, flows, interpol, dictatorship, trained, tradition,
Nearest to united: foul, recognition, gore, multiple, australians, touch, held, collaboration,
Nearest to nine: josef, precision, district, dreams, assembly, flag, nl, bible,
Nearest to d: communicate, tutorial, sport, homosexual, reported, he, assert, autonomy,
Nearest to to: mission, h, boys, journey, poll, emission, ithaca, motorcycle,
Nearest to time: equatorial, considering, adds, stanley, inaugural, fires, gibraltar,

Iteration 13600, loss=0.677899181842804
Iteration 13700, loss=0.6827317476272583
Iteration 13800, loss=0.7217070460319519
Iteration 13900, loss=0.7199193239212036
Iteration 14000, loss=0.6822579503059387
Iteration 14100, loss=0.6503889560699463
Iteration 14200, loss=0.6623252630233765
Iteration 14300, loss=0.7125308513641357
Iteration 14400, loss=0.7102271318435669
Iteration 14500, loss=0.7206653952598572
Iteration 14600, loss=0.703676700592041
Iteration 14700, loss=0.6847422122955322
Iteration 14800, loss=0.6674526333808899
Iteration 14900, loss=0.7035478353500366
Iteration 15000, loss=0.7229539752006531
Iteration 15100, loss=0.6818783283233643
Iteration 15200, loss=0.6796396970748901
Iteration 15300, loss=0.7094779014587402
Iteration 15400, loss=0.7026126384735107
Iteration 15500, loss=0.7254642248153687
Iteration 15600, loss=0.6578034162521362
Iteration 15700, loss=0.71762615442276
Iteration 15800, loss=0.6734054088592529
Iteration 15900, loss=0.6868272423744202
Iteration 16000, los

Nearest to history: the, of, and, as, science, zero, in, commodore,
Nearest to if: vice, not, with, nine, believe, conventions, could, compressed,
Nearest to united: four, between, held, list, the, representing, manhattan, leto,
Nearest to nine: one, six, seven, zero, the, of, a, three,
Nearest to d: nine, zero, one, saxony, ivory, inventor, three, communicate,
Nearest to to: the, of, a, in, and, is, nine, on,
Nearest to time: the, to, or, two, for, hopes, measurements, girls,
Nearest to between: one, nine, six, and, a, of, present, seven,
Nearest to people: name, new, whether, certain, theory, involving, riding, is,
Nearest to on: on, and, of, a, to, in, one, zero,
Nearest to which: in, to, of, the, state, two, my, be,
Nearest to can: the, of, to, in, and, a, make, on,
Iteration 30100, loss=0.7159796953201294
Iteration 30200, loss=0.7123920917510986
Iteration 30300, loss=0.6488069891929626
Iteration 30400, loss=0.6945034265518188
Iteration 30500, loss=0.7105246782302856
Iteration 3060

Iteration 46000, loss=0.6701359152793884
Iteration 46100, loss=0.6774842143058777
Iteration 46200, loss=0.6141042709350586
Iteration 46300, loss=0.19222037494182587
Iteration 46400, loss=0.8208033442497253
Iteration 46500, loss=0.6495958566665649
Iteration 46600, loss=0.7786031365394592
Iteration 46700, loss=0.7238121628761292
Iteration 46800, loss=0.7386847138404846
Iteration 46900, loss=0.7065117955207825
Iteration 47000, loss=1.4655158519744873
Iteration 47100, loss=0.5996567606925964
Iteration 47200, loss=0.5089228749275208
Iteration 47300, loss=0.5702613592147827
Iteration 47400, loss=0.7109951376914978
Iteration 47500, loss=0.6354799866676331
Iteration 47600, loss=0.7167547345161438
Iteration 47700, loss=0.5963373184204102
Iteration 47800, loss=0.6599898338317871
Iteration 47900, loss=0.6879804134368896
Iteration 48000, loss=0.7027568817138672
Iteration 48100, loss=0.7401965856552124
Iteration 48200, loss=0.42202243208885193
Iteration 48300, loss=0.58714359998703
Iteration 48400,

KeyboardInterrupt: 