这一篇博客中同样实现了word2vec的计算方法，没有使用到gensim库，使用了torch，numpy，panda和sklearn。


博客地址 https://blog.csdn.net/ParisCutie/article/details/109393772
在该博客中给出了训练数据下载地址：
> 链接:https://pan.baidu.com/s/1tFeK3mXuVXEy3EMarfeWvg 密码:v2z5

**训练注意事项**

假如有10000个单词的词汇表，再嵌入300维的词向量，那么就会有10000*300多个权重需要计算，这也是很恐怖的维度灾难
下面主要介绍这种方法优化训练过程
负例采样
negative sampling 每次让一个训练样本仅仅更新一小部分的权重参数，从而降低梯度下降过程中的计算量。
如果 vocabulary 大小为1万时， 当输入样本 ( “fox”, “quick”) 到神经网络时， “ fox” 经过 one-hot 编码，在输出层我们期望对应 “quick” 单词的那个神经元结点输出 1，其余 9999 个都应该输出 0。在这里，这9999个我们期望输出为0的神经元结点所对应的单词我们为 negative word. negative sampling 的想法也很直接 ，将随机选择一小部分的 negative words，比如选 10个 negative words 来更新对应的权重参数。

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from torch.nn.parameter import Parameter

from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

"""
先进行一些预处理，导包、用GPU运行、设置随机种子、超参数、还有分词
"""


USE_CUDA = torch.cuda.is_available()

# 为了保证实验结果可以复现，我们经常会把各种random seed固定在某一个值
random.seed(53113)
np.random.seed(53113)
torch.manual_seed(53113)
if USE_CUDA:
    torch.cuda.manual_seed(53113)
    
# 设定一些超参数
    
K = 100 #负样本'''
C = 3   #附近词个数'''
NUM_EPOCHS = 2 #训练epoch''' 
MAX_VOCAB_SIZE = 30000 #词典最大数'''
BATCH_SIZE = 128 # the batch size
LEARNING_RATE = 0.2 # the initial learning rate
EMBEDDING_SIZE = 100
       
    
LOG_FILE = "word-embedding.log"

# tokenize函数，把一篇文本转化成一个个单词
def word_tokenize(text):
    return text.split()



可以通过文本文件读取文字，再创建一个词汇表，前面设置的词汇表最大是30000个，它就是一个字典，然后再添加一个unk表示未知词，再记录单词到index的mapping和index到单词的mapping

In [5]:
# 这里导入作者给出的数据集中的训练集，在数据集文件夹的readme.md文件中有进行说明
with open("F:\\自然语言处理数据集\\text8\\text8.train.txt", "r") as fin:
    text = fin.read()
    

#分词后变小写
text = [w for w in word_tokenize(text.lower())]
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
# 未知词
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()] 
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}
word_to_idx


{'the': 0,
 'of': 1,
 'and': 2,
 'one': 3,
 'in': 4,
 'a': 5,
 'to': 6,
 'zero': 7,
 'nine': 8,
 'two': 9,
 'is': 10,
 'as': 11,
 'eight': 12,
 'for': 13,
 's': 14,
 'five': 15,
 'three': 16,
 'was': 17,
 'by': 18,
 'that': 19,
 'four': 20,
 'six': 21,
 'seven': 22,
 'with': 23,
 'on': 24,
 'are': 25,
 'it': 26,
 'from': 27,
 'or': 28,
 'his': 29,
 'an': 30,
 'be': 31,
 'this': 32,
 'which': 33,
 'at': 34,
 'he': 35,
 'not': 36,
 'also': 37,
 'have': 38,
 'were': 39,
 'has': 40,
 'but': 41,
 'other': 42,
 'their': 43,
 'its': 44,
 'they': 45,
 'first': 46,
 'some': 47,
 'had': 48,
 'more': 49,
 'all': 50,
 'can': 51,
 'most': 52,
 'been': 53,
 'such': 54,
 'many': 55,
 'who': 56,
 'new': 57,
 'there': 58,
 'used': 59,
 'after': 60,
 'when': 61,
 'time': 62,
 'into': 63,
 'these': 64,
 'only': 65,
 'american': 66,
 'see': 67,
 'may': 68,
 'than': 69,
 'i': 70,
 'world': 71,
 'would': 72,
 'b': 73,
 'no': 74,
 'd': 75,
 'however': 76,
 'between': 77,
 'about': 78,
 'over': 79,
 'states':

In [6]:
"""
可以通过文本文件读取文字，再创建一个词汇表，前面设置的词汇表最大是30000个，
它就是一个字典，然后再添加一个unk表示未知词，再记录单词到index的mapping和index到单词的mapping
"""

with open("F:\\自然语言处理数据集\\text8\\text8.train.txt", "r") as fin:
    text = fin.read()
    

#分词后变小写
text = [w for w in word_tokenize(text.lower())]
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
# 未知词
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()] 
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}
word_to_idx


