In [1]:
# http://adventuresinmachinelearning.com/word2vec-keras-tutorial/

In [2]:
from keras.models import Model
from keras.layers import Input, Dense, Reshape, Dot, Lambda
from keras.layers.embeddings import Embedding
from keras.preprocessing.sequence import skipgrams
from keras.preprocessing import sequence
from keras import backend as K
from keras.layers import dot

import urllib
import collections
import os
import zipfile

import numpy as np
import tensorflow as tf

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [65]:
def maybe_download(filename, url, expected_bytes):
    """Download a file if not present, and make sure it's the right size."""
    if not os.path.exists(filename):
        filename, _ = urllib.request.urlretrieve(url + filename, os.path.join(os.path.join(os.path.abspath(''),"dataset"),filename))
    statinfo = os.stat(os.path.join(os.path.join(os.path.abspath(''),"dataset"),filename))
    if statinfo.st_size == expected_bytes:
        print('Found and verified', filename)
    else:
        print(statinfo.st_size)
        raise Exception(
            'Failed to verify ' + filename + '. Can you get to it with a browser?')
    return os.path.join(os.path.join(os.path.abspath(''),"dataset"),filename)

In [5]:
# Read the data into a list of strings.
def read_data(filename):
    """Extract the first file enclosed in a zip file as a list of words."""
    with zipfile.ZipFile(filename) as f:
        data = tf.compat.as_str(f.read(f.namelist()[0])).split()
    return data

In [6]:
def build_dataset(words, n_words):
    """Process raw inputs into a dataset."""
    count = [['UNK', -1]]
    count.extend(collections.Counter(words).most_common(n_words - 1))
    dictionary = dict()
    for word, _ in count:
        dictionary[word] = len(dictionary)
    data = list()
    unk_count = 0
    for word in words:
        if word in dictionary:
            index = dictionary[word]
        else:
            index = 0  # dictionary['UNK']
            unk_count += 1
        data.append(index)
    count[0][1] = unk_count
    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
    return data, count, dictionary, reversed_dictionary


In [7]:
def collect_data(vocabulary_size=10000):
    url = 'http://mattmahoney.net/dc/'
    filename = maybe_download('text8.zip', url, 31344016)
    vocabulary = read_data(filename)
    print(vocabulary[:7])
    data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
                                                                vocabulary_size)
    del vocabulary  # Hint fto reduce memory.
    return data, count, dictionary, reverse_dictionary

In [66]:
vocab_size = 10000
data, count, dictionary, reverse_dictionary = collect_data(vocabulary_size=vocab_size)
print(data[:7])

Found and verified C:\Users\yoshi\Documents\GitHub\ML-Playground\keras-word2vec\dataset\text8.zip
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']
[5234, 3081, 12, 6, 195, 2, 3134]


In [9]:
window_size = 3
vector_dim = 300
epochs = 200000
valid_size = 16     # Random set of words to evaluate similarity on.
valid_window = 100  # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)

In [10]:
sampling_table = sequence.make_sampling_table(vocab_size)
couples, labels = skipgrams(data, vocab_size, window_size=window_size, sampling_table=sampling_table)
word_target, word_context = zip(*couples)
word_target = np.array(word_target, dtype="int32")
word_context = np.array(word_context, dtype="int32")

In [11]:
print(couples[:10], labels[:10])

[[1185, 4926], [807, 3053], [931, 2598], [3104, 7284], [6276, 51], [33, 1068], [4147, 45], [3552, 31], [4388, 2939], [5537, 10]] [0, 1, 0, 0, 1, 0, 1, 1, 1, 1]


In [12]:
# create some input variables
input_target = Input((1,))
input_context = Input((1,))

W0922 17:22:04.804660 19256 deprecation_wrapper.py:119] From c:\users\yoshi\appdata\local\programs\python\python37\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0922 17:22:05.025494 19256 deprecation_wrapper.py:119] From c:\users\yoshi\appdata\local\programs\python\python37\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.



