In [65]:
import math

import numpy as np
import tensorflow as tf

In [2]:
text = ['he is the king', 'the king is royal', 'she is the royal queen']

In [3]:
window_size = 2
embedding_size = 5

In [27]:
def __is_bounded(direction, range, index, tokens_leng):
    cover = range * direction
    if cover + index < 0 or cover + index >= tokens_leng:
        return True
    else:
        return False

def get_context(tokens, window_size):
    context_pair = []
    for i, token in enumerate(tokens):
        for j in range(1, window_size+1):
            if not __is_bounded(1, j, i, len(tokens)):
                context_pair.append((tokens[i], tokens[i+j]))
            if not __is_bounded(-1, j, i, len(tokens)):
                context_pair.append((tokens[i], tokens[i-j]))
    return context_pair


def __get_word_set(tokens):
    word_set = set()
    for token in tokens:
        word_set.add(token)
    return word_set

In [28]:
context_pair = []
word_set = set()
for sentence in text:
    tokens = sentence.lower().split()
    context_pair += get_context(tokens, window_size)
    tmp_word_set = __get_word_set(tokens)
    for word in tmp_word_set:
        word_set.add(word)

In [31]:
text

['he is the king', 'the king is royal', 'she is the royal queen']

In [29]:
word_set

{'he', 'is', 'king', 'queen', 'royal', 'she', 'the'}

In [30]:
context_pair

[('he', 'is'),
 ('he', 'the'),
 ('is', 'the'),
 ('is', 'he'),
 ('is', 'king'),
 ('the', 'king'),
 ('the', 'is'),
 ('the', 'he'),
 ('king', 'the'),
 ('king', 'is'),
 ('the', 'king'),
 ('the', 'is'),
 ('king', 'is'),
 ('king', 'the'),
 ('king', 'royal'),
 ('is', 'royal'),
 ('is', 'king'),
 ('is', 'the'),
 ('royal', 'is'),
 ('royal', 'king'),
 ('she', 'is'),
 ('she', 'the'),
 ('is', 'the'),
 ('is', 'she'),
 ('is', 'royal'),
 ('the', 'royal'),
 ('the', 'is'),
 ('the', 'queen'),
 ('the', 'she'),
 ('royal', 'queen'),
 ('royal', 'the'),
 ('royal', 'is'),
 ('queen', 'royal'),
 ('queen', 'the')]

In [33]:
def __get_word_index(word_set):
    word_index_dic = dict()
    inverse_word_dic = dict()
    for i, word in enumerate(word_set):
        word_index_dic[word] = i
        inverse_word_dic[i] = word
    return word_index_dic, inverse_word_dic

In [34]:
word_index_dic, inverse_word_dic = __get_word_index(word_set)
word_size = len(word_set)
batch_size = len(context_pair)

In [35]:
word_index_dic

{'king': 0, 'is': 1, 'she': 2, 'the': 3, 'he': 4, 'queen': 5, 'royal': 6}

In [36]:
inverse_word_dic

{0: 'king', 1: 'is', 2: 'she', 3: 'the', 4: 'he', 5: 'queen', 6: 'royal'}

In [37]:
batch_size

34

In [38]:
word_size

7

In [39]:
inputs = [word_index_dic[x[0]] for x in context_pair]
labels = [word_index_dic[x[1]] for x in context_pair]

In [40]:
inputs

[4,
 4,
 1,
 1,
 1,
 3,
 3,
 3,
 0,
 0,
 3,
 3,
 0,
 0,
 0,
 1,
 1,
 1,
 6,
 6,
 2,
 2,
 1,
 1,
 1,
 3,
 3,
 3,
 3,
 6,
 6,
 6,
 5,
 5]

In [41]:
train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
train_labels = tf.placeholder(tf.int32, shape=[batch_size, ])

In [43]:
embeddings = tf.Variable(tf.random_uniform([word_size, embedding_size], -1.0, -1.0))

In [44]:
embed = tf.nn.embedding_lookup(embeddings, train_inputs)

In [47]:
nce_weights = tf.Variable(tf.truncated_normal([word_size, embedding_size], stddev=1.0/math.sqrt(embedding_size)))
nce_biases = tf.Variable(tf.zeros([word_size]))

In [49]:
prediction = tf.add(tf.matmul(embed, tf.transpose(nce_weights)), nce_biases)

In [50]:
train_labels_vector = tf.one_hot(train_labels, word_size)

In [51]:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=train_labels_vector))

In [52]:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0).minimize(loss)

In [59]:
init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

