# 第二课 词向量

第二课学习目标
- 学习词向量的概念
- 用Skip-thought模型训练词向量
- 学习使用PyTorch dataset和dataloader
- 学习定义PyTorch模型
- 学习torch.nn中常见的Module
    - Embedding
- 学习常见的PyTorch operations
    - bmm
    - logsigmoid
- 保存和读取PyTorch模型
    

第二课使用的训练数据可以从以下链接下载到。

链接:https://pan.baidu.com/s/1tFeK3mXuVXEy3EMarfeWvg  密码:v2z5

在这一份notebook中，我们会（尽可能）尝试复现论文[Distributed Representations of Words and Phrases and their Compositionality](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)中训练词向量的方法. 我们会实现Skip-gram模型，并且使用论文中noice contrastive sampling的目标函数。

这篇论文有很多模型实现的细节，这些细节对于词向量的好坏至关重要。我们虽然无法完全复现论文中的实验结果，主要是由于计算资源等各种细节原因，但是我们还是可以大致展示如何训练词向量。

以下是一些我们没有实现的细节
- subsampling：参考论文section 2.3

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from torch.nn.parameter import Parameter

from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

USE_CUDA = torch.cuda.is_available()

# 为了保证实验结果可以复现，我们经常会把各种random seed固定在某一个值
random.seed(53113)
np.random.seed(53113)
torch.manual_seed(53113)
if USE_CUDA:
    torch.cuda.manual_seed(53113)
    
# 设定一些超参数
    
K = 100 # number of negative samples
C = 3 # nearby words threshold
NUM_EPOCHS = 2 # The number of epochs of training
MAX_VOCAB_SIZE = 30000 # the vocabulary size
BATCH_SIZE = 128 # the batch size
LEARNING_RATE = 0.2 # the initial learning rate
EMBEDDING_SIZE = 100
       
    
LOG_FILE = "word-embedding.log"

# tokenize函数，把一篇文本转化成一个个单词
def word_tokenize(text):
    return text.split()

- 从文本文件中读取所有的文字，通过这些文本创建一个vocabulary
- 由于单词数量可能太大，我们只选取最常见的MAX_VOCAB_SIZE个单词
- 我们添加一个UNK单词表示所有不常见的单词
- 我们需要记录单词到index的mapping，以及index到单词的mapping，单词的count，单词的(normalized) frequency，以及单词总数。

In [2]:
with open("text8.train.txt", "r") as fin:
    text = fin.read()
    
text = [w for w in word_tokenize(text.lower())]
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()] 
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}

word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs) # 用来做 negative sampling
VOCAB_SIZE = len(idx_to_word)
VOCAB_SIZE

30000

### 实现Dataloader

一个dataloader需要以下内容：

- 把所有text编码成数字，然后用subsampling预处理这些文字。
- 保存vocabulary，单词count，normalized word frequency
- 每个iteration sample一个中心词
- 根据当前的中心词返回context单词
- 根据中心词sample一些negative单词
- 返回单词的counts

这里有一个好的tutorial介绍如何使用[PyTorch dataloader](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html).
为了使用dataloader，我们需要定义以下两个function:

- ```__len__``` function需要返回整个数据集中有多少个item
- ```__get__``` 根据给定的index返回一个item

有了dataloader之后，我们可以轻松随机打乱整个数据集，拿到一个batch的数据等等。

In [10]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
        ''' text: a list of words, all text from the training dataset
            word_to_idx: the dictionary from word to idx
            idx_to_word: idx to word mapping
            word_freq: the frequency of each word
            word_counts: the word counts
        '''
        super(WordEmbeddingDataset, self).__init__()
        self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]
        self.text_encoded = torch.Tensor(self.text_encoded).long()
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.Tensor(word_freqs)
        self.word_counts = torch.Tensor(word_counts)
        
    def __len__(self):
        ''' 返回整个数据集（所有单词）的长度
        '''
        return len(self.text_encoded)
        
    def __getitem__(self, idx):
        ''' 这个function返回以下数据用于训练
            - 中心词
            - 这个单词附近的(positive)单词
            - 随机采样的K个单词作为negative sample
        '''
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))
        pos_indices = [i%len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices] 
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
        
        return center_word, pos_words, neg_words 

创建dataset和dataloader

In [None]:
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)     

### 定义PyTorch模型

