In [15]:
import numpy as np
import tensorflow as tf
import random  
from collections import Counter
import datetime, time, json
from copy import deepcopy

In [2]:
def create_word_pairs(int_corpus, window_size, stop_size):
    idx_pairs = []
    tokens = 0
    # for each snetence 
    for sentence in int_corpus:
        # for each center word
        for center_word_pos in range(len(sentence)):
            center_word_idx = sentence[center_word_pos]
            tokens += 1
            if tokens >= stop_size:
                return idx_pairs, tokens
            else:
                # for each context word within window
                for w in range(-window_size, window_size + 1):
                    context_word_pos = center_word_pos + w
                    # make soure not jump out sentence
                    if context_word_pos < 0 or context_word_pos >= len(sentence) or center_word_pos == context_word_pos:
                        continue
                    context_word_idx = sentence[context_word_pos]
                    idx_pairs.append((center_word_idx, context_word_idx))

                    
    return idx_pairs, tokens
        

In [3]:
def get_batches(idx_pairs, batch_size):
    n_batches = len(idx_pairs) // batch_size
    idx_pairs = idx_pairs[:n_batches*batch_size]
    for idx in range(0, len(idx_pairs), batch_size):
        x, y = [], []
        batch = idx_pairs[idx:idx+batch_size]
        for ii in range (len(batch)):
            x.append(batch[ii][0])
            y.append(batch[ii][1])        
        yield x, y  

### create word pairs

In [4]:
#corpus = np.load("/Users/zhang/MscProject_tweak2vec/corpus/quora_corpus_int5.npy").tolist()
corpus = np.load("/Users/zhang/MscProject_tweak2vec/corpus/pubmed_corpus_int5.npy").tolist()

corpus_shuffle = corpus[:]

random.shuffle(corpus_shuffle)
pubmed_idx_pairs, tokens = create_word_pairs(corpus_shuffle, window_size = 5, stop_size=7000000)
print('totally {0} word pairs'.format(len(pubmed_idx_pairs)))
print('totally {0} tokens'.format(tokens))


totally 26324522 word pairs
totally 3258438 tokens


In [5]:
tokens_lst = [4,3,2,1,0.5,0.1,0.05,0.01]
idx_pairs = []
for i in tokens_lst:
    random.shuffle(corpus_shuffle)
    pairs, tokens = create_word_pairs(corpus_shuffle, window_size = 5, stop_size = i * 1000000)
    idx_pairs.append(pairs)
    print('totally {0} word pairs'.format(len(pairs)))
    print('totally {0} tokens'.format(tokens))

totally 26324522 word pairs
totally 3258438 tokens
totally 24235359 word pairs
totally 3000000 tokens
totally 16158048 word pairs
totally 2000000 tokens
totally 8080080 word pairs
totally 1000000 tokens
totally 4041333 word pairs
totally 500000 tokens
totally 807589 word pairs
totally 100000 tokens
totally 404625 word pairs
totally 50000 tokens
totally 80163 word pairs
totally 10000 tokens


In [6]:
# wordlist = np.load('/Users/zhang/MscProject_tweak2vec/corpus/quora_vocab5.npy').tolist()
# wordlist.append(['UNK',0])
# word2idx = {w[0]: wordlist.index(w) for w in wordlist }
# idx2word = {wordlist.index(w): w[0] for w in wordlist }

wordlist = np.load('/Users/zhang/MscProject_tweak2vec/corpus/pubmed_vocab5.npy').tolist()
wordlist.append('UNK')
word2idx = {w: wordlist.index(w) for w in wordlist }
idx2word = {wordlist.index(w): w for w in wordlist }

### load pivot word vectors

In [16]:
# f = open('/Users/zhang/MscProject_tweak2vec/corpus/quora_pivots_google_10000.txt','r')
f = open('/Users/zhang/MscProject_tweak2vec/corpus/pubmed_pivots_google_5000.txt','r')
a = f.read()
pivots_dict = eval(a)
f.close()
print('load {0} pivot words'.format(len(pivots_dict.keys())))

load 5000 pivot words


In [17]:
dict_slice = lambda adict, start, end: dict((k, adict[k]) for k in list(adict.keys())[start:end])
def get_pivots_slice(pivots_dict, size):
    pivots = deepcopy(pivots_dict)
    pivots_slice = dict_slice(pivots, 0, size)
    pivots_idx = []
    pivots_vec = []
    for i in pivots_slice.keys():
        pivots_idx.append(i)
        pivots_vec.append(pivots_slice[i])
    return pivots_idx, pivots_vec