for iteration in range(10000):
    total_loss = 0

    feed_dict = {train_inputs: inputs, train_labels: labels}
    _, cur_loss, pred = sess.run([optimizer, loss, prediction], feed_dict=feed_dict)
    print('{}: loss: {}'.format(iteration, cur_loss))


0: loss: 1.9196467399597168
1: loss: 1.7732106447219849
2: loss: 1.7572031021118164
3: loss: 1.7466468811035156
4: loss: 1.7375222444534302
5: loss: 1.7288062572479248
6: loss: 1.7200652360916138
7: loss: 1.7110892534255981
8: loss: 1.7017781734466553
9: loss: 1.6920934915542603
10: loss: 1.6820369958877563
11: loss: 1.6716371774673462
12: loss: 1.6609398126602173
13: loss: 1.6500028371810913
14: loss: 1.6388887166976929
15: loss: 1.627661108970642
16: loss: 1.616379737854004
17: loss: 1.6051015853881836
18: loss: 1.5938774347305298
19: loss: 1.5827549695968628
20: loss: 1.5717791318893433
21: loss: 1.5609922409057617
22: loss: 1.5504345893859863
23: loss: 1.5401452779769897
24: loss: 1.5301599502563477
25: loss: 1.5205105543136597
26: loss: 1.5112230777740479
27: loss: 1.502317190170288
28: loss: 1.4938054084777832
29: loss: 1.4856928586959839
30: loss: 1.4779773950576782
31: loss: 1.4706512689590454
32: loss: 1.463701844215393
33: loss: 1.4571126699447632
34: loss: 1.4508659839630127

432: loss: 1.3218287229537964
433: loss: 1.321824312210083
434: loss: 1.32181978225708
435: loss: 1.3218154907226562
436: loss: 1.3218109607696533
437: loss: 1.321806788444519
438: loss: 1.3218024969100952
439: loss: 1.3217980861663818
440: loss: 1.3217939138412476
441: loss: 1.3217896223068237
442: loss: 1.3217854499816895
443: loss: 1.3217812776565552
444: loss: 1.3217769861221313
445: loss: 1.321772813796997
446: loss: 1.3217687606811523
447: loss: 1.3217647075653076
448: loss: 1.3217605352401733
449: loss: 1.3217564821243286
450: loss: 1.3217524290084839
451: loss: 1.3217484951019287
452: loss: 1.321744441986084
453: loss: 1.3217405080795288
454: loss: 1.321736454963684
455: loss: 1.3217326402664185
456: loss: 1.3217287063598633
457: loss: 1.3217248916625977
458: loss: 1.321721076965332
459: loss: 1.3217171430587769
460: loss: 1.3217134475708008
461: loss: 1.3217095136642456
462: loss: 1.3217058181762695
463: loss: 1.321702003479004
464: loss: 1.3216981887817383
465: loss: 1.321694

829: loss: 1.3210159540176392
830: loss: 1.3210151195526123
831: loss: 1.321014165878296
832: loss: 1.321013331413269
833: loss: 1.3210123777389526
834: loss: 1.3210115432739258
835: loss: 1.321010708808899
836: loss: 1.321009635925293
837: loss: 1.3210088014602661
838: loss: 1.3210079669952393
839: loss: 1.321007251739502
840: loss: 1.321006178855896
841: loss: 1.3210054636001587
842: loss: 1.3210045099258423
843: loss: 1.321003794670105
844: loss: 1.321002721786499
845: loss: 1.3210020065307617
846: loss: 1.3210011720657349
847: loss: 1.3210004568099976
848: loss: 1.3209996223449707
849: loss: 1.3209986686706543
850: loss: 1.3209980726242065
851: loss: 1.3209972381591797
852: loss: 1.3209962844848633
853: loss: 1.320995807647705
854: loss: 1.3209949731826782
855: loss: 1.3209943771362305
856: loss: 1.3209935426712036
857: loss: 1.3209928274154663
858: loss: 1.320992350578308
859: loss: 1.3209917545318604
860: loss: 1.3209909200668335
861: loss: 1.3209904432296753
862: loss: 1.3209898