{'the': 0,
 'of': 1,
 'and': 2,
 'one': 3,
 'in': 4,
 'a': 5,
 'to': 6,
 'zero': 7,
 'nine': 8,
 'two': 9,
 'is': 10,
 'as': 11,
 'eight': 12,
 'for': 13,
 's': 14,
 'five': 15,
 'three': 16,
 'was': 17,
 'by': 18,
 'that': 19,
 'four': 20,
 'six': 21,
 'seven': 22,
 'with': 23,
 'on': 24,
 'are': 25,
 'it': 26,
 'from': 27,
 'or': 28,
 'his': 29,
 'an': 30,
 'be': 31,
 'this': 32,
 'which': 33,
 'at': 34,
 'he': 35,
 'not': 36,
 'also': 37,
 'have': 38,
 'were': 39,
 'has': 40,
 'but': 41,
 'other': 42,
 'their': 43,
 'its': 44,
 'they': 45,
 'first': 46,
 'some': 47,
 'had': 48,
 'more': 49,
 'all': 50,
 'can': 51,
 'most': 52,
 'been': 53,
 'such': 54,
 'many': 55,
 'who': 56,
 'new': 57,
 'there': 58,
 'used': 59,
 'after': 60,
 'when': 61,
 'time': 62,
 'into': 63,
 'these': 64,
 'only': 65,
 'american': 66,
 'see': 67,
 'may': 68,
 'than': 69,
 'i': 70,
 'world': 71,
 'would': 72,
 'b': 73,
 'no': 74,
 'd': 75,
 'however': 76,
 'between': 77,
 'about': 78,
 'over': 79,
 'states':

In [13]:
# 设置词量和词频
# 词量 词频
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs) # 用来做 negative sampling
VOCAB_SIZE = len(idx_to_word)


下面实现dataloader,一个dataloader需要一下的内容：
1. 把所有text编码成数字，然后用subsampling预处理这些文字。
1. 保存vocabulary，单词count，normalized word frequency
1. 每个iteration sample一个中心词
1. 根据当前的中心词返回context单词
1. 根据中心词sample一些negative单词
1. 返回单词的counts

In [14]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
        ''' text: a list of words, all text from the training dataset
            word_to_idx: the dictionary from word to idx
            idx_to_word: idx to word mapping
            word_freq: the frequency of each word
            word_counts: the word counts
        '''
        super(WordEmbeddingDataset, self).__init__()
        self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]
        self.text_encoded = torch.Tensor(self.text_encoded).long()
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.Tensor(word_freqs)
        self.word_counts = torch.Tensor(word_counts)
        
    def __len__(self):
        ''' 返回整个数据集（所有单词）的长度
        '''
        return len(self.text_encoded)
        
    def __getitem__(self, idx):
        ''' 这个function返回以下数据用于训练
            - 中心词
            - 这个单词附近的(positive)单词
            - 随机采样的K个单词作为negative sample
        '''
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))
        pos_indices = [i%len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices] 
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
        
        return center_word, pos_words, neg_words

dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)