In [24]:
n_pivots = 1000
pivots_idx, pivots_vec = get_pivots_slice(pivots_dict, n_pivots)

### a small tf lab :)

In [10]:
embed = tf.Variable([[0,0],[1,1]])
embed_2 = tf.Variable(tf.identity(embed))
ao = tf.scatter_update(embed_2,[0],[[-5,5]])
diff = tf.reduce_sum((embed-embed_2)**2)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(diff))
sess.run(ao)
print(sess.run(diff))

0
50


### build graph with negative sampling

In [83]:
google_pretrain = np.load('/Users/zhang/MscProject_tweak2vec/word2vecModel/quora/w2v_google_50d.npy')

In [19]:
n_vocab = len(word2idx)
n_embedding = 50
reg_constant = 0.0001
n_sampled = 100
learning_rate = 0.001
epochs = 10
batch_size = 1000 # number of samples each iteration

In [20]:
train_graph = tf.Graph()
with train_graph.as_default():
    # input layer
    inputs = tf.placeholder(tf.int32, [batch_size], name='inputs')
    # labels is 2 dimensional as required by tf.nn.sampled_softmax_loss used for negative sampling.
    labels = tf.placeholder(tf.int32, [None, None], name='labels')
    
    # embedding layer
    init_width = 0.5 / n_embedding
    embedding = tf.Variable(tf.random_uniform((n_vocab, n_embedding), -init_width, init_width))
#     embedding = tf.Variable(google_pretrain)
    embed = tf.nn.embedding_lookup(embedding, inputs)

    # add regularization term
    embedding_copy = tf.Variable(tf.identity(embedding), trainable=False)
    update_embed_op = tf.scatter_update(embedding_copy,pivots_idx,pivots_vec)
    embed_copy = tf.nn.embedding_lookup(embedding_copy, inputs)
    reg_loss = reg_constant * tf.reduce_sum((embed-embed_copy)**2)
    
    # sampled softmax layer
    softmax_w = tf.Variable(tf.truncated_normal((n_vocab, n_embedding)), name="softmax_weights")
    softmax_b = tf.Variable(tf.zeros(n_vocab), name="softmax_bias")
    # Calculate the loss using negative sampling
    loss = tf.nn.sampled_softmax_loss(
        weights=softmax_w,
        biases=softmax_b,
        labels=labels,
        inputs=embed,
        num_sampled=n_sampled,
        num_classes=n_vocab)
    cost = tf.reduce_mean(loss)
    
#     total_cost = cost 
    total_cost = cost + reg_loss


    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(total_cost)

### training

In [25]:
for i in range(1,len(idx_pairs)):
    current_tokens = tokens_lst[i] * 1000000
    
    print("Tokens: ", current_tokens)
    print("Starting training at ", datetime.datetime.now())
    t0 = time.time()

    with train_graph.as_default():
        saver = tf.train.Saver()

    with tf.Session(graph=train_graph) as sess:
        iteration = 1
        loss = 0
        regular_loss = 0
        loss_best = 100
        loss_list = []
        iteration_best = 0
        sess.run(tf.global_variables_initializer())

        for e in range(1, epochs + 1):
            batches = get_batches(idx_pairs[i], batch_size)
            start = time.time()
            for x, y in batches:
                feed = {inputs: x,
                        labels: np.array(y)[:, None]}
                sess.run(update_embed_op)
                train_loss, _, regu_loss = sess.run([total_cost, optimizer, reg_loss], feed_dict=feed)
#                 train_loss, _ = sess.run([total_cost, optimizer], feed_dict=feed)

                loss += train_loss
                regular_loss += regu_loss

                if loss < loss_best:
                    W = sess.run(embedding).tolist()
                    iteration_best = iteration
                    loss_best = loss

                if iteration % 1000 == 0:
                    end = time.time()
                    loss_list.append(loss / 1000)
                    print("Epoch {}/{}".format(e, epochs),
                          "Iteration: {}".format(iteration),
                          "Avg. Training loss: {:.4f}".format(loss / 1000),
                          "Avg. Reg. loss: {:.4f}".format(regular_loss / 100),
                          "{:.4f} sec/batch".format((end - start) / 1000))


                    loss = 0
                    regular_loss = 0
                    start = time.time()
                iteration += 1
                
        np.save('w2v_pivots1000_'+str(tokens_lst[i])+'m.npy',np.array(W))        
        print("Finish training at ", datetime.datetime.now()) 
        print("-------------------------------------------------------------------------") 
        print("-------------------------------------------------------------------------")


