In [1]:
from gensim.models import Word2Vec

In [None]:
from gensim.models.callbacks import CallbackAny2Vec

## load data

In [None]:
INPUT_FILE = 'data/'

In [2]:
with open('data/unhyphened_tokenized_sent.txt', 'r') as f:
    data = f.read().split('\n')

In [3]:
# remove spaces
sentences = [[token for token in line.split(' ') if token != ''] for line in data]

In [4]:
sentences[0]

['head',
 'down',
 'to',
 'the',
 'river',
 'from',
 'the',
 'museum',
 'of',
 'drinking',
 'water',
 '(',
 'p',
 ')',
 ',',
 'turn',
 'leave',
 ',',
 'and',
 '-PRON-',
 'will',
 'soon',
 'come',
 'across',
 'this',
 'charming',
 'art',
 'village']

## train model

In [16]:
print('Loss after epoch {:3d}: {:10.3f}'.format(3, 100.324))

Loss after epoch   3:    100.324


In [20]:
class callback(CallbackAny2Vec):
    '''Callback to print loss after each epoch.'''
    def __init__(self, loss_record = []):
        self.epoch = 0
        self.total_loss = 0
        self.record = loss_record

    def on_epoch_end(self, model):
        loss = model.get_latest_training_loss() - self.total_loss
        self.record.append(loss)
        self.total_loss = model.get_latest_training_loss()
        print('Loss after epoch {:3d}: {:10.3f}'.format(self.epoch, loss))
        self.epoch += 1

In [21]:
loss = []

In [22]:
model = Word2Vec(sentences, 
                 size = 300,
                 iter = 20,
                 compute_loss = True,
                 callbacks=[callback(loss)])

Loss after epoch   0:  66192.406
Loss after epoch   1:  52306.836
Loss after epoch   2:  50540.539
Loss after epoch   3:  50112.609
Loss after epoch   4:  49121.359
Loss after epoch   5:  48657.875
Loss after epoch   6:  47740.719
Loss after epoch   7:  47004.062
Loss after epoch   8:  46127.906
Loss after epoch   9:  45738.438
Loss after epoch  10:  44718.625
Loss after epoch  11:  44370.062
Loss after epoch  12:  43691.875
Loss after epoch  13:  43214.375
Loss after epoch  14:  42383.125
Loss after epoch  15:  41873.062
Loss after epoch  16:  41094.562
Loss after epoch  17:  40803.062
Loss after epoch  18:  40108.125
Loss after epoch  19:  39268.000
Loss after epoch  20:  38634.562
Loss after epoch  21:  38128.688
Loss after epoch  22:  37861.312
Loss after epoch  23:  37026.438
Loss after epoch  24:  36318.250
Loss after epoch  25:  36126.875
Loss after epoch  26:  35323.500
Loss after epoch  27:  35088.375
Loss after epoch  28:  34372.000
Loss after epoch  29:  34115.125
Loss after

In [23]:
loss

[66192.40625,
 52306.8359375,
 50540.5390625,
 50112.609375,
 49121.359375,
 48657.875,
 47740.71875,
 47004.0625,
 46127.90625,
 45738.4375,
 44718.625,
 44370.0625,
 43691.875,
 43214.375,
 42383.125,
 41873.0625,
 41094.5625,
 40803.0625,
 40108.125,
 39268.0,
 38634.5625,
 38128.6875,
 37861.3125,
 37026.4375,
 36318.25,
 36126.875,
 35323.5,
 35088.375,
 34372.0,
 34115.125,
 33606.5,
 33091.875,
 32668.375,
 32652.25,
 31966.125,
 31559.125,
 30820.375,
 30653.375,
 30317.0,
 29813.875,
 29546.25,
 29423.75,
 29024.5,
 28833.25,
 28413.5,
 28094.25,
 27856.75,
 27565.0,
 27565.0,
 27341.5,
 26864.375,
 26672.5,
 26423.25,
 25961.75,
 26150.25,
 26060.5,
 25510.5,
 25387.5,
 24352.75,
 23878.5,
 23503.0,
 23384.0,
 23310.0,
 23455.0,
 23156.25,
 22745.0,
 22749.5,
 22554.25,
 22138.25,
 22526.0,
 22148.5,
 22013.25,
 22211.25,
 21994.0,
 21949.0,
 21746.25,
 21851.75,
 21470.25,
 21343.75,
 21612.5,
 20938.75,
 21228.75,
 21162.75,
 21208.5,
 21012.0,
 21278.75,
 20800.5,
 20997.0

In [7]:
model.get_latest_training_loss()

3009968.25

In [None]:
model.train()

In [26]:
model.wv.evaluate_word_analogies('./data/questions-words.txt')

(0.0,
 [{'section': 'capital-common-countries', 'correct': [], 'incorrect': []},
  {'section': 'capital-world', 'correct': [], 'incorrect': []},
  {'section': 'currency', 'correct': [], 'incorrect': []},
  {'section': 'city-in-state', 'correct': [], 'incorrect': []},
  {'section': 'family', 'correct': [], 'incorrect': []},
  {'section': 'gram1-adjective-to-adverb',
   'correct': [],
   'incorrect': [('MOST', 'MOSTLY', 'POSSIBLE', 'POSSIBLY'),
    ('POSSIBLE', 'POSSIBLY', 'MOST', 'MOSTLY')]},
  {'section': 'gram2-opposite', 'correct': [], 'incorrect': []},
  {'section': 'gram3-comparative', 'correct': [], 'incorrect': []},
  {'section': 'gram4-superlative', 'correct': [], 'incorrect': []},
  {'section': 'gram5-present-participle',
   'correct': [],
   'incorrect': [('READ', 'READING', 'SWIM', 'SWIMMING'),
    ('SWIM', 'SWIMMING', 'READ', 'READING')]},
  {'section': 'gram6-nationality-adjective',
   'correct': [],
   'incorrect': [('CHINA', 'CHINESE', 'JAPAN', 'JAPANESE'),
    ('JAPAN', 

In [27]:
model.save('model0908/model')

In [29]:
model2 = Word2Vec.load('model0908/model')

In [30]:
wv = model2.wv

In [31]:
wv['look']

array([ 9.51078117e-01,  5.59389174e-01,  6.59767866e-01, -6.04505062e-01,
       -1.22548878e-01, -8.66256714e-01, -9.38208401e-01,  1.36777031e+00,
       -2.99637944e-01, -1.21089637e+00,  1.38088739e+00,  2.19802427e+00,
       -4.63909626e-01, -6.70298338e-02,  1.08959460e+00,  4.38121080e-01,
        8.36748123e-01, -4.10193443e-01,  1.93357870e-01,  4.41122979e-01,
       -6.55516088e-01,  2.25814652e+00, -2.14351368e+00,  1.09360838e+00,
       -8.56462717e-01,  1.70992351e+00, -9.34902966e-01, -5.58280587e-01,
       -2.58142591e-01,  7.13660955e-01, -5.46232939e-01, -5.73939204e-01,
        6.34560645e-01, -8.69602859e-01,  3.42655599e-01, -1.34954047e+00,
        7.77718246e-01, -1.61491036e-01,  9.69480455e-01, -1.06611037e+00,
        4.67451453e-01, -9.23522651e-01,  4.00908232e-01, -1.15980327e+00,
       -2.98833668e-01,  1.32206410e-01,  1.26447988e+00,  3.63334827e-02,
        7.61256218e-01, -4.21543658e-01,  4.14517701e-01,  5.41409016e-01,
       -1.09076373e-01,  