In [15]:
# 定义pytorch模型
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        ''' 初始化输出和输出embedding
        '''
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        initrange = 0.5 / self.embed_size
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.out_embed.weight.data.uniform_(-initrange, initrange)
        
        
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        
        
    '''输入 正确的词 错误的词'''
    def forward(self, input_labels, pos_labels, neg_labels):
        '''
        input_labels: 中心词, [batch_size]
        pos_labels: 中心词周围 context window 出现过的单词 [batch_size ,(window_size * 2)]
        neg_labelss: 中心词周围没有出现过的单词，从 negative sampling 得到 [batch_size, (window_size * 2 * K)]
        
        return: loss, [batch_size]
        '''
        
        batch_size = input_labels.size(0)
        
        '''每个数字都embed 成一个 vector'''
        input_embedding = self.in_embed(input_labels) # B * embed_size
        pos_embedding = self.out_embed(pos_labels) # B * (2*C) * embed_size
        neg_embedding = self.out_embed(neg_labels) # B * (2*C * K) * embed_size
      
        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze() # B * (2*C)
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # B * (2*C*K)

        '''第一维上求和'''
        log_pos = F.logsigmoid(log_pos).sum(1)
        log_neg = F.logsigmoid(log_neg).sum(1) # batch_size
       
        loss = log_pos + log_neg
        
        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:")
print(device)

# 运行模型，这里原先是GPU跑的，我们改成CPU的代码
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    #model = model.cuda()
    model = model.cpu()

device:
cpu


In [17]:
# 下面是用于评估的代码，主要是用于测试词之间的相关性
def evaluate(filename, embedding_weights): 
    if filename.endswith(".csv"):
        data = pd.read_csv(filename, sep=",")
    else:
        data = pd.read_csv(filename, sep="\t")
    human_similarity = []
    model_similarity = []
    for i in data.iloc[:, 0:2].index:
        word1, word2 = data.iloc[i, 0], data.iloc[i, 1]
        if word1 not in word_to_idx or word2 not in word_to_idx:
            continue
        else:
            word1_idx, word2_idx = word_to_idx[word1], word_to_idx[word2]
            word1_embed, word2_embed = embedding_weights[[word1_idx]], embedding_weights[[word2_idx]]
            model_similarity.append(float(sklearn.metrics.pairwise.cosine_similarity(word1_embed, word2_embed)))
            human_similarity.append(float(data.iloc[i, 2]))

    return scipy.stats.spearmanr(human_similarity, model_similarity)# , model_similarity

def find_nearest(word):
    index = word_to_idx[word]
    embedding = embedding_weights[index]
    cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
    return [idx_to_word[i] for i in cos_dis.argsort()[:10]]


In [18]:
#训练部分的代码

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        
        
        # TODO
        input_labels = input_labels.long()
        pos_labels = pos_labels.long()
        neg_labels = neg_labels.long()
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
            
        optimizer.zero_grad()
        '''平均的loss'''
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            with open(LOG_FILE, "a") as fout:
                fout.write("epoch: {}, iter: {}, loss: {}\n".format(e, i, loss.item()))
                print("epoch: {}, iter: {}, loss: {}".format(e, i, loss.item()))
            
        
                
    embedding_weights = model.input_embeddings()
    np.save("embedding-{}".format(EMBEDDING_SIZE), embedding_weights)
    torch.save(model.state_dict(), "embedding-{}.th".format(EMBEDDING_SIZE))


epoch: 0, iter: 0, loss: 420.04742431640625
epoch: 0, iter: 100, loss: 280.9455261230469
epoch: 0, iter: 200, loss: 212.32415771484375
epoch: 0, iter: 300, loss: 185.0164337158203
epoch: 0, iter: 400, loss: 141.21463012695312
epoch: 0, iter: 500, loss: 135.06698608398438
epoch: 0, iter: 600, loss: 138.35374450683594
epoch: 0, iter: 700, loss: 112.25232696533203
epoch: 0, iter: 800, loss: 144.89501953125
epoch: 0, iter: 900, loss: 94.30026245117188
epoch: 0, iter: 1000, loss: 93.244140625
epoch: 0, iter: 1100, loss: 97.6533203125
epoch: 0, iter: 1200, loss: 79.48892974853516
epoch: 0, iter: 1300, loss: 68.31893157958984
epoch: 0, iter: 1400, loss: 70.45720672607422
epoch: 0, iter: 1500, loss: 77.17955017089844
epoch: 0, iter: 1600, loss: 72.80960845947266
epoch: 0, iter: 1700, loss: 73.63546752929688
epoch: 0, iter: 1800, loss: 70.2663803100586
epoch: 0, iter: 1900, loss: 68.15692138671875
epoch: 0, iter: 2000, loss: 66.25592041015625
epoch: 0, iter: 2100, loss: 80.15396118164062
epoch:

epoch: 0, iter: 17700, loss: 32.622859954833984
epoch: 0, iter: 17800, loss: 32.41484451293945
epoch: 0, iter: 17900, loss: 32.946754455566406
epoch: 0, iter: 18000, loss: 32.314414978027344
epoch: 0, iter: 18100, loss: 32.88111114501953
epoch: 0, iter: 18200, loss: 32.39853286743164
epoch: 0, iter: 18300, loss: 32.69651412963867
epoch: 0, iter: 18400, loss: 34.55544662475586
epoch: 0, iter: 18500, loss: 32.68207550048828
epoch: 0, iter: 18600, loss: 31.947376251220703
epoch: 0, iter: 18700, loss: 33.11784744262695
epoch: 0, iter: 18800, loss: 33.490089416503906
epoch: 0, iter: 18900, loss: 33.05300521850586
epoch: 0, iter: 19000, loss: 32.76217269897461
epoch: 0, iter: 19100, loss: 32.0904655456543
epoch: 0, iter: 19200, loss: 31.7235050201416
epoch: 0, iter: 19300, loss: 32.68367385864258
epoch: 0, iter: 19400, loss: 32.132896423339844
epoch: 0, iter: 19500, loss: 33.041786193847656
epoch: 0, iter: 19600, loss: 34.84003448486328
epoch: 0, iter: 19700, loss: 31.758575439453125
epoch: 

epoch: 0, iter: 35000, loss: 31.65372085571289
epoch: 0, iter: 35100, loss: 31.492177963256836
epoch: 0, iter: 35200, loss: 31.14501953125
epoch: 0, iter: 35300, loss: 31.604248046875
epoch: 0, iter: 35400, loss: 31.72936248779297
epoch: 0, iter: 35500, loss: 31.5876522064209
epoch: 0, iter: 35600, loss: 31.817596435546875
epoch: 0, iter: 35700, loss: 31.336750030517578
epoch: 0, iter: 35800, loss: 31.887115478515625
epoch: 0, iter: 35900, loss: 31.254798889160156
epoch: 0, iter: 36000, loss: 31.499473571777344
epoch: 0, iter: 36100, loss: 31.621349334716797
epoch: 0, iter: 36200, loss: 31.67698860168457
epoch: 0, iter: 36300, loss: 31.65187644958496
epoch: 0, iter: 36400, loss: 31.797264099121094
epoch: 0, iter: 36500, loss: 31.74847412109375
epoch: 0, iter: 36600, loss: 31.820606231689453
epoch: 0, iter: 36700, loss: 31.831567764282227
epoch: 0, iter: 36800, loss: 31.206192016601562
epoch: 0, iter: 36900, loss: 31.717723846435547
epoch: 0, iter: 37000, loss: 31.4779109954834
epoch: 0