1237: loss: 1.321377158164978
1238: loss: 1.3213893175125122
1239: loss: 1.3213876485824585
1240: loss: 1.32140052318573
1241: loss: 1.3213986158370972
1242: loss: 1.3214120864868164
1243: loss: 1.3214101791381836
1244: loss: 1.321423888206482
1245: loss: 1.3214218616485596
1246: loss: 1.3214361667633057
1247: loss: 1.3214341402053833
1248: loss: 1.3214490413665771
1249: loss: 1.3214468955993652
1250: loss: 1.3214622735977173
1251: loss: 1.321460247039795
1252: loss: 1.3214759826660156
1253: loss: 1.3214739561080933
1254: loss: 1.3214901685714722
1255: loss: 1.3214881420135498
1256: loss: 1.3215047121047974
1257: loss: 1.3215025663375854
1258: loss: 1.3215198516845703
1259: loss: 1.3215175867080688
1260: loss: 1.3215354681015015
1261: loss: 1.321532964706421
1262: loss: 1.3215514421463013
1263: loss: 1.3215489387512207
1264: loss: 1.3215678930282593
1265: loss: 1.3215652704238892
1266: loss: 1.3215850591659546
1267: loss: 1.3215820789337158
1268: loss: 1.3216023445129395
1269: loss: 1.

1653: loss: 1.3214139938354492
1654: loss: 1.3214256763458252
1655: loss: 1.321415901184082
1656: loss: 1.321427822113037
1657: loss: 1.3214181661605835
1658: loss: 1.321429967880249
1659: loss: 1.3214203119277954
1660: loss: 1.3214325904846191
1661: loss: 1.3214226961135864
1662: loss: 1.3214349746704102
1663: loss: 1.321425199508667
1664: loss: 1.3214373588562012
1665: loss: 1.3214274644851685
1666: loss: 1.3214401006698608
1667: loss: 1.3214302062988281
1668: loss: 1.3214426040649414
1669: loss: 1.3214328289031982
1670: loss: 1.3214455842971802
1671: loss: 1.321435570716858
1672: loss: 1.3214482069015503
1673: loss: 1.3214384317398071
1674: loss: 1.3214513063430786
1675: loss: 1.3214411735534668
1676: loss: 1.3214540481567383
1677: loss: 1.321444034576416
1678: loss: 1.3214572668075562
1679: loss: 1.3214471340179443
1680: loss: 1.321460247039795
1681: loss: 1.321450114250183
1682: loss: 1.3214634656906128
1683: loss: 1.3214530944824219
1684: loss: 1.3214665651321411
1685: loss: 1.32

2060: loss: 1.3213318586349487
2061: loss: 1.321320652961731
2062: loss: 1.3213306665420532
2063: loss: 1.3213194608688354
2064: loss: 1.3213295936584473
2065: loss: 1.3213183879852295
2066: loss: 1.3213284015655518
2067: loss: 1.3213173151016235
2068: loss: 1.3213273286819458
2069: loss: 1.3213163614273071
2070: loss: 1.3213263750076294
2071: loss: 1.3213152885437012
2072: loss: 1.321325421333313
2073: loss: 1.3213144540786743
2074: loss: 1.321324348449707
2075: loss: 1.321313500404358
2076: loss: 1.3213235139846802
2077: loss: 1.3213127851486206
2078: loss: 1.3213227987289429
2079: loss: 1.3213118314743042
2080: loss: 1.321321964263916
2081: loss: 1.3213109970092773
2082: loss: 1.3213211297988892
2083: loss: 1.32131028175354
2084: loss: 1.3213205337524414
2085: loss: 1.3213094472885132
2086: loss: 1.3213196992874146
2087: loss: 1.321308970451355
2088: loss: 1.3213189840316772
2089: loss: 1.3213082551956177
2090: loss: 1.321318507194519
2091: loss: 1.3213077783584595
2092: loss: 1.321

2479: loss: 1.3212454319000244
2480: loss: 1.3212552070617676
2481: loss: 1.321244239807129
2482: loss: 1.321254014968872
2483: loss: 1.3212430477142334
2484: loss: 1.3212528228759766
2485: loss: 1.3212419748306274
2486: loss: 1.321251630783081
2487: loss: 1.321240782737732
2488: loss: 1.321250557899475
2489: loss: 1.3212395906448364
2490: loss: 1.3212493658065796
2491: loss: 1.3212385177612305
2492: loss: 1.3212480545043945
2493: loss: 1.321237325668335
2494: loss: 1.321246862411499
2495: loss: 1.3212361335754395
2496: loss: 1.321245789527893
2497: loss: 1.321234941482544
2498: loss: 1.321244716644287
2499: loss: 1.3212339878082275
2500: loss: 1.3212435245513916
2501: loss: 1.321232795715332
2502: loss: 1.321242332458496
2503: loss: 1.321231722831726
2504: loss: 1.3212412595748901
2505: loss: 1.3212306499481201
2506: loss: 1.3212400674819946
2507: loss: 1.3212295770645142
2508: loss: 1.3212388753890991
2509: loss: 1.3212286233901978
2510: loss: 1.3212379217147827
2511: loss: 1.3212273