Tokens:  3000000
Starting training at  2018-07-18 01:19:00.819679
Epoch 1/10 Iteration: 1000 Avg. Training loss: 6.3035 Avg. Reg. loss: 1.2262 0.0301 sec/batch
Epoch 1/10 Iteration: 2000 Avg. Training loss: 5.2347 Avg. Reg. loss: 1.3099 0.0212 sec/batch
Epoch 1/10 Iteration: 3000 Avg. Training loss: 4.8618 Avg. Reg. loss: 1.2842 0.0207 sec/batch
Epoch 1/10 Iteration: 4000 Avg. Training loss: 4.6888 Avg. Reg. loss: 1.2694 0.0207 sec/batch
Epoch 1/10 Iteration: 5000 Avg. Training loss: 4.6066 Avg. Reg. loss: 1.2544 0.0201 sec/batch
Epoch 1/10 Iteration: 6000 Avg. Training loss: 4.5476 Avg. Reg. loss: 1.2344 0.0380 sec/batch
Epoch 1/10 Iteration: 7000 Avg. Training loss: 4.5044 Avg. Reg. loss: 1.2177 0.0207 sec/batch
Epoch 1/10 Iteration: 8000 Avg. Training loss: 4.4524 Avg. Reg. loss: 1.2100 0.0203 sec/batch
Epoch 1/10 Iteration: 9000 Avg. Training loss: 4.4250 Avg. Reg. loss: 1.1858 0.0204 sec/batch
Epoch 1/10 Iteration: 10000 Avg. Training loss: 4.4005 Avg. Reg. loss: 1.1625 0.0227 sec

Epoch 4/10 Iteration: 87000 Avg. Training loss: 3.9761 Avg. Reg. loss: 0.7035 0.0203 sec/batch
Epoch 4/10 Iteration: 88000 Avg. Training loss: 3.9883 Avg. Reg. loss: 0.7013 0.0196 sec/batch
Epoch 4/10 Iteration: 89000 Avg. Training loss: 3.9719 Avg. Reg. loss: 0.6986 0.0205 sec/batch
Epoch 4/10 Iteration: 90000 Avg. Training loss: 3.9816 Avg. Reg. loss: 0.7007 0.0206 sec/batch
Epoch 4/10 Iteration: 91000 Avg. Training loss: 3.9648 Avg. Reg. loss: 0.6937 0.0205 sec/batch
Epoch 4/10 Iteration: 92000 Avg. Training loss: 3.9649 Avg. Reg. loss: 0.6954 0.0205 sec/batch
Epoch 4/10 Iteration: 93000 Avg. Training loss: 3.9704 Avg. Reg. loss: 0.6931 0.0205 sec/batch
Epoch 4/10 Iteration: 94000 Avg. Training loss: 3.9683 Avg. Reg. loss: 0.6922 0.0204 sec/batch
Epoch 4/10 Iteration: 95000 Avg. Training loss: 3.9588 Avg. Reg. loss: 0.6942 0.0204 sec/batch
Epoch 4/10 Iteration: 96000 Avg. Training loss: 3.9608 Avg. Reg. loss: 0.6906 0.0204 sec/batch
Epoch 5/10 Iteration: 97000 Avg. Training loss: 3.