epoch: 0, iter: 52300, loss: 31.150901794433594
epoch: 0, iter: 52400, loss: 30.906295776367188
epoch: 0, iter: 52500, loss: 31.07915496826172
epoch: 0, iter: 52600, loss: 31.541339874267578
epoch: 0, iter: 52700, loss: 31.333927154541016
epoch: 0, iter: 52800, loss: 30.98295783996582
epoch: 0, iter: 52900, loss: 30.94872283935547
epoch: 0, iter: 53000, loss: 30.993457794189453
epoch: 0, iter: 53100, loss: 31.432682037353516
epoch: 0, iter: 53200, loss: 31.11025047302246
epoch: 0, iter: 53300, loss: 30.78118133544922
epoch: 0, iter: 53400, loss: 31.188953399658203
epoch: 0, iter: 53500, loss: 31.29664421081543
epoch: 0, iter: 53600, loss: 31.00973129272461
epoch: 0, iter: 53700, loss: 31.332082748413086
epoch: 0, iter: 53800, loss: 30.618459701538086
epoch: 0, iter: 53900, loss: 30.951351165771484
epoch: 0, iter: 54000, loss: 31.382783889770508
epoch: 0, iter: 54100, loss: 30.48594093322754
epoch: 0, iter: 54200, loss: 30.927156448364258
epoch: 0, iter: 54300, loss: 31.086322784423828


epoch: 0, iter: 69600, loss: 30.973575592041016
epoch: 0, iter: 69700, loss: 30.55046272277832
epoch: 0, iter: 69800, loss: 30.571338653564453
epoch: 0, iter: 69900, loss: 31.340660095214844
epoch: 0, iter: 70000, loss: 31.16770362854004
epoch: 0, iter: 70100, loss: 31.178417205810547
epoch: 0, iter: 70200, loss: 30.85662269592285
epoch: 0, iter: 70300, loss: 30.93984603881836
epoch: 0, iter: 70400, loss: 31.12394142150879
epoch: 0, iter: 70500, loss: 30.65265655517578
epoch: 0, iter: 70600, loss: 30.738887786865234
epoch: 0, iter: 70700, loss: 31.22895622253418
epoch: 0, iter: 70800, loss: 31.20877456665039
epoch: 0, iter: 70900, loss: 30.531333923339844
epoch: 0, iter: 71000, loss: 30.64777374267578
epoch: 0, iter: 71100, loss: 30.633207321166992
epoch: 0, iter: 71200, loss: 31.328046798706055
epoch: 0, iter: 71300, loss: 31.190513610839844
epoch: 0, iter: 71400, loss: 30.602359771728516
epoch: 0, iter: 71500, loss: 31.03752326965332
epoch: 0, iter: 71600, loss: 30.312740325927734
ep

epoch: 0, iter: 86900, loss: 31.112844467163086
epoch: 0, iter: 87000, loss: 30.646249771118164
epoch: 0, iter: 87100, loss: 30.514963150024414
epoch: 0, iter: 87200, loss: 30.814855575561523
epoch: 0, iter: 87300, loss: 30.547473907470703
epoch: 0, iter: 87400, loss: 31.03626823425293
epoch: 0, iter: 87500, loss: 30.531963348388672
epoch: 0, iter: 87600, loss: 30.970291137695312
epoch: 0, iter: 87700, loss: 31.113658905029297
epoch: 0, iter: 87800, loss: 30.70583152770996
epoch: 0, iter: 87900, loss: 30.231983184814453
epoch: 0, iter: 88000, loss: 31.034385681152344
epoch: 0, iter: 88100, loss: 31.26152801513672
epoch: 0, iter: 88200, loss: 30.556272506713867
epoch: 0, iter: 88300, loss: 30.37435531616211
epoch: 0, iter: 88400, loss: 30.888999938964844
epoch: 0, iter: 88500, loss: 31.43544578552246
epoch: 0, iter: 88600, loss: 30.379802703857422
epoch: 0, iter: 88700, loss: 30.733320236206055
epoch: 0, iter: 88800, loss: 30.518768310546875
epoch: 0, iter: 88900, loss: 30.7299919128417