2920: loss: 1.3211551904678345
2921: loss: 1.32114577293396
2922: loss: 1.3211547136306763
2923: loss: 1.3211452960968018
2924: loss: 1.321154236793518
2925: loss: 1.321144938468933
2926: loss: 1.3211537599563599
2927: loss: 1.3211443424224854
2928: loss: 1.321153163909912
2929: loss: 1.3211438655853271
2930: loss: 1.321152687072754
2931: loss: 1.3211435079574585
2932: loss: 1.3211523294448853
2933: loss: 1.3211430311203003
2934: loss: 1.321151852607727
2935: loss: 1.321142554283142
2936: loss: 1.3211512565612793
2937: loss: 1.3211419582366943
2938: loss: 1.321150779724121
2939: loss: 1.3211416006088257
2940: loss: 1.321150302886963
2941: loss: 1.321141004562378
2942: loss: 1.3211498260498047
2943: loss: 1.3211405277252197
2944: loss: 1.321149230003357
2945: loss: 1.3211400508880615
2946: loss: 1.3211488723754883
2947: loss: 1.3211394548416138
2948: loss: 1.32114839553833
2949: loss: 1.3211389780044556
2950: loss: 1.3211477994918823
2951: loss: 1.3211383819580078
2952: loss: 1.32114732

3361: loss: 1.3210550546646118
3362: loss: 1.3210629224777222
3363: loss: 1.3210548162460327
3364: loss: 1.321062684059143
3365: loss: 1.3210546970367432
3366: loss: 1.3210625648498535
3367: loss: 1.321054458618164
3368: loss: 1.3210620880126953
3369: loss: 1.3210543394088745
3370: loss: 1.3210618495941162
3371: loss: 1.3210539817810059
3372: loss: 1.3210617303848267
3373: loss: 1.3210537433624268
3374: loss: 1.321061611175537
3375: loss: 1.3210535049438477
3376: loss: 1.321061372756958
3377: loss: 1.321053385734558
3378: loss: 1.321061134338379
3379: loss: 1.321053147315979
3380: loss: 1.3210608959197998
3381: loss: 1.3210530281066895
3382: loss: 1.3210606575012207
3383: loss: 1.3210527896881104
3384: loss: 1.3210605382919312
3385: loss: 1.3210524320602417
3386: loss: 1.3210601806640625
3387: loss: 1.3210521936416626
3388: loss: 1.3210599422454834
3389: loss: 1.321052074432373
3390: loss: 1.3210595846176147
3391: loss: 1.3210517168045044
3392: loss: 1.3210595846176147
3393: loss: 1.32

3797: loss: 1.3209919929504395
3798: loss: 1.320999026298523
3799: loss: 1.3209917545318604
3800: loss: 1.3209986686706543
3801: loss: 1.3209915161132812
3802: loss: 1.320998191833496
3803: loss: 1.3209912776947021
3804: loss: 1.3209980726242065
3805: loss: 1.3209909200668335
3806: loss: 1.3209978342056274
3807: loss: 1.3209905624389648
3808: loss: 1.3209975957870483
3809: loss: 1.3209905624389648
3810: loss: 1.3209973573684692
3811: loss: 1.3209902048110962
3812: loss: 1.3209969997406006
3813: loss: 1.3209898471832275
3814: loss: 1.3209967613220215
3815: loss: 1.320989727973938
3816: loss: 1.3209965229034424
3817: loss: 1.3209892511367798
3818: loss: 1.3209962844848633
3819: loss: 1.3209891319274902
3820: loss: 1.3209960460662842
3821: loss: 1.3209888935089111
3822: loss: 1.320995807647705
3823: loss: 1.320988655090332
3824: loss: 1.320995569229126
3825: loss: 1.320988416671753
3826: loss: 1.3209953308105469
3827: loss: 1.3209881782531738
3828: loss: 1.3209949731826782
3829: loss: 1.3

4211: loss: 1.320948600769043
4212: loss: 1.3209550380706787
4213: loss: 1.3209483623504639
4214: loss: 1.3209547996520996
4215: loss: 1.3209481239318848
4216: loss: 1.32095468044281
4217: loss: 1.3209478855133057
4218: loss: 1.3209543228149414
4219: loss: 1.3209476470947266
4220: loss: 1.3209542036056519
4221: loss: 1.3209476470947266
4222: loss: 1.3209538459777832
4223: loss: 1.320947289466858
4224: loss: 1.3209537267684937
4225: loss: 1.3209470510482788
4226: loss: 1.320953607559204
4227: loss: 1.3209466934204102
4228: loss: 1.320953369140625
4229: loss: 1.3209466934204102
4230: loss: 1.3209530115127563
4231: loss: 1.320946455001831
4232: loss: 1.3209527730941772
4233: loss: 1.320946216583252
4234: loss: 1.3209525346755981
4235: loss: 1.3209458589553833
4236: loss: 1.3209525346755981
4237: loss: 1.3209457397460938
4238: loss: 1.3209521770477295
4239: loss: 1.3209455013275146
4240: loss: 1.3209518194198608
4241: loss: 1.320945382118225
4242: loss: 1.3209517002105713
4243: loss: 1.320

