## 词向量
- 学习词向量的概念
- 用Skip-thought模型训练词向量
- 学习定义PyTorch模型
- 学习torch.nn中常见的Module
    - Embedding
- 学习常见的PyTorch operations
    - bmm
    - logsigmoid
- 保存和读取PyTorch模型

**分布式表示**
用一个词附近的其他词来表示---现代统计自然语言处理中最有创见的想法之一。

**Word2Vec:Skip-Gram模型**
输入一个中心词w(t)，t表示这个词的位置。用一个一层神经网络，预测它周围的单词：w(t-2)、w(t-1)、w(t+1)、w(t+2)，目的并不是为了预测周围的单词，而是为了得到参数，这是一个假任务，真正需要的是参数。

![Skip-Gram](images/1.png)
![Skip-Gram-Loss](images/2.png)

如果按照模型公式，假如有5w个单词，那么目标函数的分母就会非常大，所以进行修改，进行负例采样，给一个正确的中心词，一个正确的周围词，希望它越大越好。再给一个中心词，和一些错误的周围词，希望它的概率越低越好。

![Skip-Gram-负例采样](images/3.png)

第二个式子是目标函数，要越大越好。第一部分是衡量 中心词和附近的词之间的函数，第二部分是衡量k个随机选出不是中心词附近的词和中心词的函数。由于第二部分的点积带负号，所以也是越大越好。所以就这样确定训练目标，然后训练参数。

log后面是sigmoid函数

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import random
import math
import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

USE_CUDA = torch.cuda.is_available()

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
    torch.cuda.manual_seed(1)

# 设定一些超参数
C = 3 # context window
K = 100 # number of negative samples
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100

def word_tokenize(text):
    return text.split()

- 从文本文件中读取所有的文字，通过这些文本创建一个vocabulary
- 由于单词数量可能太大，只选取最常见的MAX_VOCAB_SIZE个单词
- 我们添加一个UNK单词表示所有不常见的单词
- 我们需要记录单词到index的mapping、以及index到单词的mapping、单词的count、单词的(normalized)frequency，以及单词总数

In [2]:
with open("text8/text8.train.txt", "r") as fin:
    text = fin.read()
    
text = text.split()

# 出现频率最高的MAX_VOCAB_SIZE-1个单词,留一个位置给unk(不常见的单词)
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))

idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}
word_to_idx

{'the': 0,
 'of': 1,
 'and': 2,
 'one': 3,
 'in': 4,
 'a': 5,
 'to': 6,
 'zero': 7,
 'nine': 8,
 'two': 9,
 'is': 10,
 'as': 11,
 'eight': 12,
 'for': 13,
 's': 14,
 'five': 15,
 'three': 16,
 'was': 17,
 'by': 18,
 'that': 19,
 'four': 20,
 'six': 21,
 'seven': 22,
 'with': 23,
 'on': 24,
 'are': 25,
 'it': 26,
 'from': 27,
 'or': 28,
 'his': 29,
 'an': 30,
 'be': 31,
 'this': 32,
 'which': 33,
 'at': 34,
 'he': 35,
 'not': 36,
 'also': 37,
 'have': 38,
 'were': 39,
 'has': 40,
 'but': 41,
 'other': 42,
 'their': 43,
 'its': 44,
 'they': 45,
 'first': 46,
 'some': 47,
 'had': 48,
 'more': 49,
 'all': 50,
 'can': 51,
 'most': 52,
 'been': 53,
 'such': 54,
 'many': 55,
 'who': 56,
 'new': 57,
 'there': 58,
 'used': 59,
 'after': 60,
 'when': 61,
 'time': 62,
 'into': 63,
 'these': 64,
 'only': 65,
 'american': 66,
 'see': 67,
 'may': 68,
 'than': 69,
 'i': 70,
 'world': 71,
 'would': 72,
 'b': 73,
 'no': 74,
 'd': 75,
 'however': 76,
 'between': 77,
 'about': 78,
 'over': 79,
 'states':

In [3]:
word_counts = np.array([count for count in vocab.values()], dtype = np.float32)
word_freqs = word_counts / np.sum(word_counts)
print(word_counts)
# 把概率全都提到3/4次方，然后再重新normalize
# 3/4次之后，会将高概率的单词的概率值，分一部分给低概率的单词。
# 因为相同的操作，对高概率单词的概率值影响更大
word_freqs = word_freqs ** (3./4.)

# 词频3/4次方以后，词频和不为1了，所以重新normalize一下
word_freqs = word_freqs / np.sum(word_freqs)
VOCAB_SIZE = len(idx_to_word)


[9.59616e+05 5.37144e+05 3.76233e+05 ... 2.00000e+01 2.00000e+01
 6.17240e+05]


### 实现Dataloader

dataloader可以自动生成batch

一个dataloader需要以下内容：
- 把所有text编码成数字，然后用subsampling预处理这些文字
- 保存vocabulary，单词count，normalized word frequency
- 每个iteration sample一个中心词
- 根据当前的中心词返回context单词
- 根据中心词sample一些negative单词
- 返回单词的counts

为了使用dataloader，需要定义以下两个function：
-  ```__len__function``` 需要返回整个数据集中有多少个item
-  ```__get__``` 根据给定的index返回一个item

有了dataloader之后，可以轻松随机打乱整个数据集，拿到一个batch的数据等等


### Dataset和Dataloader

简单说，用 一个类 抽象地表示数据集，而 Dataloader 作为迭代器，每次产生一个 batch 大小的数据，节省内存。

Dataset 是 PyTorch 中用来表示数据集的一个抽象类，我们的数据集可以用这个类来表示，至少覆写下面两个方法即可：
- ```__len__```：数据集大小
- ```__getitem__```:实现这个方法后，可以通过下标的方式(dataset[i] )的来取得第 i 个数据

In [27]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
        super(WordEmbeddingDataset, self).__init__()
        self.text_encoded = [word_to_idx.get(word, word_to_idx["<unk>"]) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.Tensor(word_freqs)
        self.word_counts = torch.Tensor(word_counts)
        
    def __len__(self):
        # 这个数据集一共有多少个item
        return len(self.text_encoded)
        
    def __getitem__(self, idx): 
        # 首先要返回idx所对应的中心词
        center_word = self.text_encoded[idx]
        
        # 周围词 range防止小于0 余数防止越界
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))
        
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        
        pos_words = self.text_encoded[pos_indices]
        
        # 随机产生不在周围的词 负例采样单词
        # 在三万个单词，根据freqs做采样，返回的是下标，对每个正确的单词采样k个不正确的单词
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
        
        return center_word, pos_words, neg_words

