In [590]:
import os
import time
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from skipgrams import SkipGrams
from data_utils import DataManager
from utils import *
from model import Word2Vec, loss_fn
from functools import partial

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load dataset

In [614]:
WINDOW_SIZE = 11
threshold = 1e-3
dm = DataManager.from_text_file('train.txt', WINDOW_SIZE, threshold=threshold)

Window size: 11. Threshold: 1.00e-02
Getting dataset size...
There are 17556 total lines in the dataset.
Sorting words by frequency...% completed.
Creating keys and unigrams...
Initializing lookup tables...
Finished.
Finding number of tokens...
Total number of tokens: 2462984


## Create an analogies metric to print every so often

In [615]:
def test_common_analogies(model, dm):
    
    U = model.layers[0].get_weights()[0]
    V = model.layers[1].get_weights()[0]
    W_emb = 0.5*(U+V)
    
    func = partial(print_analogies, U=W_emb, vocab_table=dm.vocab_table,
                  inv_vocab_table=dm.inv_vocab_table, K=3)
    
    print('~'*20+' Analogy task '+'~'*20)
    func('he', 'she', 'him')
    print('-'*40)
    func('see', 'saw', 'hear')
    print('-'*40)
    func('m', 'km', 'ft')
    print('-'*40)
    func('possible', 'impossible', 'able')
    print('-'*40)
    func('day', 'week', 'month')
    print('-'*40)
    func('daughter', 'son', 'mother')

## Train

In [616]:
# metrics
train_loss = tf.keras.metrics.Mean()
VOCAB_SIZE = dm.skg.vocab_size+1
history = {'loss':[]}

# model
tf.keras.backend.clear_session()
model = Word2Vec(vocab_size=VOCAB_SIZE, d_model=128)
optimizer = tf.keras.optimizers.Adam()

@tf.function
def train_step(inp, ctxt, lbl, mask):
    with tf.GradientTape() as tape:
        pred = model((inp, ctxt))
        loss = loss_fn(lbl, pred, mask)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss)

In [627]:
NUM_NS = 5
BATCH_SIZE = 128
EPOCHS = 1
DS_SIZE = dm.num_tokens//BATCH_SIZE
BUFFER_SIZE = 5000 # required buffer memory can become large ~ roughly bsz*(num_ns+window_size-1)*buffer*32*3 bytes.

train_ds = dm.batched_ds(BATCH_SIZE, NUM_NS, BUFFER_SIZE).prefetch(1)

for epoch in range(EPOCHS):
    print(f"----- Epoch {epoch+1}/{EPOCHS} -----")
    train_loss.reset_states()
    start = time.time()
    for step, ((inp, ctxt), lbl, mask) in enumerate(train_ds):
        
        train_step(inp, ctxt, lbl, mask)
        loss = train_loss.result().numpy()
        diff = (time.time()-start)/(step+1)
        history['loss'].append(loss)
        print_bar(step, DS_SIZE, diff, loss)
        
        # we start the drop threshold off high to get some training for 
        # common words, then decrease it over training to learn rare words
        if (step+1)%500==0 and threshold > 1e-5:
            threshold /= 10
            dm.skg.set_threshold(threshold)
            
        if (step+1)%1000==0:
            test_common_analogies(model, dm)

## Extract embedding as average of input+output

In [623]:
U = model.layers[0].get_weights()[0]
V = model.layers[1].get_weights()[0]
W_emb = 0.5*(U+V)
find_closest('fun', W_emb, dm.vocab_table, dm.inv_vocab_table)

array([b'zhou', b'various', b'nations', b'less', b'wheeler', b'national',
       b'european', b'sold', b'royal', b'works'], dtype=object)

In [624]:
from functools import partial

print_analogy = partial(print_analogies, U=W_emb, 
    vocab_table=dm.vocab_table, inv_vocab_table=dm.inv_vocab_table, K=3)

print_close = partial(print_closest, U=W_emb, vocab_table=dm.vocab_table,
    inv_vocab_table=dm.inv_vocab_table, K=3)

## (Have fun) testing some analogies

In [625]:
print('~'*20+' Analogy task '+'~'*20)
print_analogy('he', 'she', 'him')
print('-'*40)
print_analogy('some', 'many', 'few')
print('-'*40)
print_analogy('m', 'km', 'ft')
print('-'*40)
print_analogy('person', 'people', 'one')
print('-'*40)
print_analogy('day', 'week', 'month')
print('-'*40)
print_analogy('daughter', 'son', 'mother')

~~~~~~~~~~~~~~~~~~~~ Analogy task ~~~~~~~~~~~~~~~~~~~~
he is to she, as him is to ___?
	Option 1: b'films'
	Option 2: b'wheeler'
	Option 3: b'department'
----------------------------------------
some is to many, as few is to ___?
	Option 1: b'weeks'
	Option 2: b'moved'
	Option 3: b'although'
----------------------------------------
m is to km, as ft is to ___?
	Option 1: b'h'
	Option 2: b'@,@'
	Option 3: b'/'
----------------------------------------
person is to people, as one is to ___?
	Option 1: b"'s"
	Option 2: b'that'
	Option 3: b'by'
----------------------------------------
day is to week, as month is to ___?
	Option 1: b'helped'
	Option 2: b'space'
	Option 3: b'arts'
----------------------------------------
daughter is to son, as mother is to ___?
	Option 1: b'back'
	Option 2: b'states'
	Option 3: b'its'


In [626]:
print('~'*20+' Word grouping task '+'~'*20)
for word in ['president', 'game', 'company']:
    print_close(word)
    print('-'*40)

~~~~~~~~~~~~~~~~~~~~ Word grouping task ~~~~~~~~~~~~~~~~~~~~
The nearest neighbors to president are:
	Neighbor 1: b'national'
	Neighbor 2: b'singles'
	Neighbor 3: b'army'
----------------------------------------
The nearest neighbors to game are:
	Neighbor 1: b'their'
	Neighbor 2: b'which'
	Neighbor 3: b'@-@'
----------------------------------------
The nearest neighbors to company are:
	Neighbor 1: b'which'
	Neighbor 2: b'but'
	Neighbor 3: b'their'
----------------------------------------


In [629]:
dm.skg.vocab_size

28913

In [630]:
len(dm.skg.unigrams)

28913

In [631]:
dm.tokenize(tf.constant('422523'))

<tf.Tensor: shape=(1,), dtype=int64, numpy=array([28914])>

In [632]:
!ls

[34m__pycache__[m[m          tests.py             valid.txt
data_utils.py        train.py             word2vec_colab.ipynb
model.py             train.txt
skipgrams.py         utils.py


In [633]:
import json

In [636]:
probs = [0.]+list(dm.skg.unigrams)

In [654]:
words = tf.strings.split(dm.detokenize(tf.range(dm.skg.vocab_size+1))).numpy()
words = [x.decode() for x in words]

In [655]:
d = dict(zip(words, probs))

In [653]:
[x.decode() for x in words[:5]]

['<pad>', '<sos>', '<eos>', '<unk>', 'the']

In [657]:
history = {'loss':[4.5, 3.2, 2.1]}
json.dumps(history)

'{"loss": [4.5, 3.2, 2.1]}'