4638: loss: 1.32091224193573
4639: loss: 1.320906162261963
4640: loss: 1.32091224193573
4641: loss: 1.3209059238433838
4642: loss: 1.3209118843078613
4643: loss: 1.3209058046340942
4644: loss: 1.3209116458892822
4645: loss: 1.3209054470062256
4646: loss: 1.3209115266799927
4647: loss: 1.3209054470062256
4648: loss: 1.3209114074707031
4649: loss: 1.3209052085876465
4650: loss: 1.320911169052124
4651: loss: 1.320905089378357
4652: loss: 1.3209110498428345
4653: loss: 1.3209048509597778
4654: loss: 1.3209110498428345
4655: loss: 1.3209047317504883
4656: loss: 1.3209108114242554
4657: loss: 1.3209047317504883
4658: loss: 1.3209105730056763
4659: loss: 1.3209046125411987
4660: loss: 1.3209104537963867
4661: loss: 1.32090425491333
4662: loss: 1.3209102153778076
4663: loss: 1.32090425491333
4664: loss: 1.3209102153778076
4665: loss: 1.320904016494751
4666: loss: 1.3209099769592285
4667: loss: 1.3209037780761719
4668: loss: 1.3209097385406494
4669: loss: 1.3209037780761719
4670: loss: 1.320909

5087: loss: 1.320871114730835
5088: loss: 1.3208765983581543
5089: loss: 1.3208709955215454
5090: loss: 1.3208764791488647
5091: loss: 1.3208707571029663
5092: loss: 1.3208762407302856
5093: loss: 1.3208705186843872
5094: loss: 1.320876121520996
5095: loss: 1.3208703994750977
5096: loss: 1.3208760023117065
5097: loss: 1.320870280265808
5098: loss: 1.3208757638931274
5099: loss: 1.320870041847229
5100: loss: 1.320875644683838
5101: loss: 1.320870041847229
5102: loss: 1.320875644683838
5103: loss: 1.32086980342865
5104: loss: 1.3208754062652588
5105: loss: 1.32086980342865
5106: loss: 1.3208751678466797
5107: loss: 1.3208696842193604
5108: loss: 1.3208750486373901
5109: loss: 1.3208693265914917
5110: loss: 1.3208749294281006
5111: loss: 1.3208692073822021
5112: loss: 1.3208746910095215
5113: loss: 1.320868968963623
5114: loss: 1.3208744525909424
5115: loss: 1.320868968963623
5116: loss: 1.3208743333816528
5117: loss: 1.320868730545044
5118: loss: 1.3208742141723633
5119: loss: 1.32086849

5525: loss: 1.320841670036316
5526: loss: 1.3208467960357666
5527: loss: 1.3208414316177368
5528: loss: 1.3208465576171875
5529: loss: 1.3208413124084473
5530: loss: 1.320846438407898
5531: loss: 1.3208411931991577
5532: loss: 1.3208463191986084
5533: loss: 1.3208411931991577
5534: loss: 1.3208463191986084
5535: loss: 1.3208410739898682
5536: loss: 1.3208459615707397
5537: loss: 1.320840835571289
5538: loss: 1.3208459615707397
5539: loss: 1.3208407163619995
5540: loss: 1.3208458423614502
5541: loss: 1.32084059715271
5542: loss: 1.3208457231521606
5543: loss: 1.3208404779434204
5544: loss: 1.320845603942871
5545: loss: 1.3208403587341309
5546: loss: 1.320845603942871
5547: loss: 1.3208401203155518
5548: loss: 1.320845603942871
5549: loss: 1.3208401203155518
5550: loss: 1.3208452463150024
5551: loss: 1.3208398818969727
5552: loss: 1.320845365524292
5553: loss: 1.3208398818969727
5554: loss: 1.320845127105713
5555: loss: 1.320839762687683
5556: loss: 1.3208448886871338
5557: loss: 1.32083