epoch: 0, iter: 104100, loss: 30.34078598022461
epoch: 0, iter: 104200, loss: 30.904800415039062
epoch: 0, iter: 104300, loss: 31.234813690185547
epoch: 0, iter: 104400, loss: 30.50078773498535
epoch: 0, iter: 104500, loss: 30.261173248291016
epoch: 0, iter: 104600, loss: 30.976444244384766
epoch: 0, iter: 104700, loss: 30.89337730407715
epoch: 0, iter: 104800, loss: 30.80858612060547
epoch: 0, iter: 104900, loss: 31.038652420043945
epoch: 0, iter: 105000, loss: 31.101478576660156
epoch: 0, iter: 105100, loss: 30.45868682861328
epoch: 0, iter: 105200, loss: 30.741971969604492
epoch: 0, iter: 105300, loss: 30.977130889892578
epoch: 0, iter: 105400, loss: 31.44770622253418
epoch: 0, iter: 105500, loss: 30.89352035522461
epoch: 0, iter: 105600, loss: 31.16472816467285
epoch: 0, iter: 105700, loss: 30.47923469543457
epoch: 0, iter: 105800, loss: 30.902729034423828
epoch: 0, iter: 105900, loss: 30.82274627685547
epoch: 0, iter: 106000, loss: 30.664066314697266
epoch: 0, iter: 106100, loss: 

epoch: 1, iter: 1500, loss: 31.255069732666016
epoch: 1, iter: 1600, loss: 30.734466552734375
epoch: 1, iter: 1700, loss: 30.547897338867188
epoch: 1, iter: 1800, loss: 30.663860321044922
epoch: 1, iter: 1900, loss: 30.66433334350586
epoch: 1, iter: 2000, loss: 30.804462432861328
epoch: 1, iter: 2100, loss: 30.185583114624023
epoch: 1, iter: 2200, loss: 30.620018005371094
epoch: 1, iter: 2300, loss: 31.070510864257812
epoch: 1, iter: 2400, loss: 30.151710510253906
epoch: 1, iter: 2500, loss: 30.3049259185791
epoch: 1, iter: 2600, loss: 30.486413955688477
epoch: 1, iter: 2700, loss: 30.44532012939453
epoch: 1, iter: 2800, loss: 30.69796371459961
epoch: 1, iter: 2900, loss: 30.20358657836914
epoch: 1, iter: 3000, loss: 30.52279281616211
epoch: 1, iter: 3100, loss: 30.76050567626953
epoch: 1, iter: 3200, loss: 30.07227325439453
epoch: 1, iter: 3300, loss: 30.465713500976562
epoch: 1, iter: 3400, loss: 30.810346603393555
epoch: 1, iter: 3500, loss: 30.72740364074707
epoch: 1, iter: 3600, l

epoch: 1, iter: 18900, loss: 30.538433074951172
epoch: 1, iter: 19000, loss: 30.877334594726562
epoch: 1, iter: 19100, loss: 31.18408966064453
epoch: 1, iter: 19200, loss: 29.835643768310547
epoch: 1, iter: 19300, loss: 30.767580032348633
epoch: 1, iter: 19400, loss: 30.22886848449707
epoch: 1, iter: 19500, loss: 30.625030517578125
epoch: 1, iter: 19600, loss: 30.608631134033203
epoch: 1, iter: 19700, loss: 29.807720184326172
epoch: 1, iter: 19800, loss: 30.422557830810547
epoch: 1, iter: 19900, loss: 30.826189041137695
epoch: 1, iter: 20000, loss: 30.36678123474121
epoch: 1, iter: 20100, loss: 30.468250274658203
epoch: 1, iter: 20200, loss: 31.115556716918945
epoch: 1, iter: 20300, loss: 31.02297019958496
epoch: 1, iter: 20400, loss: 30.46123695373535
epoch: 1, iter: 20500, loss: 30.97983741760254
epoch: 1, iter: 20600, loss: 30.242549896240234
epoch: 1, iter: 20700, loss: 30.25586700439453
epoch: 1, iter: 20800, loss: 30.98520278930664
epoch: 1, iter: 20900, loss: 30.591732025146484