In [13]:
embedding = Embedding(vocab_size, vector_dim, input_length=1, name='embedding')
target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context)

W0922 17:22:10.071061 19256 deprecation_wrapper.py:119] From c:\users\yoshi\appdata\local\programs\python\python37\lib\site-packages\keras\backend\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.



In [20]:
# setup a cosine similarity operation which will be output in a secondary model
similarity = dot([target,context],axes=1,normalize=True)

In [21]:
# now perform the dot product operation to get a similarity measure
dot_product = dot([target, context], normalize=False, axes=1)
dot_product = Reshape((1,))(dot_product)
# add the sigmoid output layer
output = Dense(1, activation='sigmoid')(dot_product)
# create the primary training model
model = Model(input=[input_target, input_context], output=output)
model.compile(loss='binary_crossentropy', optimizer='rmsprop')

  import sys


In [22]:
# create a secondary validation model to run our similarity checks during training
validation_model = Model(input=[input_target, input_context], output=similarity)

  


In [39]:
class SimilarityCallback:
    def run_sim(self,word):
        valid_word = reverse_dictionary[dictionary[word]]
        top_k = 8  # number of nearest neighbors
        sim = self._get_sim(dictionary[word])
        nearest = (-sim).argsort()[1:top_k + 1]
        log_str = 'Nearest to %s:' % valid_word
        for k in range(top_k):
            close_word = reverse_dictionary[nearest[k]]
            log_str = '%s %s,' % (log_str, close_word)
        print(log_str)

    @staticmethod
    def _get_sim(valid_word_idx):
        sim = np.zeros((vocab_size,))
        in_arr1 = np.zeros((1,))
        in_arr2 = np.zeros((1,))
        in_arr1[0,] = valid_word_idx
        for i in range(vocab_size):
            in_arr2[0,] = i
            out = validation_model.predict_on_batch([in_arr1, in_arr2])
            sim[i] = out
        return sim
sim_cb = SimilarityCallback()

In [27]:
arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
    idx = np.random.randint(0, len(labels)-1)
    arr_1[0,] = word_target[idx]
    arr_2[0,] = word_context[idx]
    arr_3[0,] = labels[idx]
    loss = model.train_on_batch([arr_1, arr_2], arr_3)
    if cnt % 100 == 0:
        print("Iteration {}, loss={}".format(cnt, loss))
    #if cnt % 10000 == 0:
        #sim_cb.run_sim()

Iteration 0, loss=0.7043582201004028
Iteration 100, loss=0.6977705955505371
Iteration 200, loss=0.6908437609672546
Iteration 300, loss=0.6888933777809143
Iteration 400, loss=0.69443678855896
Iteration 500, loss=0.6899911165237427
Iteration 600, loss=0.7138196229934692
Iteration 700, loss=0.6840954422950745
Iteration 800, loss=0.693039059638977
Iteration 900, loss=0.7091814875602722
Iteration 1000, loss=0.6933619379997253
Iteration 1100, loss=0.7059417963027954
Iteration 1200, loss=0.7038109302520752
Iteration 1300, loss=0.6979499459266663
Iteration 1400, loss=0.6843993067741394
Iteration 1500, loss=0.6782235503196716
Iteration 1600, loss=0.6714305281639099
Iteration 1700, loss=0.7106215357780457
Iteration 1800, loss=0.6892980337142944
Iteration 1900, loss=0.7039719223976135
Iteration 2000, loss=0.6949812769889832
Iteration 2100, loss=0.7141749858856201
Iteration 2200, loss=0.6904787421226501
Iteration 2300, loss=0.7187868356704712
Iteration 2400, loss=0.7043004631996155
Iteration 2500,