5929: loss: 1.320817470550537
5930: loss: 1.3208223581314087
5931: loss: 1.320817470550537
5932: loss: 1.3208223581314087
5933: loss: 1.3208173513412476
5934: loss: 1.3208222389221191
5935: loss: 1.320817232131958
5936: loss: 1.3208222389221191
5937: loss: 1.320816993713379
5938: loss: 1.32082200050354
5939: loss: 1.320816993713379
5940: loss: 1.3208218812942505
5941: loss: 1.3208168745040894
5942: loss: 1.320821762084961
5943: loss: 1.3208167552947998
5944: loss: 1.3208216428756714
5945: loss: 1.3208167552947998
5946: loss: 1.3208216428756714
5947: loss: 1.3208165168762207
5948: loss: 1.3208214044570923
5949: loss: 1.3208165168762207
5950: loss: 1.3208212852478027
5951: loss: 1.3208163976669312
5952: loss: 1.3208210468292236
5953: loss: 1.320816159248352
5954: loss: 1.3208210468292236
5955: loss: 1.3208160400390625
5956: loss: 1.3208208084106445
5957: loss: 1.320815920829773
5958: loss: 1.3208208084106445
5959: loss: 1.320815920829773
5960: loss: 1.320820689201355
5961: loss: 1.320815

6371: loss: 1.320794939994812
6372: loss: 1.3207995891571045
6373: loss: 1.3207948207855225
6374: loss: 1.320799469947815
6375: loss: 1.3207948207855225
6376: loss: 1.320799469947815
6377: loss: 1.320794701576233
6378: loss: 1.3207992315292358
6379: loss: 1.320794701576233
6380: loss: 1.3207992315292358
6381: loss: 1.3207944631576538
6382: loss: 1.3207989931106567
6383: loss: 1.3207942247390747
6384: loss: 1.3207988739013672
6385: loss: 1.3207942247390747
6386: loss: 1.3207988739013672
6387: loss: 1.3207941055297852
6388: loss: 1.320798635482788
6389: loss: 1.3207941055297852
6390: loss: 1.320798635482788
6391: loss: 1.3207941055297852
6392: loss: 1.320798635482788
6393: loss: 1.3207939863204956
6394: loss: 1.320798397064209
6395: loss: 1.320793867111206
6396: loss: 1.320798397064209
6397: loss: 1.320793867111206
6398: loss: 1.3207982778549194
6399: loss: 1.3207937479019165
6400: loss: 1.3207981586456299
6401: loss: 1.320793628692627
6402: loss: 1.3207981586456299
6403: loss: 1.3207933

6815: loss: 1.3207744359970093
6816: loss: 1.3207788467407227
6817: loss: 1.3207744359970093
6818: loss: 1.320778727531433
6819: loss: 1.3207743167877197
6820: loss: 1.320778727531433
6821: loss: 1.3207743167877197
6822: loss: 1.320778727531433
6823: loss: 1.3207740783691406
6824: loss: 1.320778489112854
6825: loss: 1.3207740783691406
6826: loss: 1.320778489112854
6827: loss: 1.320773959159851
6828: loss: 1.3207783699035645
6829: loss: 1.3207738399505615
6830: loss: 1.3207783699035645
6831: loss: 1.3207738399505615
6832: loss: 1.320778250694275
6833: loss: 1.3207736015319824
6834: loss: 1.3207778930664062
6835: loss: 1.3207736015319824
6836: loss: 1.3207778930664062
6837: loss: 1.3207736015319824
6838: loss: 1.3207777738571167
6839: loss: 1.3207734823226929
6840: loss: 1.3207777738571167
6841: loss: 1.3207732439041138
6842: loss: 1.3207776546478271
6843: loss: 1.3207732439041138
6844: loss: 1.3207776546478271
6845: loss: 1.3207731246948242
6846: loss: 1.3207775354385376
6847: loss: 1.3

7244: loss: 1.3207613229751587
7245: loss: 1.3207570314407349
7246: loss: 1.3207612037658691
7247: loss: 1.3207570314407349
7248: loss: 1.3207612037658691
7249: loss: 1.3207569122314453
7250: loss: 1.3207612037658691
7251: loss: 1.3207569122314453
7252: loss: 1.3207610845565796
7253: loss: 1.3207567930221558
7254: loss: 1.32076096534729
7255: loss: 1.3207566738128662
7256: loss: 1.3207608461380005
7257: loss: 1.3207566738128662
7258: loss: 1.3207608461380005
7259: loss: 1.3207565546035767
7260: loss: 1.3207606077194214
7261: loss: 1.320756435394287
7262: loss: 1.3207606077194214
7263: loss: 1.320756435394287
7264: loss: 1.3207606077194214
7265: loss: 1.320756435394287
7266: loss: 1.3207604885101318
7267: loss: 1.320756196975708
7268: loss: 1.3207603693008423
7269: loss: 1.3207563161849976
7270: loss: 1.3207603693008423
7271: loss: 1.3207560777664185
7272: loss: 1.3207602500915527
7273: loss: 1.320755958557129
7274: loss: 1.3207601308822632
7275: loss: 1.320755958557129
7276: loss: 1.32