epoch: 1, iter: 36200, loss: 30.946800231933594
epoch: 1, iter: 36300, loss: 30.634357452392578
epoch: 1, iter: 36400, loss: 30.616987228393555
epoch: 1, iter: 36500, loss: 30.4654598236084
epoch: 1, iter: 36600, loss: 30.1270809173584
epoch: 1, iter: 36700, loss: 30.501789093017578
epoch: 1, iter: 36800, loss: 29.987722396850586
epoch: 1, iter: 36900, loss: 30.554763793945312
epoch: 1, iter: 37000, loss: 30.672256469726562
epoch: 1, iter: 37100, loss: 30.579334259033203
epoch: 1, iter: 37200, loss: 30.264326095581055
epoch: 1, iter: 37300, loss: 30.972774505615234
epoch: 1, iter: 37400, loss: 30.168289184570312
epoch: 1, iter: 37500, loss: 30.339183807373047
epoch: 1, iter: 37600, loss: 30.45969581604004
epoch: 1, iter: 37700, loss: 30.209287643432617
epoch: 1, iter: 37800, loss: 30.717679977416992
epoch: 1, iter: 37900, loss: 30.499910354614258
epoch: 1, iter: 38000, loss: 29.963001251220703
epoch: 1, iter: 38100, loss: 29.734071731567383
epoch: 1, iter: 38200, loss: 30.4552707672119

epoch: 1, iter: 53400, loss: 30.636764526367188
epoch: 1, iter: 53500, loss: 30.41672706604004
epoch: 1, iter: 53600, loss: 30.121938705444336
epoch: 1, iter: 53700, loss: 30.12906265258789
epoch: 1, iter: 53800, loss: 30.263370513916016
epoch: 1, iter: 53900, loss: 30.39374542236328
epoch: 1, iter: 54000, loss: 30.63083267211914
epoch: 1, iter: 54100, loss: 30.913244247436523
epoch: 1, iter: 54200, loss: 30.351009368896484
epoch: 1, iter: 54300, loss: 30.269506454467773
epoch: 1, iter: 54400, loss: 30.370874404907227
epoch: 1, iter: 54500, loss: 30.889720916748047
epoch: 1, iter: 54600, loss: 30.222816467285156
epoch: 1, iter: 54700, loss: 30.6479549407959
epoch: 1, iter: 54800, loss: 30.750118255615234
epoch: 1, iter: 54900, loss: 30.138599395751953
epoch: 1, iter: 55000, loss: 30.798246383666992
epoch: 1, iter: 55100, loss: 30.780776977539062
epoch: 1, iter: 55200, loss: 30.994680404663086
epoch: 1, iter: 55300, loss: 30.043785095214844
epoch: 1, iter: 55400, loss: 30.59393310546875

epoch: 1, iter: 70700, loss: 30.16600799560547
epoch: 1, iter: 70800, loss: 31.183349609375
epoch: 1, iter: 70900, loss: 30.591449737548828
epoch: 1, iter: 71000, loss: 30.225963592529297
epoch: 1, iter: 71100, loss: 30.788251876831055
epoch: 1, iter: 71200, loss: 29.75831413269043
epoch: 1, iter: 71300, loss: 30.115829467773438
epoch: 1, iter: 71400, loss: 30.414857864379883
epoch: 1, iter: 71500, loss: 30.734912872314453
epoch: 1, iter: 71600, loss: 30.05133056640625
epoch: 1, iter: 71700, loss: 29.749170303344727
epoch: 1, iter: 71800, loss: 30.840194702148438
epoch: 1, iter: 71900, loss: 30.2999324798584
epoch: 1, iter: 72000, loss: 30.25675392150879
epoch: 1, iter: 72100, loss: 30.026107788085938
epoch: 1, iter: 72200, loss: 30.221410751342773
epoch: 1, iter: 72300, loss: 31.193042755126953
epoch: 1, iter: 72400, loss: 30.52434730529785
epoch: 1, iter: 72500, loss: 30.21171760559082
epoch: 1, iter: 72600, loss: 30.864728927612305
epoch: 1, iter: 72700, loss: 29.95269203186035
epoc