Epoch 8/10 Iteration: 173000 Avg. Training loss: 3.8880 Avg. Reg. loss: 0.6679 0.0202 sec/batch
Epoch 8/10 Iteration: 174000 Avg. Training loss: 3.8910 Avg. Reg. loss: 0.6692 0.0200 sec/batch
Epoch 8/10 Iteration: 175000 Avg. Training loss: 3.8870 Avg. Reg. loss: 0.6724 0.0205 sec/batch
Epoch 8/10 Iteration: 176000 Avg. Training loss: 3.8908 Avg. Reg. loss: 0.6697 0.0207 sec/batch
Epoch 8/10 Iteration: 177000 Avg. Training loss: 3.8881 Avg. Reg. loss: 0.6661 0.0208 sec/batch
Epoch 8/10 Iteration: 178000 Avg. Training loss: 3.8764 Avg. Reg. loss: 0.6679 0.0207 sec/batch
Epoch 8/10 Iteration: 179000 Avg. Training loss: 3.8769 Avg. Reg. loss: 0.6702 0.0206 sec/batch
Epoch 8/10 Iteration: 180000 Avg. Training loss: 3.8827 Avg. Reg. loss: 0.6700 0.0208 sec/batch
Epoch 8/10 Iteration: 181000 Avg. Training loss: 3.8791 Avg. Reg. loss: 0.6748 0.0207 sec/batch
Epoch 8/10 Iteration: 182000 Avg. Training loss: 3.8809 Avg. Reg. loss: 0.6707 0.0203 sec/batch
Epoch 8/10 Iteration: 183000 Avg. Traini

Epoch 1/10 Iteration: 14000 Avg. Training loss: 4.3335 Avg. Reg. loss: 1.1570 0.0198 sec/batch
Epoch 1/10 Iteration: 15000 Avg. Training loss: 4.3352 Avg. Reg. loss: 1.1509 0.0200 sec/batch
Epoch 1/10 Iteration: 16000 Avg. Training loss: 4.2969 Avg. Reg. loss: 1.1272 0.0201 sec/batch
Epoch 2/10 Iteration: 17000 Avg. Training loss: 4.2740 Avg. Reg. loss: 1.1070 0.0170 sec/batch
Epoch 2/10 Iteration: 18000 Avg. Training loss: 4.2693 Avg. Reg. loss: 1.1150 0.0200 sec/batch
Epoch 2/10 Iteration: 19000 Avg. Training loss: 4.2641 Avg. Reg. loss: 1.0948 0.0195 sec/batch
Epoch 2/10 Iteration: 20000 Avg. Training loss: 4.2543 Avg. Reg. loss: 1.0734 0.0200 sec/batch
Epoch 2/10 Iteration: 21000 Avg. Training loss: 4.2544 Avg. Reg. loss: 1.0659 0.0185 sec/batch
Epoch 2/10 Iteration: 22000 Avg. Training loss: 4.2245 Avg. Reg. loss: 1.0630 0.0200 sec/batch
Epoch 2/10 Iteration: 23000 Avg. Training loss: 4.2246 Avg. Reg. loss: 1.0291 0.0201 sec/batch
Epoch 2/10 Iteration: 24000 Avg. Training loss: 4.

Epoch 7/10 Iteration: 101000 Avg. Training loss: 3.9311 Avg. Reg. loss: 0.7392 0.0189 sec/batch
Epoch 7/10 Iteration: 102000 Avg. Training loss: 3.9398 Avg. Reg. loss: 0.7404 0.0200 sec/batch
Epoch 7/10 Iteration: 103000 Avg. Training loss: 3.9330 Avg. Reg. loss: 0.7407 0.0198 sec/batch
Epoch 7/10 Iteration: 104000 Avg. Training loss: 3.9346 Avg. Reg. loss: 0.7380 0.0202 sec/batch
Epoch 7/10 Iteration: 105000 Avg. Training loss: 3.9263 Avg. Reg. loss: 0.7355 0.0199 sec/batch
Epoch 7/10 Iteration: 106000 Avg. Training loss: 3.9318 Avg. Reg. loss: 0.7314 0.0201 sec/batch
Epoch 7/10 Iteration: 107000 Avg. Training loss: 3.9380 Avg. Reg. loss: 0.7406 0.0202 sec/batch
Epoch 7/10 Iteration: 108000 Avg. Training loss: 3.9236 Avg. Reg. loss: 0.7324 0.0198 sec/batch
Epoch 7/10 Iteration: 109000 Avg. Training loss: 3.9304 Avg. Reg. loss: 0.7323 0.0201 sec/batch
Epoch 7/10 Iteration: 110000 Avg. Training loss: 3.9250 Avg. Reg. loss: 0.7314 0.0200 sec/batch
Epoch 7/10 Iteration: 111000 Avg. Traini