7689: loss: 1.3207406997680664
7690: loss: 1.3207448720932007
7691: loss: 1.3207406997680664
7692: loss: 1.3207446336746216
7693: loss: 1.3207406997680664
7694: loss: 1.320744514465332
7695: loss: 1.3207405805587769
7696: loss: 1.320744514465332
7697: loss: 1.3207406997680664
7698: loss: 1.320744514465332
7699: loss: 1.3207405805587769
7700: loss: 1.320744514465332
7701: loss: 1.3207404613494873
7702: loss: 1.3207443952560425
7703: loss: 1.3207403421401978
7704: loss: 1.3207443952560425
7705: loss: 1.3207403421401978
7706: loss: 1.3207441568374634
7707: loss: 1.3207402229309082
7708: loss: 1.3207441568374634
7709: loss: 1.3207402229309082
7710: loss: 1.3207441568374634
7711: loss: 1.3207401037216187
7712: loss: 1.3207440376281738
7713: loss: 1.320739984512329
7714: loss: 1.3207440376281738
7715: loss: 1.3207398653030396
7716: loss: 1.3207439184188843
7717: loss: 1.3207398653030396
7718: loss: 1.3207439184188843
7719: loss: 1.32073974609375
7720: loss: 1.3207437992095947
7721: loss: 1.3

8110: loss: 1.3207308053970337
8111: loss: 1.3207271099090576
8112: loss: 1.3207308053970337
8113: loss: 1.320726990699768
8114: loss: 1.3207308053970337
8115: loss: 1.320726990699768
8116: loss: 1.3207306861877441
8117: loss: 1.320726990699768
8118: loss: 1.3207306861877441
8119: loss: 1.320726752281189
8120: loss: 1.3207306861877441
8121: loss: 1.320726752281189
8122: loss: 1.3207305669784546
8123: loss: 1.320726752281189
8124: loss: 1.3207305669784546
8125: loss: 1.320726752281189
8126: loss: 1.320730447769165
8127: loss: 1.3207266330718994
8128: loss: 1.3207303285598755
8129: loss: 1.3207265138626099
8130: loss: 1.3207303285598755
8131: loss: 1.3207265138626099
8132: loss: 1.3207303285598755
8133: loss: 1.3207263946533203
8134: loss: 1.320730209350586
8135: loss: 1.3207263946533203
8136: loss: 1.320730209350586
8137: loss: 1.3207261562347412
8138: loss: 1.3207300901412964
8139: loss: 1.3207261562347412
8140: loss: 1.3207300901412964
8141: loss: 1.3207261562347412
8142: loss: 1.3207

8527: loss: 1.3207145929336548
8528: loss: 1.3207181692123413
8529: loss: 1.3207145929336548
8530: loss: 1.3207181692123413
8531: loss: 1.3207145929336548
8532: loss: 1.3207181692123413
8533: loss: 1.3207143545150757
8534: loss: 1.3207179307937622
8535: loss: 1.3207143545150757
8536: loss: 1.3207180500030518
8537: loss: 1.3207142353057861
8538: loss: 1.3207179307937622
8539: loss: 1.3207142353057861
8540: loss: 1.3207178115844727
8541: loss: 1.3207141160964966
8542: loss: 1.3207178115844727
8543: loss: 1.3207141160964966
8544: loss: 1.3207178115844727
8545: loss: 1.3207141160964966
8546: loss: 1.3207178115844727
8547: loss: 1.320713996887207
8548: loss: 1.3207178115844727
8549: loss: 1.3207138776779175
8550: loss: 1.3207175731658936
8551: loss: 1.3207138776779175
8552: loss: 1.320717453956604
8553: loss: 1.3207138776779175
8554: loss: 1.320717453956604
8555: loss: 1.320713758468628
8556: loss: 1.320717453956604
8557: loss: 1.320713758468628
8558: loss: 1.3207173347473145
8559: loss: 1.