创建dataset和dataloader

In [35]:
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle = True, num_workers = 0)



In [42]:
dataset.text_encoded[:100]

tensor([ 4813,  3139,    11,     5,   194,     1,  3015,    46,    59,   155,
          127,   741,   461, 10485,   133,     0, 25752,     1,     0,   108,
          833,     2,     0, 16267, 29999,     1,     0,   152,   833,  3493,
            0,   194,    10,   186,    59,     4,     5, 10620,   213,     6,
         1332,   102,   437,    19,    59,  2764,   355,     6,  3625,     0,
          709,     1,   364,    26,    40,    37,    53,   527,    97,    11,
            5,  1398,  2929,    18,   562,   691,  6644,     0,   252,  4813,
           10,  1043,    27,     0,   316,   247, 29999,  2964,   789,   189,
         4813,    11,     5,   201,   569,    10,     0,  1107,    19,  2581,
           25,  8819,     2,   273,    31,  4089,   140,    58,    25,  6494])

In [36]:
next(iter(dataloader))

[tensor([ 2233,     4,   561,   119,  1586,  8470,   390,    15,  1390,   285,
             1,    21,    47, 29999,     1,  3021,  1247,   424,     4,  2663,
           408,     0,     0,   140,    11,    15,  1167,  2480,     3,   784,
          1665, 12015, 13367,     4,    16,  1764,  4905,    42,    71,     0,
           113,     7,    16,    45,  1086,   758,    95,  8825,     2,  1421,
          1465,     2,   706,  3333,    18,     0,   195,  4586,    18,    23,
            27,    27,    11,   522,  2629,    18,  1385,   156,  7448,     8,
           115,  1392,     0,  1077,   339,  2669,  3677,    15,    36,  3990,
          2588,    26,    88,  9680,    13,  3337,     1,    11,     6,  2965,
             6,  1255, 12705,     8,  4866,     0,  4560,   336,   376,    31,
         29999,    10,   728,  1730,  3385,  2609, 19043,   484,    15,    55,
          1057,    45,     0,    30,  2598,  2986,     8,    38,   353,  2336,
          5195, 29999,  1218,   196,     7,   808,  

### 定义PyTorch模型


In [72]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        # nn.Embedding
        # 给一个编号，嵌入层就能返回这个编号对应的嵌入向量
        # 输入为一个编号列表，输出为对应的符号嵌入向量列表
        # num_embeddings :词典的大小尺寸
        # embedding_dim :嵌入向量的维度，即用多少维来表示一个符号。
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)
    
    def forward(self, input_labels, pos_labels, neg_labels):
        # input_labels: [batch_size] 中心词
        # pos_labels: [batch_size, (window_size * 2)] 周围词
        # neg_labels: [batch_size], (window_size * 2 * K) 非周围词
        
        # batch是128个数字,把每一个数字embed成vector
        input_embedding = self.in_embed(input_labels) # [batch_size , embed_size]
        pos_embedding = self.out_embed(pos_labels) # [batch_size,  (window_size * 2), embed_size]
        neg_embedding = self.out_embed(neg_labels) # [batch_size,  (window_size * 2 * K), embed_size]
        
        
        input_embedding = input_embedding.unsqueeze(2) # 增加tensor的维度 [batch_size , embed_size, 1]
        
        # loss里面说要做一个点乘，
        # bmm: (b,n,m) x (b,m,p) = (b,n,p)
        pos_dot = torch.bmm(pos_embedding, input_embedding).squeeze(2) # [batch_size, (window_size * 2), 1]
        neg_dot = torch.bmm(neg_embedding, -input_embedding).squeeze(2) # [batch_size, (window_size * 2 * k), 1]
        
        
        log_pos = F.logsigmoid(pos_dot).sum(1) # 在第一维上求和
        log_neg = F.logsigmoid(neg_dot).sum(1)
        
        loss = log_pos + log_neg
        
        # loss当前是一个objective function,并不是一个loss，返回负的
        return -loss # [batch_size]
        
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu.numpy()
        
        
        
        
        

定义一个模型以及把模型移动到GPU

In [73]:
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()

下面是评估模型的代码，以及训练模型的代码

训练模型：
- 模型一般需要训练若干个epoch
- 每个epoch我们都把所有的数据分成若干个batch
- 把每个batch的输入和输出都包装成cuda tensor
- forward pass，通过输入的句子预测每个单词的下一个单词
- 用模型的预测和正确的下一个单词计算cross entropy loss
- 清空模型当前gradient
- backward pass
- 更新模型参数
- 每隔一定的iteration输出模型在当前iteration的loss，以及在验证数据集上做模型的评估

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
    for i, [input_labels, pos_labels, neg_labels] in enumerate(dataloader):
        input_labels = input_labels.long()
        pos_labels = pos_labels.long()
        neg_labels = neg_labels.long()
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
            
            optimizer.zero_grad()
            loss = model(input_labels, pos_labels, neg_labels).mean() # 求个平均，返回的是若干个loss
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print("epoch", e, "iteration", i, loss.item())
                

epoch 0 iteration 0 2400.27685546875
epoch 0 iteration 100 1235.7406005859375
epoch 0 iteration 200 873.1128540039062
epoch 0 iteration 300 781.069580078125
epoch 0 iteration 400 636.6002197265625
epoch 0 iteration 500 579.11865234375
epoch 0 iteration 600 507.3105773925781
epoch 0 iteration 700 453.44580078125
epoch 0 iteration 800 433.0643310546875
epoch 0 iteration 900 375.1482849121094
epoch 0 iteration 1000 313.5187683105469
epoch 0 iteration 1100 383.52899169921875
epoch 0 iteration 1200 292.67230224609375
epoch 0 iteration 1300 269.2372131347656
epoch 0 iteration 1400 284.4329833984375
epoch 0 iteration 1500 257.2510986328125
epoch 0 iteration 1600 307.3775939941406
epoch 0 iteration 1700 259.46124267578125
epoch 0 iteration 1800 173.3792266845703
epoch 0 iteration 1900 230.35145568847656
epoch 0 iteration 2000 199.69741821289062
epoch 0 iteration 2100 179.59371948242188
epoch 0 iteration 2200 175.15467834472656
epoch 0 iteration 2300 198.58416748046875
epoch 0 iteration 2400 24

epoch 0 iteration 19700 41.61482238769531
epoch 0 iteration 19800 44.519805908203125
epoch 0 iteration 19900 41.7851676940918
epoch 0 iteration 20000 44.28675079345703
epoch 0 iteration 20100 40.056663513183594
epoch 0 iteration 20200 50.26182556152344
epoch 0 iteration 20300 38.33570098876953
epoch 0 iteration 20400 40.405731201171875
epoch 0 iteration 20500 43.90251159667969
epoch 0 iteration 20600 39.8834228515625
epoch 0 iteration 20700 39.68968200683594
epoch 0 iteration 20800 44.30888748168945
epoch 0 iteration 20900 41.17680740356445
epoch 0 iteration 21000 45.64203643798828
epoch 0 iteration 21100 42.70426940917969
epoch 0 iteration 21200 41.81909942626953
epoch 0 iteration 21300 40.968772888183594
epoch 0 iteration 21400 51.406654357910156
epoch 0 iteration 21500 46.19879913330078
epoch 0 iteration 21600 41.08018493652344
epoch 0 iteration 21700 43.468955993652344
epoch 0 iteration 21800 45.41233825683594
epoch 0 iteration 21900 40.498992919921875
epoch 0 iteration 22000 40.92

epoch 0 iteration 39200 40.278167724609375
epoch 0 iteration 39300 36.57728576660156
epoch 0 iteration 39400 37.17780303955078
epoch 0 iteration 39500 37.56548309326172
epoch 0 iteration 39600 38.165802001953125
epoch 0 iteration 39700 35.748016357421875
epoch 0 iteration 39800 35.2784309387207
epoch 0 iteration 39900 37.46043014526367
epoch 0 iteration 40000 36.22037887573242
epoch 0 iteration 40100 38.66938018798828
epoch 0 iteration 40200 36.7360725402832
epoch 0 iteration 40300 38.93268966674805
epoch 0 iteration 40400 36.693824768066406
epoch 0 iteration 40500 38.67284393310547
epoch 0 iteration 40600 38.02044677734375
epoch 0 iteration 40700 36.06898880004883
epoch 0 iteration 40800 37.20315170288086
epoch 0 iteration 40900 35.51831817626953
epoch 0 iteration 41000 38.759864807128906
epoch 0 iteration 41100 36.873165130615234
epoch 0 iteration 41200 35.41002655029297
epoch 0 iteration 41300 36.7028694152832
epoch 0 iteration 41400 37.061561584472656
epoch 0 iteration 41500 36.659

epoch 0 iteration 58700 35.962364196777344
epoch 0 iteration 58800 36.669593811035156
epoch 0 iteration 58900 35.2177619934082
epoch 0 iteration 59000 34.8304328918457
epoch 0 iteration 59100 35.197505950927734
epoch 0 iteration 59200 36.11582946777344
epoch 0 iteration 59300 33.8350715637207
epoch 0 iteration 59400 36.75258255004883
epoch 0 iteration 59500 36.724266052246094
epoch 0 iteration 59600 36.119895935058594
epoch 0 iteration 59700 34.370948791503906
epoch 0 iteration 59800 34.38564682006836
epoch 0 iteration 59900 36.48630142211914
epoch 0 iteration 60000 36.19337463378906
epoch 0 iteration 60100 35.00898742675781
epoch 0 iteration 60200 35.00585174560547
epoch 0 iteration 60300 35.82815933227539
epoch 0 iteration 60400 36.136287689208984
epoch 0 iteration 60500 35.23573684692383
epoch 0 iteration 60600 36.856170654296875
epoch 0 iteration 60700 38.136287689208984
epoch 0 iteration 60800 36.537750244140625
epoch 0 iteration 60900 36.57293701171875
epoch 0 iteration 61000 35.

epoch 0 iteration 78100 32.973114013671875
epoch 0 iteration 78200 35.74901580810547
epoch 0 iteration 78300 34.127479553222656
epoch 0 iteration 78400 33.23461151123047
epoch 0 iteration 78500 34.9541015625
epoch 0 iteration 78600 34.61484909057617
epoch 0 iteration 78700 34.43072509765625
epoch 0 iteration 78800 35.10067367553711
epoch 0 iteration 78900 36.096527099609375
epoch 0 iteration 79000 35.383602142333984
epoch 0 iteration 79100 35.883087158203125
epoch 0 iteration 79200 34.12458038330078
epoch 0 iteration 79300 34.221046447753906
epoch 0 iteration 79400 34.15699005126953
epoch 0 iteration 79500 34.627479553222656
epoch 0 iteration 79600 34.65629577636719
epoch 0 iteration 79700 34.796722412109375
epoch 0 iteration 79800 34.737857818603516
epoch 0 iteration 79900 35.985530853271484
epoch 0 iteration 80000 34.37731170654297
epoch 0 iteration 80100 36.238075256347656
epoch 0 iteration 80200 35.358829498291016
epoch 0 iteration 80300 35.04269790649414
epoch 0 iteration 80400 34

epoch 0 iteration 97600 33.07356262207031
epoch 0 iteration 97700 34.38031005859375
epoch 0 iteration 97800 34.25047302246094
epoch 0 iteration 97900 33.69698715209961
epoch 0 iteration 98000 33.3869514465332
epoch 0 iteration 98100 33.9665412902832
epoch 0 iteration 98200 34.90050506591797
epoch 0 iteration 98300 33.79363250732422
epoch 0 iteration 98400 33.633811950683594
epoch 0 iteration 98500 35.04399871826172
epoch 0 iteration 98600 34.1281623840332
epoch 0 iteration 98700 34.10356903076172
epoch 0 iteration 98800 36.02113342285156
epoch 0 iteration 98900 33.382442474365234
epoch 0 iteration 99000 34.544837951660156
epoch 0 iteration 99100 33.66824722290039
epoch 0 iteration 99200 33.94880676269531
epoch 0 iteration 99300 33.37635040283203
epoch 0 iteration 99400 34.60924530029297
epoch 0 iteration 99500 34.0490608215332
epoch 0 iteration 99600 35.24097442626953
epoch 0 iteration 99700 33.96331787109375
epoch 0 iteration 99800 34.234310150146484
epoch 0 iteration 99900 33.8646011

epoch 0 iteration 116700 34.14668655395508
epoch 0 iteration 116800 33.42211151123047
epoch 0 iteration 116900 33.48313903808594
epoch 0 iteration 117000 34.55853271484375
epoch 0 iteration 117100 32.793792724609375
epoch 0 iteration 117200 34.10591125488281
epoch 0 iteration 117300 34.893959045410156
epoch 0 iteration 117400 33.57299041748047
epoch 0 iteration 117500 33.6443977355957
epoch 0 iteration 117600 34.455081939697266
epoch 0 iteration 117700 33.928794860839844
epoch 0 iteration 117800 33.8573112487793
epoch 0 iteration 117900 32.472232818603516
epoch 0 iteration 118000 34.4710807800293
epoch 0 iteration 118100 34.199440002441406
epoch 0 iteration 118200 33.7806396484375
epoch 0 iteration 118300 35.03395080566406
epoch 0 iteration 118400 33.204124450683594
epoch 0 iteration 118500 35.311119079589844
epoch 0 iteration 118600 32.63358688354492
epoch 0 iteration 118700 33.24555587768555
epoch 0 iteration 118800 35.32106018066406
epoch 0 iteration 118900 33.595619201660156
epoch 

epoch 1 iteration 16800 33.841217041015625
epoch 1 iteration 16900 33.0023193359375
epoch 1 iteration 17000 33.8380012512207
epoch 1 iteration 17100 33.542179107666016
epoch 1 iteration 17200 32.865867614746094
epoch 1 iteration 17300 33.40606689453125
epoch 1 iteration 17400 32.647796630859375
epoch 1 iteration 17500 34.47508239746094
epoch 1 iteration 17600 33.34884262084961
epoch 1 iteration 17700 33.444488525390625
epoch 1 iteration 17800 33.136451721191406
epoch 1 iteration 17900 34.18987274169922
epoch 1 iteration 18000 33.22206115722656
epoch 1 iteration 18100 34.205665588378906
epoch 1 iteration 18200 32.683349609375
epoch 1 iteration 18300 33.14904022216797
epoch 1 iteration 18400 34.65764617919922
epoch 1 iteration 18500 33.43492889404297
epoch 1 iteration 18600 33.035545349121094
epoch 1 iteration 18700 34.544517517089844
epoch 1 iteration 18800 32.98221969604492
epoch 1 iteration 18900 32.155250549316406
epoch 1 iteration 19000 33.0732536315918
epoch 1 iteration 19100 32.58

epoch 1 iteration 36300 32.999267578125
epoch 1 iteration 36400 33.047523498535156
epoch 1 iteration 36500 33.42100524902344
epoch 1 iteration 36600 34.011714935302734
epoch 1 iteration 36700 33.48255920410156
epoch 1 iteration 36800 33.760047912597656
epoch 1 iteration 36900 33.53578186035156
epoch 1 iteration 37000 33.38592529296875
epoch 1 iteration 37100 31.876462936401367
epoch 1 iteration 37200 32.43408203125
epoch 1 iteration 37300 32.0902099609375
epoch 1 iteration 37400 33.02770233154297
epoch 1 iteration 37500 32.59451675415039
epoch 1 iteration 37600 33.527523040771484
epoch 1 iteration 37700 33.32722854614258
epoch 1 iteration 37800 32.71855926513672
epoch 1 iteration 37900 33.07994842529297
epoch 1 iteration 38000 32.77999496459961
epoch 1 iteration 38100 34.05043029785156
epoch 1 iteration 38200 31.62129783630371
epoch 1 iteration 38300 34.070396423339844
epoch 1 iteration 38400 31.920671463012695
epoch 1 iteration 38500 33.43434143066406
epoch 1 iteration 38600 32.796722

In [71]:
print(len(text))

15304686