Epoch 3/10 Iteration: 23000 Avg. Training loss: 4.1952 Avg. Reg. loss: 0.9943 0.0200 sec/batch
Epoch 3/10 Iteration: 24000 Avg. Training loss: 4.1840 Avg. Reg. loss: 0.9870 0.0202 sec/batch
Epoch 4/10 Iteration: 25000 Avg. Training loss: 4.1717 Avg. Reg. loss: 0.9819 0.0153 sec/batch
Epoch 4/10 Iteration: 26000 Avg. Training loss: 4.1643 Avg. Reg. loss: 0.9721 0.0201 sec/batch
Epoch 4/10 Iteration: 27000 Avg. Training loss: 4.1583 Avg. Reg. loss: 0.9663 0.0200 sec/batch
Epoch 4/10 Iteration: 28000 Avg. Training loss: 4.1369 Avg. Reg. loss: 0.9628 0.0201 sec/batch
Epoch 4/10 Iteration: 29000 Avg. Training loss: 4.1493 Avg. Reg. loss: 0.9569 0.0201 sec/batch
Epoch 4/10 Iteration: 30000 Avg. Training loss: 4.1436 Avg. Reg. loss: 0.9462 0.0203 sec/batch
Epoch 4/10 Iteration: 31000 Avg. Training loss: 4.1252 Avg. Reg. loss: 0.9365 0.0202 sec/batch
Epoch 4/10 Iteration: 32000 Avg. Training loss: 4.1156 Avg. Reg. loss: 0.9340 0.0203 sec/batch
Epoch 5/10 Iteration: 33000 Avg. Training loss: 4.

Epoch 7/10 Iteration: 27000 Avg. Training loss: 4.0798 Avg. Reg. loss: 1.0661 0.0207 sec/batch
Epoch 7/10 Iteration: 28000 Avg. Training loss: 4.0695 Avg. Reg. loss: 1.0595 0.0206 sec/batch
Epoch 8/10 Iteration: 29000 Avg. Training loss: 4.0536 Avg. Reg. loss: 1.0585 0.0147 sec/batch
Epoch 8/10 Iteration: 30000 Avg. Training loss: 4.0583 Avg. Reg. loss: 1.0484 0.0206 sec/batch
Epoch 8/10 Iteration: 31000 Avg. Training loss: 4.0450 Avg. Reg. loss: 1.0471 0.0203 sec/batch
Epoch 8/10 Iteration: 32000 Avg. Training loss: 4.0375 Avg. Reg. loss: 1.0405 0.0208 sec/batch
Epoch 9/10 Iteration: 33000 Avg. Training loss: 4.0184 Avg. Reg. loss: 1.0419 0.0139 sec/batch
Epoch 9/10 Iteration: 34000 Avg. Training loss: 4.0208 Avg. Reg. loss: 1.0361 0.0203 sec/batch
Epoch 9/10 Iteration: 35000 Avg. Training loss: 4.0178 Avg. Reg. loss: 1.0327 0.0207 sec/batch
Epoch 9/10 Iteration: 36000 Avg. Training loss: 4.0060 Avg. Reg. loss: 1.0264 0.0200 sec/batch
Epoch 10/10 Iteration: 37000 Avg. Training loss: 3

In [38]:
print(pivots_dict[1])
print(W[1])

[-0.696653425693512, -0.4016351103782654, -0.15490229427814484, -0.153431236743927, 0.11273664981126785, -0.20632797479629517, -0.05852844938635826, 0.18393464386463165, 0.04037215933203697, -0.15603983402252197, -0.12679000198841095, 0.10461419820785522, -0.03136150911450386, -0.09917640686035156, -0.21953696012496948, -0.06557910144329071, -0.3572455048561096, -0.07304935902357101, 0.2829059362411499, 0.25940605998039246, 0.18046262860298157, -0.18454191088676453, -0.13335512578487396, -0.11446908116340637, -0.09217895567417145, -0.028645846992731094, 0.07994083315134048, -0.3566879630088806, -0.16788771748542786, -0.09856567531824112, -0.05210083723068237, -0.06661748886108398, 0.09986916929483414, 0.1596103459596634, -0.1205173209309578, -0.03440592437982559, 0.028155574575066566, -0.17301133275032043, -0.17946180701255798, -0.0042143226601183414, -0.18912769854068756, -0.17107552289962769, -0.14589069783687592, -0.08563197404146194, -0.043947286903858185, -0.053388938307762146, 0.