8941: loss: 1.3207032680511475
8942: loss: 1.3207067251205444
8943: loss: 1.3207032680511475
8944: loss: 1.3207067251205444
8945: loss: 1.320703148841858
8946: loss: 1.3207067251205444
8947: loss: 1.320703148841858
8948: loss: 1.3207066059112549
8949: loss: 1.3207030296325684
8950: loss: 1.3207064867019653
8951: loss: 1.3207030296325684
8952: loss: 1.3207063674926758
8953: loss: 1.3207029104232788
8954: loss: 1.3207063674926758
8955: loss: 1.3207027912139893
8956: loss: 1.3207063674926758
8957: loss: 1.3207027912139893
8958: loss: 1.3207063674926758
8959: loss: 1.3207026720046997
8960: loss: 1.3207063674926758
8961: loss: 1.3207026720046997
8962: loss: 1.3207062482833862
8963: loss: 1.3207025527954102
8964: loss: 1.3207061290740967
8965: loss: 1.3207025527954102
8966: loss: 1.3207061290740967
8967: loss: 1.3207025527954102
8968: loss: 1.3207061290740967
8969: loss: 1.3207025527954102
8970: loss: 1.3207060098648071
8971: loss: 1.3207025527954102
8972: loss: 1.3207060098648071
8973: loss

9361: loss: 1.320692777633667
9362: loss: 1.320696234703064
9363: loss: 1.3206926584243774
9364: loss: 1.3206959962844849
9365: loss: 1.3206926584243774
9366: loss: 1.3206959962844849
9367: loss: 1.320692539215088
9368: loss: 1.3206958770751953
9369: loss: 1.3206924200057983
9370: loss: 1.3206959962844849
9371: loss: 1.3206924200057983
9372: loss: 1.3206957578659058
9373: loss: 1.3206924200057983
9374: loss: 1.3206958770751953
9375: loss: 1.3206924200057983
9376: loss: 1.3206957578659058
9377: loss: 1.3206921815872192
9378: loss: 1.3206957578659058
9379: loss: 1.3206923007965088
9380: loss: 1.3206956386566162
9381: loss: 1.3206921815872192
9382: loss: 1.3206955194473267
9383: loss: 1.3206921815872192
9384: loss: 1.3206955194473267
9385: loss: 1.3206921815872192
9386: loss: 1.3206955194473267
9387: loss: 1.3206921815872192
9388: loss: 1.3206952810287476
9389: loss: 1.3206920623779297
9390: loss: 1.3206952810287476
9391: loss: 1.3206920623779297
9392: loss: 1.320695161819458
9393: loss: 

9775: loss: 1.3206830024719238
9776: loss: 1.3206862211227417
9777: loss: 1.3206828832626343
9778: loss: 1.3206861019134521
9779: loss: 1.3206827640533447
9780: loss: 1.3206861019134521
9781: loss: 1.3206828832626343
9782: loss: 1.3206861019134521
9783: loss: 1.3206827640533447
9784: loss: 1.3206861019134521
9785: loss: 1.3206827640533447
9786: loss: 1.3206859827041626
9787: loss: 1.3206827640533447
9788: loss: 1.3206859827041626
9789: loss: 1.3206827640533447
9790: loss: 1.3206859827041626
9791: loss: 1.3206827640533447
9792: loss: 1.320685863494873
9793: loss: 1.3206827640533447
9794: loss: 1.320685863494873
9795: loss: 1.3206825256347656
9796: loss: 1.320685863494873
9797: loss: 1.3206826448440552
9798: loss: 1.320685625076294
9799: loss: 1.3206825256347656
9800: loss: 1.3206857442855835
9801: loss: 1.320682406425476
9802: loss: 1.320685625076294
9803: loss: 1.320682406425476
9804: loss: 1.320685625076294
9805: loss: 1.3206822872161865
9806: loss: 1.3206855058670044
9807: loss: 1.32

In [70]:
def get_vec(word, session):
    return session.run(embeddings[word_index_dic[word]])

def __sim(vec1, vec2):
    return (1 - math.acos(get_cos_similarity(vec1,vec2)) / math.pi)

def get_cos_similarity(vec1, vec2):
    return np.dot(vec1, vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))

def find_cloest_word(word_set, session, target_word):
    sim = 0.0
    vec1 = get_vec(target_word, session)
    result = ''
    for word in word_set:
        if word == target_word:
            continue
        vec2 = get_vec(word,session)
        tmp_sim = __sim(vec1, vec2)
        print('{} : {} : {}'.format(target_word, word, tmp_sim))
        if tmp_sim > sim:
            sim = tmp_sim
            result = word
    return result

In [71]:
find_cloest_word(word_set, sess, 'king')

king : is : 0.4993977721677135
king : she : 0.7711273152027727
king : the : 0.49085532737979043
king : he : 0.7711273152027727
king : queen : 0.779433053181117
king : royal : 0.562655044120014


'queen'