In [6]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        ''' 初始化输出和输出embedding
        '''
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        initrange = 0.5 / self.embed_size
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.out_embed.weight.data.uniform_(-initrange, initrange)
        
        
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        
        
    def forward(self, input_labels, pos_labels, neg_labels):
        '''
        input_labels: 中心词, [batch_size]
        pos_labels: 中心词周围 context window 出现过的单词 [batch_size * (window_size * 2)]
        neg_labelss: 中心词周围没有出现过的单词，从 negative sampling 得到 [batch_size, (window_size * 2 * K)]
        
        return: loss, [batch_size]
        '''
        
        batch_size = input_labels.size(0)
        
        input_embedding = self.in_embed(input_labels) # B * embed_size
        pos_embedding = self.out_embed(pos_labels) # B * (2*C) * embed_size
        neg_embedding = self.out_embed(neg_labels) # B * (2*C * K) * embed_size
      
        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze() # B * (2*C)
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # B * (2*C*K)

        log_pos = F.logsigmoid(log_pos).sum(1)
        log_neg = F.logsigmoid(log_neg).sum(1) # batch_size
       
        loss = log_pos + log_neg
        
        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()
        

定义一个模型以及把模型移动到GPU

In [None]:
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()

下面是评估模型的代码，以及训练模型的代码

In [7]:
def evaluate(filename, embedding_weights): 
    if filename.endswith(".csv"):
        data = pd.read_csv(filename, sep=",")
    else:
        data = pd.read_csv(filename, sep="\t")
    human_similarity = []
    model_similarity = []
    for i in data.iloc[:, 0:2].index:
        word1, word2 = data.iloc[i, 0], data.iloc[i, 1]
        if word1 not in word_to_idx or word2 not in word_to_idx:
            continue
        else:
            word1_idx, word2_idx = word_to_idx[word1], word_to_idx[word2]
            word1_embed, word2_embed = embedding_weights[[word1_idx]], embedding_weights[[word2_idx]]
            model_similarity.append(float(sklearn.metrics.pairwise.cosine_similarity(word1_embed, word2_embed)))
            human_similarity.append(float(data.iloc[i, 2]))

    return scipy.stats.spearmanr(human_similarity, model_similarity)# , model_similarity

def find_nearest(word):
    index = word_to_idx[word]
    embedding = embedding_weights[index]
    cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
    return [idx_to_word[i] for i in cos_dis.argsort()[:10]]

训练模型：
- 模型一般需要训练若干个epoch
- 每个epoch我们都把所有的数据分成若干个batch
- 把每个batch的输入和输出都包装成cuda tensor
- forward pass，通过输入的句子预测每个单词的下一个单词
- 用模型的预测和正确的下一个单词计算cross entropy loss
- 清空模型当前gradient
- backward pass
- 更新模型参数
- 每隔一定的iteration输出模型在当前iteration的loss，以及在验证数据集上做模型的评估

In [9]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        
        
        # TODO
        input_labels = input_labels.long()
        pos_labels = pos_labels.long()
        neg_labels = neg_labels.long()
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
            
        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            with open(LOG_FILE, "a") as fout:
                fout.write("epoch: {}, iter: {}, loss: {}\n".format(e, i, loss.item()))
                print("epoch: {}, iter: {}, loss: {}".format(e, i, loss.item()))
            
        
        if i % 2000 == 0:
            embedding_weights = model.input_embeddings()
            sim_simlex = evaluate("simlex-999.txt", embedding_weights)
            sim_men = evaluate("men.txt", embedding_weights)
            sim_353 = evaluate("wordsim353.csv", embedding_weights)
            with open(LOG_FILE, "a") as fout:
                print("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format(
                    e, i, sim_simlex, sim_men, sim_353, find_nearest("monster")))
                fout.write("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format(
                    e, i, sim_simlex, sim_men, sim_353, find_nearest("monster")))
                
    embedding_weights = model.input_embeddings()
    np.save("embedding-{}".format(EMBEDDING_SIZE), embedding_weights)
    torch.save(model.state_dict(), "embedding-{}.th".format(EMBEDDING_SIZE))

epoch: 0, iter: 0, loss: 420.04736328125
epoch: 0, iteration: 0, simlex-999: SpearmanrResult(correlation=0.002806243285464091, pvalue=0.9309107582703205), men: SpearmanrResult(correlation=-0.03578915454199749, pvalue=0.06854012381329619), sim353: SpearmanrResult(correlation=0.02468906830123471, pvalue=0.6609497549092586), nearest to monster: ['monster', 'communism', 'bosses', 'microprocessors', 'infectious', 'debussy', 'unesco', 'tantamount', 'offices', 'tischendorf']

epoch: 0, iter: 100, loss: 278.9967041015625
epoch: 0, iter: 200, loss: 248.71990966796875
epoch: 0, iter: 300, loss: 202.95816040039062
epoch: 0, iter: 400, loss: 157.04776000976562
epoch: 0, iter: 500, loss: 137.83531188964844
epoch: 0, iter: 600, loss: 121.03585815429688
epoch: 0, iter: 700, loss: 105.300537109375
epoch: 0, iter: 800, loss: 114.10055541992188
epoch: 0, iter: 900, loss: 104.72723388671875
epoch: 0, iter: 1000, loss: 99.03569030761719
epoch: 0, iter: 1100, loss: 95.2179946899414
epoch: 0, iter: 1200, lo

