In [326]:
import collections
import d2lzh as d2l
import math
from mxnet import autograd,gluon,nd
from mxnet.gluon import data as gdata,loss as gloss,nn
import random
import sys
import time
import zipfile
import numpy as np

In [478]:
def skip_gram(center,context_and_negatives,embed_v,embed_u):
    v=embed_v(center)
    u=embed_u(context_and_negatives)
    pred=nd.batch_dot(v,u.swapaxes(1,2))
    return pred

In [479]:
loss=gloss.SigmoidBinaryCrossEntropyLoss()

In [480]:
def sigmd(x):
    return -math.log(1/(1+math.exp(-x)))

In [481]:
def read_file_data(filePath):
    f=open("data/wordTrain/train_cnn.lex")
    lines=f.readlines()
    raw_dataset=[st.split() for st in lines]
    #for st in raw_dataset[:3]:
     #   print('# tokens:',len(st),st[:5])    
    counter=collections.Counter([tk for st in raw_dataset for tk in st])
    counter=dict(filter(lambda x:x[1]>=5,counter.items()))
    idx_to_token=[tk for tk,_ in counter.items()]
    token_to_idx={tk:idx for idx, tk in enumerate(idx_to_token)}
    dataset=[[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in raw_dataset]
    num_tokens=sum([len(st) for st in dataset])
    return idx_to_token,token_to_idx,dataset,counter,num_tokens

In [482]:
idx_to_token,token_to_idx,dataset,counter,num_tokens=read_file_data("data/wordTrain/train_cnn.lex")
print(num_tokens)
print(dataset[0:10])

235634
[[0, 1, 2], [3, 4], [5, 1, 6, 2, 7, 8], [5, 1, 6, 2, 7, 8], [5, 1, 8], [5, 1, 8], [5, 1, 8], [9, 8], [3, 4], [3, 4, 10, 8]]


In [483]:
def discard(idx):
    return random.uniform(0,1)<1-math.sqrt(1e-4/counter[idx_to_token[idx]]*num_tokens)

In [484]:
subsampled_dataset=[[tk for tk in st if not discard(tk)] for st in dataset]
subsampled_dataset2=[x for x in subsampled_dataset if x!=[]]
print(subsampled_dataset[0:10])
num_tokens=sum([len(st) for st in subsampled_dataset])
print(subsampled_dataset2[0:10])
print(num_tokens)

[[1], [], [], [], [], [5], [], [], [], []]
[[1], [5], [3], [4], [10], [12, 16], [20], [21], [24], [22]]
13759


In [485]:
def get_centers_and_contexts(dataset,max_windows_size):
    centers,contexts=[],[]
    for st in dataset:
        if len(st)<2:
            continue
        centers+=st
        for center_i in range(len(st)):
            windows_size=random.randint(1,max_windows_size)
            indices=list(range(max(0,center_i-windows_size),min(len(st),center_i+1+windows_size)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    return centers,contexts

In [486]:
all_centers,all_contexts=get_centers_and_contexts(subsampled_dataset,5)

In [487]:
def get_negatives(all_contexts,sampling_weights,K):
    all_negatives,neg_candidates,i=[],[],0
    population=list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives=[]
        while len(negatives)<len(contexts)*K:
            if i==len(neg_candidates):
                i,neg_candidates=0,random.choices(population,sampling_weights,k=int(1e5))
            neg,i=neg_candidates[i],i+1
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

In [488]:
sampling_weights=[counter[W]**0.75 for W in idx_to_token]
all_negatives=get_negatives(all_contexts,sampling_weights,5)


In [489]:
def batchify(data):
    max_len=max(len(c)+len(n) for _,c,n in data)
    centers,contexts_negatives,masks,labels=[],[],[],[]
    for center,context,negatives in data:
        cur_len=len(context)+len(negatives)
        centers+=[center]
        contexts_negatives+=[context+negatives+[0]*(max_len-cur_len)]
        masks+=[[1]*cur_len+[0]*(max_len-cur_len)]
        labels+=[[1]*len(context)+[0]*(max_len-len(context))]
    return (nd.array(centers).reshape((-1,1)),nd.array(contexts_negatives),nd.array(masks),nd.array(labels))

In [490]:
batch_size=512
num_workers=0 if sys.platform.startswith('win32') else 4
dataset=gdata.ArrayDataset(all_centers,all_contexts,all_negatives)
data_iter=gdata.DataLoader(dataset,batch_size,shuffle=True,batchify_fn=batchify,num_workers=num_workers)

In [491]:
embed_size=20
net=nn.Sequential()
net.add(nn.Embedding(input_dim=len(idx_to_token),output_dim=embed_size),
        nn.Embedding(input_dim=len(idx_to_token),output_dim=embed_size))

In [492]:
def train(net,lr,num_epochs):
    ctx=d2l.try_gpu()
    net.initialize(ctx=ctx,force_reinit=True)
    trainer=gluon.Trainer(net.collect_params(),'adam',{'learning_rate':lr})
    for epoch in range(num_epochs):
        start,l_sum,n=time.time(),0.0,0
        for batch in data_iter:
            center,context_negative,mask,label=[data.as_in_context(ctx) for data in batch]
            with autograd.record():
                pred=skip_gram(center,context_negative,net[0],net[1])
                l=(loss(pred.reshape(label.shape),label,mask)*mask.shape[1]/mask.sum(axis=1))
            l.backward()
            trainer.step(batch_size)
            l_sum+=l.sum().asscalar()
            n+=l.size
        print('epoch %d, loss %.2f,time %.2fs'%(epoch+1,l_sum/n,time.time()-start))

In [493]:
train(net,0.005,100)

epoch 1, loss 0.69,time 0.33s
epoch 2, loss 0.66,time 0.29s
epoch 3, loss 0.58,time 0.28s
epoch 4, loss 0.49,time 0.28s
epoch 5, loss 0.44,time 0.28s
epoch 6, loss 0.44,time 0.29s
epoch 7, loss 0.43,time 0.28s
epoch 8, loss 0.43,time 0.27s
epoch 9, loss 0.43,time 0.28s
epoch 10, loss 0.43,time 0.28s
epoch 11, loss 0.43,time 0.28s
epoch 12, loss 0.43,time 0.28s
epoch 13, loss 0.42,time 0.28s
epoch 14, loss 0.42,time 0.28s
epoch 15, loss 0.42,time 0.27s
epoch 16, loss 0.42,time 0.27s
epoch 17, loss 0.41,time 0.28s
epoch 18, loss 0.41,time 0.29s
epoch 19, loss 0.41,time 0.28s
epoch 20, loss 0.40,time 0.27s
epoch 21, loss 0.40,time 0.32s
epoch 22, loss 0.40,time 0.28s
epoch 23, loss 0.39,time 0.28s
epoch 24, loss 0.39,time 0.27s
epoch 25, loss 0.39,time 0.28s
epoch 26, loss 0.39,time 0.28s
epoch 27, loss 0.39,time 0.27s
epoch 28, loss 0.38,time 0.28s
epoch 29, loss 0.38,time 0.29s
epoch 30, loss 0.38,time 0.28s
epoch 31, loss 0.38,time 0.28s
epoch 32, loss 0.38,time 0.27s
epoch 33, loss 0.

In [494]:
def get_similar_tokens(query_token,k,embed):
    W=embed.weight.data()
    x=W[token_to_idx[query_token]]
    cos=nd.dot(W,x)/(nd.sum(W*W,axis=1)*nd.sum(x*x)+1e-9).sqrt()
    topk=nd.topk(cos,k=k+1,ret_typ='indices').asnumpy().astype('int32')
    for i in topk[1:]:
        print('cosine sim=%.3f:%s'%(cos[i].asscalar(),(idx_to_token[i])))

In [507]:
get_similar_tokens('fread',20,net[0])

cosine sim=0.869:fwrite
cosine sim=0.553:memset
cosine sim=0.545:=
cosine sim=0.541:stdin
cosine sim=0.478:assert
cosine sim=0.469:goto
cosine sim=0.447:feof
cosine sim=0.446:!
cosine sim=0.440:&
cosine sim=0.407:fflush
cosine sim=0.403:strcmp
cosine sim=0.399:struct_var
cosine sim=0.398:NULL
cosine sim=0.397:call_func
cosine sim=0.390:/
cosine sim=0.384:+
cosine sim=0.380:malloc
cosine sim=0.378:>=
cosine sim=0.372:<=
cosine sim=0.366:-


In [None]:
#net.save('data/params/word.txt')

In [508]:
embedding_weight=net[0].weight.data()

61

In [510]:
def save_params(filePath,idx_to_token,embedding_weight):
    temp_one=np.array(idx_to_token)
    np.save(filePath+'idx_to_token.npy',temp_one)
    #temp_two=np.array(token_to_idx)
    #np.save(filePath+'token_to_idx.npy',temp_two)
    nd.save(filePath+'embedding_weight.txt',embedding_weight)
    return;

In [354]:
def read_params(filePath):
    idx_to_token=np.load(filePath+'idx_to_token.npy')
    #token_to_idx=np.load(filePath+'token_to_idx.npy')
    token_to_idx={tk:idx for idx,tk in enumerate(idx_to_token)}
    embedding_weight=nd.load(filePath+'embedding_weight.txt')
    return idx_to_token,token_to_idx,embedding_weight

In [511]:
save_params('data/params/word2Vec/',idx_to_token,embedding_weight)

In [356]:
temp_idx_to_token,temp_token_to_idx,temp_embedding_weight=read_params('data/params/word2Vec/')

In [357]:
print(temp_idx_to_token)
print(idx_to_token)

['#define' 'var' 'nums' 'struct' 'struct_name' 'data_type' '[' ']' ';' '}'
 'struct_var' '*' 'void' 'func' '(' ')' '{' 'printf' 'words' 'while'
 'switch' 'case' ':' 'call_func' 'break' 'fflush' 'stdin' 'gets' '=' 'if'
 '<' '|' '>' 'else' 'return' 'FILE' 'fopen' ',' 'for' 'fread' '+' 'sizeof'
 '!' '.' 'fclose' 'NULL' '-' 'do' 'continue' 'strcmp' '&' 'scanf' 'malloc'
 'fwrite' 'default' 'getch' 'feof' 'system' '<=' 'typedef' '>=' 'free' '/'
 'memset' 'static' 'goto' 'extern' '#endif' '#ifndef' 'assert' '#else']
['#define', 'var', 'nums', 'struct', 'struct_name', 'data_type', '[', ']', ';', '}', 'struct_var', '*', 'void', 'func', '(', ')', '{', 'printf', 'words', 'while', 'switch', 'case', ':', 'call_func', 'break', 'fflush', 'stdin', 'gets', '=', 'if', '<', '|', '>', 'else', 'return', 'FILE', 'fopen', ',', 'for', 'fread', '+', 'sizeof', '!', '.', 'fclose', 'NULL', '-', 'do', 'continue', 'strcmp', '&', 'scanf', 'malloc', 'fwrite', 'default', 'getch', 'feof', 'system', '<=', 'typedef', '>=

In [358]:
#temp_token_to_idx={tk:idx for idx,tk in enumerate(temp_idx_to_token)}
print(temp_token_to_idx)

{'#define': 0, 'var': 1, 'nums': 2, 'struct': 3, 'struct_name': 4, 'data_type': 5, '[': 6, ']': 7, ';': 8, '}': 9, 'struct_var': 10, '*': 11, 'void': 12, 'func': 13, '(': 14, ')': 15, '{': 16, 'printf': 17, 'words': 18, 'while': 19, 'switch': 20, 'case': 21, ':': 22, 'call_func': 23, 'break': 24, 'fflush': 25, 'stdin': 26, 'gets': 27, '=': 28, 'if': 29, '<': 30, '|': 31, '>': 32, 'else': 33, 'return': 34, 'FILE': 35, 'fopen': 36, ',': 37, 'for': 38, 'fread': 39, '+': 40, 'sizeof': 41, '!': 42, '.': 43, 'fclose': 44, 'NULL': 45, '-': 46, 'do': 47, 'continue': 48, 'strcmp': 49, '&': 50, 'scanf': 51, 'malloc': 52, 'fwrite': 53, 'default': 54, 'getch': 55, 'feof': 56, 'system': 57, '<=': 58, 'typedef': 59, '>=': 60, 'free': 61, '/': 62, 'memset': 63, 'static': 64, 'goto': 65, 'extern': 66, '#endif': 67, '#ifndef': 68, 'assert': 69, '#else': 70}


In [320]:
print(token_to_idx)

{'#define': 0, 'var': 1, 'nums': 2, 'struct': 3, 'struct_name': 4, 'data_type': 5, '[': 6, ']': 7, ';': 8, '}': 9, 'struct_var': 10, '*': 11, 'void': 12, 'func': 13, '(': 14, ')': 15, '{': 16, 'printf': 17, 'words': 18, 'while': 19, 'switch': 20, 'case': 21, ':': 22, 'call_func': 23, 'break': 24, 'fflush': 25, 'stdin': 26, 'gets': 27, '=': 28, 'if': 29, '<': 30, '|': 31, '>': 32, 'else': 33, 'return': 34, 'FILE': 35, 'fopen': 36, ',': 37, 'for': 38, 'fread': 39, '+': 40, 'sizeof': 41, '!': 42, '.': 43, 'fclose': 44, 'NULL': 45, '-': 46, 'do': 47, 'continue': 48, 'strcmp': 49, '&': 50, 'scanf': 51, 'malloc': 52, 'fwrite': 53, 'default': 54, 'getch': 55, 'feof': 56, 'system': 57, '<=': 58, 'typedef': 59, '>=': 60, 'free': 61, '/': 62, 'memset': 63, 'static': 64, 'goto': 65, 'extern': 66, '#endif': 67, '#ifndef': 68, 'assert': 69, '#else': 70, '#undef': 71, 'union': 72}