Iteration 20400, loss=0.7127941846847534
Iteration 20500, loss=0.6654199361801147
Iteration 20600, loss=0.7026482224464417
Iteration 20700, loss=0.6934882998466492
Iteration 20800, loss=0.6718586087226868
Iteration 20900, loss=0.6840875148773193
Iteration 21000, loss=0.68257737159729
Iteration 21100, loss=0.7395837306976318
Iteration 21200, loss=0.7012734413146973
Iteration 21300, loss=0.6810333132743835
Iteration 21400, loss=0.6767929792404175
Iteration 21500, loss=0.7237265706062317
Iteration 21600, loss=0.6819280385971069
Iteration 21700, loss=0.7023845314979553
Iteration 21800, loss=0.7157949805259705
Iteration 21900, loss=0.6842852234840393
Iteration 22000, loss=0.7316397428512573
Iteration 22100, loss=0.6818546056747437
Iteration 22200, loss=0.6696129441261292
Iteration 22300, loss=0.7196038961410522
Iteration 22400, loss=0.7210505604743958
Iteration 22500, loss=0.7053542733192444
Iteration 22600, loss=0.7163170576095581
Iteration 22700, loss=0.7146627902984619
Iteration 22800, l

Iteration 40500, loss=0.709158718585968
Iteration 40600, loss=0.7007417678833008
Iteration 40700, loss=0.6854608654975891
Iteration 40800, loss=0.634636402130127
Iteration 40900, loss=0.6705198287963867
Iteration 41000, loss=0.8021155595779419
Iteration 41100, loss=0.7013018131256104
Iteration 41200, loss=0.7138341069221497
Iteration 41300, loss=0.6327673196792603
Iteration 41400, loss=0.7075983285903931
Iteration 41500, loss=0.7142137289047241
Iteration 41600, loss=0.7170511484146118
Iteration 41700, loss=0.6484302282333374
Iteration 41800, loss=0.48590517044067383
Iteration 41900, loss=0.7084977626800537
Iteration 42000, loss=0.5287415385246277
Iteration 42100, loss=0.6595121026039124
Iteration 42200, loss=0.7679147124290466
Iteration 42300, loss=0.6101247072219849
Iteration 42400, loss=0.6888533234596252
Iteration 42500, loss=0.6223827600479126
Iteration 42600, loss=0.7077337503433228
Iteration 42700, loss=0.7827364206314087
Iteration 42800, loss=0.6885812282562256
Iteration 42900, 

Iteration 60600, loss=0.6811394691467285
Iteration 60700, loss=0.7071220278739929
Iteration 60800, loss=0.6949946880340576
Iteration 60900, loss=0.6927408576011658
Iteration 61000, loss=0.7112575769424438
Iteration 61100, loss=0.75185227394104
Iteration 61200, loss=0.7036424279212952
Iteration 61300, loss=0.6770431399345398
Iteration 61400, loss=0.6046666502952576
Iteration 61500, loss=0.6528273224830627
Iteration 61600, loss=0.739270806312561
Iteration 61700, loss=0.37906867265701294
Iteration 61800, loss=0.6441942453384399
Iteration 61900, loss=0.7168317437171936
Iteration 62000, loss=0.6349290013313293
Iteration 62100, loss=0.6992717385292053
Iteration 62200, loss=0.7185357809066772
Iteration 62300, loss=0.705105185508728
Iteration 62400, loss=0.6641832590103149
Iteration 62500, loss=0.23473213613033295
Iteration 62600, loss=0.604440450668335
Iteration 62700, loss=0.006087915506213903
Iteration 62800, loss=0.7272270321846008
Iteration 62900, loss=1.0792044401168823
Iteration 63000, 

Iteration 80600, loss=1.0416626930236816
Iteration 80700, loss=0.8675651550292969
Iteration 80800, loss=0.5860719084739685
Iteration 80900, loss=0.5942667722702026
Iteration 81000, loss=0.2819274961948395
Iteration 81100, loss=1.4117467403411865
Iteration 81200, loss=0.7814738750457764
Iteration 81300, loss=0.7356569766998291
Iteration 81400, loss=0.00010443275823490694
Iteration 81500, loss=0.311770498752594
Iteration 81600, loss=0.6337862014770508
Iteration 81700, loss=0.5046642422676086
Iteration 81800, loss=0.16390134394168854
Iteration 81900, loss=0.7723805904388428
Iteration 82000, loss=0.9477216005325317
Iteration 82100, loss=0.5903235673904419
Iteration 82200, loss=0.3322469890117645
Iteration 82300, loss=1.0542693138122559
Iteration 82400, loss=0.834091067314148
Iteration 82500, loss=0.29577910900115967
Iteration 82600, loss=0.8323885202407837
Iteration 82700, loss=0.5135364532470703
Iteration 82800, loss=0.6917175650596619
Iteration 82900, loss=0.5115829110145569
Iteration 83