epoch: 1, iter: 88000, loss: 30.622028350830078
epoch: 1, iter: 88100, loss: 30.535669326782227
epoch: 1, iter: 88200, loss: 30.221040725708008
epoch: 1, iter: 88300, loss: 30.317367553710938
epoch: 1, iter: 88400, loss: 30.19761848449707
epoch: 1, iter: 88500, loss: 30.48804473876953
epoch: 1, iter: 88600, loss: 30.860395431518555
epoch: 1, iter: 88700, loss: 31.044105529785156
epoch: 1, iter: 88800, loss: 30.652788162231445
epoch: 1, iter: 88900, loss: 30.680356979370117
epoch: 1, iter: 89000, loss: 30.222814559936523
epoch: 1, iter: 89100, loss: 30.2125244140625
epoch: 1, iter: 89200, loss: 30.532150268554688
epoch: 1, iter: 89300, loss: 30.421567916870117
epoch: 1, iter: 89400, loss: 30.806224822998047
epoch: 1, iter: 89500, loss: 30.464120864868164
epoch: 1, iter: 89600, loss: 30.079832077026367
epoch: 1, iter: 89700, loss: 30.178625106811523
epoch: 1, iter: 89800, loss: 30.632232666015625
epoch: 1, iter: 89900, loss: 30.290834426879883
epoch: 1, iter: 90000, loss: 30.194770812988

epoch: 1, iter: 105200, loss: 30.59814453125
epoch: 1, iter: 105300, loss: 31.09385108947754
epoch: 1, iter: 105400, loss: 30.71389389038086
epoch: 1, iter: 105500, loss: 30.132606506347656
epoch: 1, iter: 105600, loss: 30.67302703857422
epoch: 1, iter: 105700, loss: 30.170761108398438
epoch: 1, iter: 105800, loss: 30.473005294799805
epoch: 1, iter: 105900, loss: 30.78424835205078
epoch: 1, iter: 106000, loss: 30.167783737182617
epoch: 1, iter: 106100, loss: 30.344383239746094
epoch: 1, iter: 106200, loss: 30.43816375732422
epoch: 1, iter: 106300, loss: 29.755826950073242
epoch: 1, iter: 106400, loss: 30.154783248901367
epoch: 1, iter: 106500, loss: 30.50189208984375
epoch: 1, iter: 106600, loss: 30.070846557617188
epoch: 1, iter: 106700, loss: 30.227800369262695
epoch: 1, iter: 106800, loss: 30.58196449279785
epoch: 1, iter: 106900, loss: 30.158279418945312
epoch: 1, iter: 107000, loss: 29.916120529174805
epoch: 1, iter: 107100, loss: 30.114656448364258
epoch: 1, iter: 107200, loss: 3

In [19]:
# 训练好了就可以用于做相关性的测试，导入训练好的权重，就可以得到下面的结果。
model.load_state_dict(torch.load("embedding-{}.th".format(EMBEDDING_SIZE)))
embedding_weights = model.input_embeddings()
for word in ["good", "fresh", "monster", "green", "like", "america", "chicago", "work", "computer", "language"]:
    print(word, find_nearest(word))


good ['good', 'bad', 'alone', 'perfect', 'experience', 'hard', 'truth', 'money', 'really', 'poor']
fresh ['fresh', 'grain', 'dense', 'lighter', 'waste', 'noise', 'cooling', 'sized', 'mild', 'rigid']
monster ['monster', 'giant', 'robot', 'blade', 'stone', 'clown', 'hammer', 'bull', 'ghost', 'finger']
green ['green', 'blue', 'yellow', 'white', 'cross', 'orange', 'red', 'black', 'snow', 'mountain']
like ['like', 'etc', 'unlike', 'animals', 'amongst', 'soft', 'whereas', 'rich', 'eat', 'similarly']
america ['america', 'africa', 'korea', 'india', 'turkey', 'australia', 'pakistan', 'indian', 'argentina', 'carolina']
chicago ['chicago', 'boston', 'london', 'texas', 'illinois', 'indiana', 'massachusetts', 'florida', 'ohio', 'pennsylvania']
work ['work', 'writing', 'job', 'marx', 'solo', 'nietzsche', 'writings', 'vision', 'appearance', 'songs']
computer ['computer', 'digital', 'audio', 'electronic', 'video', 'graphics', 'hardware', 'computers', 'software', 'program']
language ['language', 'langu