epoch: 0, iter: 12100, loss: 33.59938430786133
epoch: 0, iter: 12200, loss: 32.594879150390625
epoch: 0, iter: 12300, loss: 32.42393493652344
epoch: 0, iter: 12400, loss: 32.8863410949707
epoch: 0, iter: 12500, loss: 39.303016662597656
epoch: 0, iter: 12600, loss: 33.103118896484375
epoch: 0, iter: 12700, loss: 36.31195068359375
epoch: 0, iter: 12800, loss: 33.8329963684082
epoch: 0, iter: 12900, loss: 32.499595642089844
epoch: 0, iter: 13000, loss: 33.224632263183594
epoch: 0, iter: 13100, loss: 33.931884765625
epoch: 0, iter: 13200, loss: 33.35892105102539
epoch: 0, iter: 13300, loss: 33.33966064453125
epoch: 0, iter: 13400, loss: 34.09075164794922
epoch: 0, iter: 13500, loss: 33.52397918701172
epoch: 0, iter: 13600, loss: 34.18444061279297
epoch: 0, iter: 13700, loss: 33.96720886230469
epoch: 0, iter: 13800, loss: 34.23271942138672
epoch: 0, iter: 13900, loss: 33.36094665527344
epoch: 0, iter: 14000, loss: 35.998287200927734
epoch: 0, iteration: 14000, simlex-999: SpearmanrResult(co

epoch: 0, iter: 24200, loss: 31.729236602783203
epoch: 0, iter: 24300, loss: 31.751216888427734
epoch: 0, iter: 24400, loss: 31.54802131652832
epoch: 0, iter: 24500, loss: 31.819448471069336
epoch: 0, iter: 24600, loss: 31.87582778930664
epoch: 0, iter: 24700, loss: 32.44230651855469
epoch: 0, iter: 24800, loss: 32.13909149169922
epoch: 0, iter: 24900, loss: 31.6838321685791
epoch: 0, iter: 25000, loss: 32.01523208618164
epoch: 0, iter: 25100, loss: 31.727489471435547
epoch: 0, iter: 25200, loss: 32.378543853759766
epoch: 0, iter: 25300, loss: 32.155052185058594
epoch: 0, iter: 25400, loss: 32.30049514770508
epoch: 0, iter: 25500, loss: 32.10628128051758
epoch: 0, iter: 25600, loss: 32.01287841796875
epoch: 0, iter: 25700, loss: 32.22496032714844
epoch: 0, iter: 25800, loss: 32.15202331542969
epoch: 0, iter: 25900, loss: 32.43567657470703
epoch: 0, iter: 26000, loss: 31.745975494384766
epoch: 0, iteration: 26000, simlex-999: SpearmanrResult(correlation=0.08715629365703002, pvalue=0.006

epoch: 0, iter: 36400, loss: 31.05801773071289
epoch: 0, iter: 36500, loss: 31.969802856445312
epoch: 0, iter: 36600, loss: 31.290489196777344
epoch: 0, iter: 36700, loss: 31.409465789794922
epoch: 0, iter: 36800, loss: 31.444076538085938
epoch: 0, iter: 36900, loss: 31.494474411010742
epoch: 0, iter: 37000, loss: 31.12554931640625
epoch: 0, iter: 37100, loss: 31.744049072265625
epoch: 0, iter: 37200, loss: 31.608917236328125
epoch: 0, iter: 37300, loss: 31.441722869873047
epoch: 0, iter: 37400, loss: 31.544227600097656
epoch: 0, iter: 37500, loss: 31.359806060791016
epoch: 0, iter: 37600, loss: 31.130847930908203
epoch: 0, iter: 37700, loss: 32.14916229248047
epoch: 0, iter: 37800, loss: 31.148212432861328
epoch: 0, iter: 37900, loss: 31.835248947143555
epoch: 0, iter: 38000, loss: 31.421974182128906
epoch: 0, iteration: 38000, simlex-999: SpearmanrResult(correlation=0.09401565185194706, pvalue=0.003602024110356835), men: SpearmanrResult(correlation=0.09723017395213002, pvalue=7.10171

epoch: 0, iter: 48500, loss: 31.48914909362793
epoch: 0, iter: 48600, loss: 31.45376205444336
epoch: 0, iter: 48700, loss: 30.948339462280273
epoch: 0, iter: 48800, loss: 30.842824935913086
epoch: 0, iter: 48900, loss: 30.931697845458984
epoch: 0, iter: 49000, loss: 31.468204498291016
epoch: 0, iter: 49100, loss: 31.04726791381836
epoch: 0, iter: 49200, loss: 31.148698806762695
epoch: 0, iter: 49300, loss: 31.295198440551758
epoch: 0, iter: 49400, loss: 31.415983200073242
epoch: 0, iter: 49500, loss: 31.53121566772461
epoch: 0, iter: 49600, loss: 30.391773223876953
epoch: 0, iter: 49700, loss: 31.365924835205078
epoch: 0, iter: 49800, loss: 30.920448303222656
epoch: 0, iter: 49900, loss: 30.881540298461914
epoch: 0, iter: 50000, loss: 31.272510528564453
epoch: 0, iteration: 50000, simlex-999: SpearmanrResult(correlation=0.10413335271622073, pvalue=0.0012554545146236879), men: SpearmanrResult(correlation=0.10361287469529604, pvalue=1.251734153196469e-07), sim353: SpearmanrResult(correla

epoch: 0, iter: 60600, loss: 31.264543533325195
epoch: 0, iter: 60700, loss: 31.218517303466797
epoch: 0, iter: 60800, loss: 31.23360824584961
epoch: 0, iter: 60900, loss: 30.85096549987793
epoch: 0, iter: 61000, loss: 30.768386840820312
epoch: 0, iter: 61100, loss: 31.50748634338379
epoch: 0, iter: 61200, loss: 30.46345329284668
epoch: 0, iter: 61300, loss: 30.543607711791992
epoch: 0, iter: 61400, loss: 30.628982543945312
epoch: 0, iter: 61500, loss: 31.45627784729004
epoch: 0, iter: 61600, loss: 31.070459365844727
epoch: 0, iter: 61700, loss: 30.569217681884766
epoch: 0, iter: 61800, loss: 30.83639907836914
epoch: 0, iter: 61900, loss: 31.005922317504883
epoch: 0, iter: 62000, loss: 31.41488265991211
epoch: 0, iteration: 62000, simlex-999: SpearmanrResult(correlation=0.11119875283206068, pvalue=0.0005685786512505508), men: SpearmanrResult(correlation=0.11318488733549789, pvalue=7.599257092187759e-09), sim353: SpearmanrResult(correlation=0.12779805415765372, pvalue=0.0226465488272404

epoch: 0, iter: 72700, loss: 30.91510581970215
epoch: 0, iter: 72800, loss: 30.70620346069336
epoch: 0, iter: 72900, loss: 30.421703338623047
epoch: 0, iter: 73000, loss: 30.53826141357422
epoch: 0, iter: 73100, loss: 30.770679473876953
epoch: 0, iter: 73200, loss: 31.04900360107422
epoch: 0, iter: 73300, loss: 30.795854568481445
epoch: 0, iter: 73400, loss: 31.299104690551758
epoch: 0, iter: 73500, loss: 30.484947204589844
epoch: 0, iter: 73600, loss: 30.79161834716797
epoch: 0, iter: 73700, loss: 30.636621475219727
epoch: 0, iter: 73800, loss: 31.00129508972168
epoch: 0, iter: 73900, loss: 30.91973114013672
epoch: 0, iter: 74000, loss: 31.55290985107422
epoch: 0, iteration: 74000, simlex-999: SpearmanrResult(correlation=0.11672803148915531, pvalue=0.0002961005658581428), men: SpearmanrResult(correlation=0.11817601695076835, pvalue=1.6031687449902205e-09), sim353: SpearmanrResult(correlation=0.15298232562148392, pvalue=0.006267834790300931), nearest to monster: ['monster', 'angel', 'l

epoch: 0, iter: 84800, loss: 30.84918975830078
epoch: 0, iter: 84900, loss: 30.95672035217285
epoch: 0, iter: 85000, loss: 31.12570571899414
epoch: 0, iter: 85100, loss: 31.057252883911133
epoch: 0, iter: 85200, loss: 30.39339828491211
epoch: 0, iter: 85300, loss: 30.523571014404297
epoch: 0, iter: 85400, loss: 30.765701293945312
epoch: 0, iter: 85500, loss: 30.65972137451172
epoch: 0, iter: 85600, loss: 30.2365779876709
epoch: 0, iter: 85700, loss: 31.060688018798828
epoch: 0, iter: 85800, loss: 31.084121704101562
epoch: 0, iter: 85900, loss: 30.77812957763672
epoch: 0, iter: 86000, loss: 30.55185890197754
epoch: 0, iteration: 86000, simlex-999: SpearmanrResult(correlation=0.12072190676944367, pvalue=0.00018154682975915078), men: SpearmanrResult(correlation=0.1252523395746619, pvalue=1.577244824410371e-10), sim353: SpearmanrResult(correlation=0.1690460146471711, pvalue=0.002490881483585671), nearest to monster: ['monster', 'blade', 'leg', 'angel', 'boat', 'tail', 'bird', 'mirror', 'le

epoch: 0, iter: 96900, loss: 30.507057189941406
epoch: 0, iter: 97000, loss: 30.755821228027344
epoch: 0, iter: 97100, loss: 30.22985076904297
epoch: 0, iter: 97200, loss: 30.947574615478516
epoch: 0, iter: 97300, loss: 30.583507537841797
epoch: 0, iter: 97400, loss: 30.67584991455078
epoch: 0, iter: 97500, loss: 31.08060073852539
epoch: 0, iter: 97600, loss: 30.564102172851562
epoch: 0, iter: 97700, loss: 30.59963607788086
epoch: 0, iter: 97800, loss: 31.315624237060547
epoch: 0, iter: 97900, loss: 31.017738342285156
epoch: 0, iter: 98000, loss: 30.729049682617188
epoch: 0, iteration: 98000, simlex-999: SpearmanrResult(correlation=0.1246043454563031, pvalue=0.00011121651888022881), men: SpearmanrResult(correlation=0.13216585436099, pvalue=1.4393399261301587e-11), sim353: SpearmanrResult(correlation=0.17839479356905732, pvalue=0.001401368292639592), nearest to monster: ['monster', 'blade', 'angel', 'leg', 'mirror', 'shield', 'bird', 'tail', 'boat', 'signature']

epoch: 0, iter: 98100, 

epoch: 0, iter: 108800, loss: 30.220489501953125
epoch: 0, iter: 108900, loss: 30.999284744262695
epoch: 0, iter: 109000, loss: 31.053329467773438
epoch: 0, iter: 109100, loss: 30.955081939697266
epoch: 0, iter: 109200, loss: 30.715665817260742
epoch: 0, iter: 109300, loss: 30.646869659423828
epoch: 0, iter: 109400, loss: 30.617048263549805
epoch: 0, iter: 109500, loss: 31.204490661621094
epoch: 0, iter: 109600, loss: 30.811479568481445
epoch: 0, iter: 109700, loss: 30.87088394165039
epoch: 0, iter: 109800, loss: 30.969287872314453
epoch: 0, iter: 109900, loss: 30.64400291442871
epoch: 0, iter: 110000, loss: 30.75538444519043
epoch: 0, iteration: 110000, simlex-999: SpearmanrResult(correlation=0.13088839031890218, pvalue=4.880473123942339e-05), men: SpearmanrResult(correlation=0.13896681910256206, pvalue=1.2053636316763994e-12), sim353: SpearmanrResult(correlation=0.20021881116883977, pvalue=0.0003271445558931211), nearest to monster: ['monster', 'blade', 'camera', 'leg', 'shield', 'el

epoch: 1, iter: 1100, loss: 30.46042251586914
epoch: 1, iter: 1200, loss: 30.88376235961914
epoch: 1, iter: 1300, loss: 30.545751571655273
epoch: 1, iter: 1400, loss: 30.541282653808594
epoch: 1, iter: 1500, loss: 30.788883209228516
epoch: 1, iter: 1600, loss: 30.412235260009766
epoch: 1, iter: 1700, loss: 30.570415496826172
epoch: 1, iter: 1800, loss: 30.742263793945312
epoch: 1, iter: 1900, loss: 30.20556640625
epoch: 1, iter: 2000, loss: 30.579498291015625
epoch: 1, iteration: 2000, simlex-999: SpearmanrResult(correlation=0.13750886561871162, pvalue=1.9667970854520583e-05), men: SpearmanrResult(correlation=0.14216903853907206, pvalue=3.5913225784003253e-13), sim353: SpearmanrResult(correlation=0.20737145549247832, pvalue=0.00019612168069552233), nearest to monster: ['monster', 'blade', 'camera', 'module', 'robot', 'boat', 'leg', 'elephant', 'harp', 'pen']

epoch: 1, iter: 2100, loss: 31.068511962890625
epoch: 1, iter: 2200, loss: 30.329666137695312
epoch: 1, iter: 2300, loss: 30.718

epoch: 1, iter: 13400, loss: 30.430782318115234
epoch: 1, iter: 13500, loss: 30.365447998046875
epoch: 1, iter: 13600, loss: 30.273536682128906
epoch: 1, iter: 13700, loss: 30.858108520507812
epoch: 1, iter: 13800, loss: 30.77298927307129
epoch: 1, iter: 13900, loss: 31.031143188476562
epoch: 1, iter: 14000, loss: 30.615827560424805
epoch: 1, iteration: 14000, simlex-999: SpearmanrResult(correlation=0.1403276597185061, pvalue=1.3185922971315998e-05), men: SpearmanrResult(correlation=0.14529215462232734, pvalue=1.0734198813575383e-13), sim353: SpearmanrResult(correlation=0.22418410664878283, pvalue=5.495482632416603e-05), nearest to monster: ['monster', 'blade', 'bird', 'boat', 'robot', 'mine', 'module', 'camera', 'giant', 'harp']

epoch: 1, iter: 14100, loss: 30.48816680908203
epoch: 1, iter: 14200, loss: 30.806354522705078
epoch: 1, iter: 14300, loss: 29.96129035949707
epoch: 1, iter: 14400, loss: 30.932781219482422
epoch: 1, iter: 14500, loss: 30.7196102142334
epoch: 1, iter: 14600, 

epoch: 1, iter: 25600, loss: 30.703933715820312
epoch: 1, iter: 25700, loss: 30.121395111083984
epoch: 1, iter: 25800, loss: 30.44470977783203
epoch: 1, iter: 25900, loss: 30.887786865234375
epoch: 1, iter: 26000, loss: 30.558914184570312
epoch: 1, iteration: 26000, simlex-999: SpearmanrResult(correlation=0.1440751174505626, pvalue=7.656767087120004e-06), men: SpearmanrResult(correlation=0.1491477742745481, pvalue=2.3302203512655484e-14), sim353: SpearmanrResult(correlation=0.23077736171791446, pvalue=3.247610265381441e-05), nearest to monster: ['monster', 'blade', 'bird', 'robot', 'mine', 'elephant', 'harp', 'triangle', 'pen', 'reed']

epoch: 1, iter: 26100, loss: 30.59500503540039
epoch: 1, iter: 26200, loss: 30.334857940673828
epoch: 1, iter: 26300, loss: 30.802188873291016
epoch: 1, iter: 26400, loss: 30.327043533325195
epoch: 1, iter: 26500, loss: 30.643577575683594
epoch: 1, iter: 26600, loss: 30.822498321533203
epoch: 1, iter: 26700, loss: 30.609739303588867
epoch: 1, iter: 2680

epoch: 1, iter: 37800, loss: 30.280288696289062
epoch: 1, iter: 37900, loss: 30.579071044921875
epoch: 1, iter: 38000, loss: 30.68809700012207
epoch: 1, iteration: 38000, simlex-999: SpearmanrResult(correlation=0.14842186148466413, pvalue=4.006345881913544e-06), men: SpearmanrResult(correlation=0.1535692327051543, pvalue=3.845826952077038e-15), sim353: SpearmanrResult(correlation=0.24061126807866595, pvalue=1.4397516474117237e-05), nearest to monster: ['monster', 'blade', 'bird', 'mine', 'robot', 'reed', 'giant', 'ghost', 'enigma', 'harp']

epoch: 1, iter: 38100, loss: 30.60862922668457
epoch: 1, iter: 38200, loss: 30.42845916748047
epoch: 1, iter: 38300, loss: 30.334047317504883
epoch: 1, iter: 38400, loss: 30.224014282226562
epoch: 1, iter: 38500, loss: 30.38711166381836
epoch: 1, iter: 38600, loss: 30.579326629638672
epoch: 1, iter: 38700, loss: 30.49921417236328
epoch: 1, iter: 38800, loss: 30.80820083618164
epoch: 1, iter: 38900, loss: 31.00635528564453
epoch: 1, iter: 39000, loss

epoch: 1, iter: 49900, loss: 30.408056259155273
epoch: 1, iter: 50000, loss: 30.49826431274414
epoch: 1, iteration: 50000, simlex-999: SpearmanrResult(correlation=0.1525143841833279, pvalue=2.140523243449846e-06), men: SpearmanrResult(correlation=0.15705781004352684, pvalue=8.938073731951767e-16), sim353: SpearmanrResult(correlation=0.24303933091018648, pvalue=1.1714453427503608e-05), nearest to monster: ['monster', 'robot', 'mine', 'ghost', 'blade', 'triangle', 'bird', 'mirror', 'pen', 'trilogy']

epoch: 1, iter: 50100, loss: 30.582521438598633
epoch: 1, iter: 50200, loss: 30.166404724121094
epoch: 1, iter: 50300, loss: 30.79269790649414
epoch: 1, iter: 50400, loss: 30.7398738861084
epoch: 1, iter: 50500, loss: 30.5670108795166
epoch: 1, iter: 50600, loss: 30.718910217285156
epoch: 1, iter: 50700, loss: 30.94159507751465
epoch: 1, iter: 50800, loss: 30.046207427978516
epoch: 1, iter: 50900, loss: 30.098331451416016
epoch: 1, iter: 51000, loss: 29.920578002929688
epoch: 1, iter: 51100,

epoch: 1, iter: 62000, loss: 30.466468811035156
epoch: 1, iteration: 62000, simlex-999: SpearmanrResult(correlation=0.1572838048897373, pvalue=1.0096640956970655e-06), men: SpearmanrResult(correlation=0.16137007881415613, pvalue=1.405264112416452e-16), sim353: SpearmanrResult(correlation=0.25097676215992787, pvalue=5.8801359002641125e-06), nearest to monster: ['monster', 'robot', 'pen', 'ghost', 'giant', 'cow', 'mine', 'storyline', 'bird', 'blade']

epoch: 1, iter: 62100, loss: 30.18429946899414
epoch: 1, iter: 62200, loss: 30.559833526611328
epoch: 1, iter: 62300, loss: 30.80440902709961
epoch: 1, iter: 62400, loss: 30.450206756591797
epoch: 1, iter: 62500, loss: 30.552818298339844
epoch: 1, iter: 62600, loss: 30.82094383239746
epoch: 1, iter: 62700, loss: 30.254344940185547
epoch: 1, iter: 62800, loss: 30.72846221923828
epoch: 1, iter: 62900, loss: 30.654434204101562
epoch: 1, iter: 63000, loss: 30.073328018188477
epoch: 1, iter: 63100, loss: 30.521087646484375
epoch: 1, iter: 63200,

epoch: 1, iteration: 74000, simlex-999: SpearmanrResult(correlation=0.1617221261012147, pvalue=4.916430156015046e-07), men: SpearmanrResult(correlation=0.16269528376899162, pvalue=7.876945371094953e-17), sim353: SpearmanrResult(correlation=0.2568775911038176, pvalue=3.4697706327940896e-06), nearest to monster: ['monster', 'giant', 'clown', 'robot', 'triangle', 'killer', 'horn', 'storyline', 'bird', 'pen']

epoch: 1, iter: 74100, loss: 30.254886627197266
epoch: 1, iter: 74200, loss: 29.888710021972656
epoch: 1, iter: 74300, loss: 30.417236328125
epoch: 1, iter: 74400, loss: 30.457595825195312
epoch: 1, iter: 74500, loss: 31.00020980834961
epoch: 1, iter: 74600, loss: 30.30846405029297
epoch: 1, iter: 74700, loss: 30.387718200683594
epoch: 1, iter: 74800, loss: 30.376087188720703
epoch: 1, iter: 74900, loss: 30.061664581298828
epoch: 1, iter: 75000, loss: 30.370288848876953
epoch: 1, iter: 75100, loss: 30.63956642150879
epoch: 1, iter: 75200, loss: 30.442768096923828
epoch: 1, iter: 7530

epoch: 1, iter: 86100, loss: 30.831403732299805
epoch: 1, iter: 86200, loss: 30.484277725219727
epoch: 1, iter: 86300, loss: 30.60747718811035
epoch: 1, iter: 86400, loss: 30.155363082885742
epoch: 1, iter: 86500, loss: 30.28110122680664
epoch: 1, iter: 86600, loss: 30.374900817871094
epoch: 1, iter: 86700, loss: 30.804969787597656
epoch: 1, iter: 86800, loss: 30.20755958557129
epoch: 1, iter: 86900, loss: 30.167919158935547
epoch: 1, iter: 87000, loss: 30.547744750976562
epoch: 1, iter: 87100, loss: 30.687185287475586
epoch: 1, iter: 87200, loss: 30.32683563232422
epoch: 1, iter: 87300, loss: 30.641101837158203
epoch: 1, iter: 87400, loss: 30.987831115722656
epoch: 1, iter: 87500, loss: 30.438377380371094
epoch: 1, iter: 87600, loss: 30.0216007232666
epoch: 1, iter: 87700, loss: 30.663925170898438
epoch: 1, iter: 87800, loss: 30.71135711669922
epoch: 1, iter: 87900, loss: 30.71870994567871
epoch: 1, iter: 88000, loss: 30.205699920654297
epoch: 1, iteration: 88000, simlex-999: Spearman

epoch: 1, iter: 98200, loss: 30.314851760864258
epoch: 1, iter: 98300, loss: 30.17748260498047
epoch: 1, iter: 98400, loss: 30.060449600219727
epoch: 1, iter: 98500, loss: 30.29900550842285
epoch: 1, iter: 98600, loss: 30.583925247192383
epoch: 1, iter: 98700, loss: 30.511886596679688
epoch: 1, iter: 98800, loss: 29.978679656982422
epoch: 1, iter: 98900, loss: 30.08024787902832
epoch: 1, iter: 99000, loss: 29.74579620361328
epoch: 1, iter: 99100, loss: 30.44879722595215
epoch: 1, iter: 99200, loss: 30.379261016845703
epoch: 1, iter: 99300, loss: 29.564411163330078
epoch: 1, iter: 99400, loss: 30.413551330566406
epoch: 1, iter: 99500, loss: 29.98810386657715
epoch: 1, iter: 99600, loss: 30.30841827392578
epoch: 1, iter: 99700, loss: 30.51578140258789
epoch: 1, iter: 99800, loss: 30.445234298706055
epoch: 1, iter: 99900, loss: 30.237821578979492
epoch: 1, iter: 100000, loss: 30.199050903320312
epoch: 1, iteration: 100000, simlex-999: SpearmanrResult(correlation=0.16715232584964468, pvalu

epoch: 1, iter: 110100, loss: 30.272464752197266
epoch: 1, iter: 110200, loss: 30.38793182373047
epoch: 1, iter: 110300, loss: 30.590267181396484
epoch: 1, iter: 110400, loss: 30.97867202758789
epoch: 1, iter: 110500, loss: 30.195693969726562
epoch: 1, iter: 110600, loss: 30.050588607788086
epoch: 1, iter: 110700, loss: 30.010971069335938
epoch: 1, iter: 110800, loss: 30.200347900390625
epoch: 1, iter: 110900, loss: 30.716394424438477
epoch: 1, iter: 111000, loss: 30.02122688293457
epoch: 1, iter: 111100, loss: 30.24693489074707
epoch: 1, iter: 111200, loss: 30.085987091064453
epoch: 1, iter: 111300, loss: 30.499698638916016
epoch: 1, iter: 111400, loss: 30.532825469970703
epoch: 1, iter: 111500, loss: 29.860715866088867
epoch: 1, iter: 111600, loss: 30.18459701538086
epoch: 1, iter: 111700, loss: 30.063079833984375
epoch: 1, iter: 111800, loss: 30.4438533782959
epoch: 1, iter: 111900, loss: 29.979290008544922
epoch: 1, iter: 112000, loss: 29.959312438964844
epoch: 1, iteration: 112000

In [11]:
model.load_state_dict(torch.load("embedding-{}.th".format(EMBEDDING_SIZE)))

## 在 MEN 和 Simplex-999 数据集上做评估

In [12]:
embedding_weights = model.input_embeddings()
print("simlex-999", evaluate("simlex-999.txt", embedding_weights))
print("men", evaluate("men.txt", embedding_weights))
print("wordsim353", evaluate("wordsim353.csv", embedding_weights))

simlex-999 SpearmanrResult(correlation=0.17251697429101504, pvalue=7.863946056740345e-08)
men SpearmanrResult(correlation=0.1778096817088841, pvalue=7.565661657312768e-20)
wordsim353 SpearmanrResult(correlation=0.27153702278146635, pvalue=8.842165885381714e-07)


## 寻找nearest neighbors

In [13]:
for word in ["good", "fresh", "monster", "green", "like", "america", "chicago", "work", "computer", "language"]:
    print(word, find_nearest(word))

good ['good', 'bad', 'perfect', 'hard', 'questions', 'alone', 'money', 'false', 'truth', 'experience']
fresh ['fresh', 'grain', 'waste', 'cooling', 'lighter', 'dense', 'mild', 'sized', 'warm', 'steel']
monster ['monster', 'giant', 'robot', 'hammer', 'clown', 'bull', 'demon', 'triangle', 'storyline', 'slogan']
green ['green', 'blue', 'yellow', 'white', 'cross', 'orange', 'black', 'red', 'mountain', 'gold']
like ['like', 'unlike', 'etc', 'whereas', 'animals', 'soft', 'amongst', 'similarly', 'bear', 'drink']
america ['america', 'africa', 'korea', 'india', 'australia', 'turkey', 'pakistan', 'mexico', 'argentina', 'carolina']
chicago ['chicago', 'boston', 'illinois', 'texas', 'london', 'indiana', 'massachusetts', 'florida', 'berkeley', 'michigan']
work ['work', 'writing', 'job', 'marx', 'solo', 'label', 'recording', 'nietzsche', 'appearance', 'stage']
computer ['computer', 'digital', 'electronic', 'audio', 'video', 'graphics', 'hardware', 'software', 'computers', 'program']
language ['langu

## 单词之间的关系

In [14]:
man_idx = word_to_idx["man"] 
king_idx = word_to_idx["king"] 
woman_idx = word_to_idx["woman"]
embedding = embedding_weights[woman_idx] - embedding_weights[man_idx] + embedding_weights[king_idx]
cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
for i in cos_dis.argsort()[:20]:
    print(idx_to_word[i])

king
henry
charles
pope
queen
iii
prince
elizabeth
alexander
constantine
edward
son
iv
louis
emperor
mary
james
joseph
frederick
francis