Iteration 100500, loss=0.8505027294158936
Iteration 100600, loss=0.2353801280260086
Iteration 100700, loss=0.0034885387867689133
Iteration 100800, loss=0.5716490745544434
Iteration 100900, loss=0.4202783703804016
Iteration 101000, loss=0.12893301248550415
Iteration 101100, loss=0.44741469621658325
Iteration 101200, loss=0.19724026322364807
Iteration 101300, loss=0.4696182608604431
Iteration 101400, loss=0.8362036943435669
Iteration 101500, loss=0.4899890422821045
Iteration 101600, loss=0.43146318197250366
Iteration 101700, loss=0.41056734323501587
Iteration 101800, loss=0.7661129236221313
Iteration 101900, loss=0.5348192453384399
Iteration 102000, loss=0.35455697774887085
Iteration 102100, loss=0.5738062858581543
Iteration 102200, loss=0.41861605644226074
Iteration 102300, loss=0.499184250831604
Iteration 102400, loss=1.0542781352996826
Iteration 102500, loss=0.4950883984565735
Iteration 102600, loss=0.35108819603919983
Iteration 102700, loss=0.774634838104248
Iteration 102800, loss=0.

Iteration 119900, loss=0.5211132764816284
Iteration 120000, loss=0.8662559390068054
Iteration 120100, loss=0.6169233918190002
Iteration 120200, loss=0.7323141694068909
Iteration 120300, loss=0.5424603223800659
Iteration 120400, loss=1.4184539318084717
Iteration 120500, loss=0.6697059273719788
Iteration 120600, loss=0.3758188486099243
Iteration 120700, loss=0.45407259464263916
Iteration 120800, loss=0.008788946084678173
Iteration 120900, loss=0.5254298448562622
Iteration 121000, loss=0.8364971280097961
Iteration 121100, loss=1.1075677871704102
Iteration 121200, loss=0.38637667894363403
Iteration 121300, loss=0.5453944802284241
Iteration 121400, loss=1.2515851259231567
Iteration 121500, loss=0.564623236656189
Iteration 121600, loss=0.2003663033246994
Iteration 121700, loss=0.9724838733673096
Iteration 121800, loss=1.2536007165908813
Iteration 121900, loss=0.9824622869491577
Iteration 122000, loss=0.05236739292740822
Iteration 122100, loss=0.0012587481178343296
Iteration 122200, loss=0.47

Iteration 139300, loss=0.4236263334751129
Iteration 139400, loss=4.887701288680546e-05
Iteration 139500, loss=0.4798629581928253
Iteration 139600, loss=0.5452032685279846
Iteration 139700, loss=0.2750888764858246
Iteration 139800, loss=0.4683803915977478
Iteration 139900, loss=1.247666597366333
Iteration 140000, loss=0.02785526216030121
Iteration 140100, loss=0.1769830882549286
Iteration 140200, loss=1.0546889305114746
Iteration 140300, loss=0.06508862972259521
Iteration 140400, loss=0.40699708461761475
Iteration 140500, loss=0.28391510248184204
Iteration 140600, loss=1.192093321833454e-07
Iteration 140700, loss=0.4769325852394104
Iteration 140800, loss=0.7933003306388855
Iteration 140900, loss=0.9136980175971985
Iteration 141000, loss=1.2155762910842896
Iteration 141100, loss=0.8927676677703857
Iteration 141200, loss=0.05097449570894241
Iteration 141300, loss=0.5698312520980835
Iteration 141400, loss=0.3836348056793213
Iteration 141500, loss=0.37397530674934387
Iteration 141600, loss=

Iteration 158700, loss=0.2774644196033478
Iteration 158800, loss=0.9167948365211487
Iteration 158900, loss=0.1667308658361435
Iteration 159000, loss=0.8920900821685791
Iteration 159100, loss=0.36992770433425903
Iteration 159200, loss=0.3074275255203247
Iteration 159300, loss=0.34061184525489807
Iteration 159400, loss=0.49403631687164307
Iteration 159500, loss=1.291396141052246
Iteration 159600, loss=0.2726496160030365
Iteration 159700, loss=0.3534890413284302
Iteration 159800, loss=0.3899666368961334
Iteration 159900, loss=0.7648742198944092
Iteration 160000, loss=0.007979795336723328
Iteration 160100, loss=0.41552993655204773
Iteration 160200, loss=0.027951065450906754
Iteration 160300, loss=1.2431774139404297
Iteration 160400, loss=0.2892298400402069
Iteration 160500, loss=0.3268592953681946
Iteration 160600, loss=0.36033448576927185
Iteration 160700, loss=0.33832356333732605
Iteration 160800, loss=0.9831904172897339
Iteration 160900, loss=0.47812795639038086
Iteration 161000, loss=7

Iteration 178000, loss=0.44700244069099426
Iteration 178100, loss=1.2905287742614746
Iteration 178200, loss=0.2780383229255676
Iteration 178300, loss=1.4040205478668213
Iteration 178400, loss=1.7195014953613281
Iteration 178500, loss=0.2950812876224518
Iteration 178600, loss=0.3150540888309479
Iteration 178700, loss=1.0683499574661255
Iteration 178800, loss=0.2932606637477875
Iteration 178900, loss=0.8693981766700745
Iteration 179000, loss=0.36132383346557617
Iteration 179100, loss=1.194222092628479
Iteration 179200, loss=0.4800880551338196
Iteration 179300, loss=0.06794383376836777
Iteration 179400, loss=0.9477627873420715
Iteration 179500, loss=1.0984032154083252
Iteration 179600, loss=1.2759385108947754
Iteration 179700, loss=0.5409483313560486
Iteration 179800, loss=0.0010849159443750978
Iteration 179900, loss=0.28027480840682983
Iteration 180000, loss=0.5325177907943726
Iteration 180100, loss=0.40360260009765625
Iteration 180200, loss=0.5051113963127136
Iteration 180300, loss=0.06

Iteration 197300, loss=0.0004304976901039481
Iteration 197400, loss=0.9016216993331909
Iteration 197500, loss=0.13048531115055084
Iteration 197600, loss=1.3882907629013062
Iteration 197700, loss=1.1936968564987183
Iteration 197800, loss=0.5067706108093262
Iteration 197900, loss=0.33757483959198
Iteration 198000, loss=0.9337306618690491
Iteration 198100, loss=0.2706547975540161
Iteration 198200, loss=0.46250370144844055
Iteration 198300, loss=0.44139015674591064
Iteration 198400, loss=0.5243008732795715
Iteration 198500, loss=0.4418121874332428
Iteration 198600, loss=0.335632860660553
Iteration 198700, loss=1.0078961849212646
Iteration 198800, loss=0.27046817541122437
Iteration 198900, loss=0.4829268455505371
Iteration 199000, loss=1.052517294883728
Iteration 199100, loss=0.03819243609905243
Iteration 199200, loss=0.640683650970459
Iteration 199300, loss=0.5506829023361206
Iteration 199400, loss=2.86106405837927e-05
Iteration 199500, loss=0.3915008306503296
Iteration 199600, loss=1.3248

In [67]:
validation_model.save("word2vec-keras-model.h5")

In [60]:
sim_cb.run_sim("python")

Nearest to python: monty, or, viewed, circus, march, miles, award, television,
