# 10.3 word2vec的实现

In [267]:
import collections
import math
import random
import sys
import time
import os
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data

sys.path.append("..") 
import d2lzh_pytorch as d2l
print(torch.__version__)

1.9.0


## 10.3.1 处理数据集

In [268]:
assert 'ptb.train.txt' in os.listdir("../../data/ptb")

In [269]:
with open('../../data/ptb/ptb.train.txt', 'r') as f:
    lines = f.readlines()
    # st是sentence的缩写
    raw_dataset = [st.split() for st in lines]

'# sentences: %d' % len(raw_dataset)

'# sentences: 42068'

In [70]:
# raw_dataset[0]

In [71]:
for st in raw_dataset[:3]:
    print('# tokens:', len(st), st[:])

# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter']
# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov.', 'N']
# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of', '<unk>', 'n.v.', 'the', 'dutch', 'publishing', 'group']


### 10.3.1.1 建立词语索引

In [72]:
# tk是token的缩写
counter = collections.Counter([tk for st in raw_dataset for tk in st])
# counter.items()

In [73]:
# 为了计算简单，我们只保留在数据集中至少出现5次的词
# 除去出现次数少于5的字符
counter = dict(filter(lambda x: x[1] >= 5, counter.items()))
# counter

In [74]:
idx_to_token = [tk for tk, _ in counter.items()]
# idx_to_token

In [75]:
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
list(token_to_idx.items())[-1], len(token_to_idx), token_to_idx

(('unilab', 9857),
 9858,
 {'pierre': 0,
  '<unk>': 1,
  'N': 2,
  'years': 3,
  'old': 4,
  'will': 5,
  'join': 6,
  'the': 7,
  'board': 8,
  'as': 9,
  'a': 10,
  'nonexecutive': 11,
  'director': 12,
  'nov.': 13,
  'mr.': 14,
  'is': 15,
  'chairman': 16,
  'of': 17,
  'n.v.': 18,
  'dutch': 19,
  'publishing': 20,
  'group': 21,
  'rudolph': 22,
  'and': 23,
  'former': 24,
  'consolidated': 25,
  'gold': 26,
  'fields': 27,
  'plc': 28,
  'was': 29,
  'named': 30,
  'this': 31,
  'british': 32,
  'industrial': 33,
  'conglomerate': 34,
  'form': 35,
  'asbestos': 36,
  'once': 37,
  'used': 38,
  'to': 39,
  'make': 40,
  'kent': 41,
  'cigarette': 42,
  'filters': 43,
  'has': 44,
  'caused': 45,
  'high': 46,
  'percentage': 47,
  'cancer': 48,
  'deaths': 49,
  'among': 50,
  'workers': 51,
  'exposed': 52,
  'it': 53,
  'more': 54,
  'than': 55,
  'ago': 56,
  'researchers': 57,
  'reported': 58,
  'fiber': 59,
  'unusually': 60,
  'enters': 61,
  'with': 62,
  'even': 63,


In [76]:
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
           for st in raw_dataset]
dataset[1]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2]

In [77]:
num_tokens = sum([len(st) for st in dataset])
'# tokens: %d' % num_tokens

'# tokens: 887100'

### 10.3.1.2 二次采样

In [78]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(
        1e-4 / counter[idx_to_token[idx]] * num_tokens)


subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
'# tokens: %d' % sum([len(st) for st in subsampled_dataset])

'# tokens: 375370'

In [79]:
[idx_to_token[idx] for idx in subsampled_dataset[2]]

['n.v.', 'dutch', 'publishing']

In [80]:
def compare_counts(token):
    return '# %s: before=%d, after=%d' % (token, sum(
        [st.count(token_to_idx[token]) for st in dataset]), sum(
        [st.count(token_to_idx[token]) for st in subsampled_dataset]))

compare_counts('the')

'# the: before=50770, after=2107'

In [81]:
compare_counts('join')

'# join: before=45, after=45'

In [82]:
compare_counts('mom')

'# mom: before=6, after=6'

### 10.3.1.3 提取中心词和背景词

In [83]:
def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:
        if len(st) < 2:  # 每个句子至少要有2个词才可能组成一对“中心词-背景词”
            continue
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, center_i - window_size),
                                 min(len(st), center_i + 1 + window_size)))
            indices.remove(center_i)  # 将中心词排除在背景词之外
            contexts.append([st[idx] for idx in indices])
    return centers, contexts

In [84]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [3, 4, 6]
center 6 has contexts [4, 5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]


In [85]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [86]:
all_centers[100], all_contexts[100]

(17, [133, 134, 138, 139, 48, 140])

In [87]:
idx_to_token[all_centers[101]], [idx_to_token[idx] for idx in all_contexts[101]]

('boston', ['a.', 'of', 'cancer', 'institute'])

## 10.3.2 负采样

In [88]:
# idx_to_token

In [89]:
sampling_weights = [counter[w]**0.75 for w in idx_to_token]
# sampling_weights

In [90]:
def get_negatives(all_contexts, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                # 根据每个词的权重（sampling_weights）随机生成k个词的索引作为噪声词。
                # 为了高效计算，可以将k设得稍大一点
                i, neg_candidates = 0, random.choices(
                    population, sampling_weights, k=int(1e5))
            neg, i = neg_candidates[i], i + 1
            # 噪声词不能是背景词
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [91]:
all_centers[0], all_contexts[0], all_negatives[0]

(0,
 [4, 6, 11, 12],
 [222,
  6515,
  9735,
  3592,
  7507,
  7705,
  7,
  191,
  1739,
  2147,
  693,
  603,
  94,
  8308,
  442,
  385,
  4303,
  6247,
  182,
  163])

## 10.3.3 读取数据
* Basically, the collate_fn receives a list of tuples if your `__getitem__` function from a Dataset subclass returns a tuple, or just a normal list if your Dataset subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size examples together as you would using torch.stack (not exactly it, but it is simple like that).

* Suppose for example, you want to create batches of a list of varying dimension tensors. The below code pads sequences with 0 until the maximum sequence size of the batch, that is why we need the `collate_fn`, because a standard batching algorithm (simply using `torch.stack`) won’t work in this case, and we need to manually pad different sequences with variable length to the same size before creating the batch.

In [92]:
idx_to_token[0]

'pierre'

In [128]:
def batchify(data):
    """用作DataLoader的参数collate_fn: 输入是个长为batchsize的list, list中的每个元素都是__getitem__得到的结果"""
    max_len = max(len(c) + len(n) for _, c, n in data)
#     print(f"max_len - {max_len}")
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
#         print(f"centers - {centers}")
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

In [129]:
batch_size = 512

example_data = zip(
    all_centers[:batch_size], 
    all_contexts[:batch_size], 
    all_negatives[:batch_size]
)

# for center, context, negative in example_data:
#     print(f"{center}\n{context}\n{negative}")
#     break

In [130]:
# batchify(list(example_data))

In [131]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, index):
        return (self.centers[index], self.contexts[index], self.negatives[index])

    def __len__(self):
        return len(self.centers)

In [132]:
batch_size = 512
num_workers = 0 if sys.platform.startswith('win32') else 16

dataset = MyDataset(
    all_centers,
    all_contexts,
    all_negatives
)

In [133]:
# len(dataset), dataset[0], num_workers

In [134]:
data_iter = Data.DataLoader(
    dataset, 
    batch_size, 
    shuffle=True,
    collate_fn=batchify,
#     num_workers=num_workers
)

In [122]:
# data_iter

In [136]:
for i, batch in enumerate(data_iter):
    for name, data in zip(['centers', 'contexts_negatives', 'masks',
                           'labels'], batch):
        print(name, 'shape:', data.shape)
#     print(batch)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


## 10.3.4 跳字模型
### 10.3.4.1 嵌入层

In [148]:
num_embeddings = len(idx_to_token)
embed_size = 100

embed = nn.Embedding(
    num_embeddings=num_embeddings, 
    embedding_dim=embed_size
)
embed.weight.shape

torch.Size([9858, 100])

In [168]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)
embed(x).shape
# embed(x)[0][0] == embed.weight[1]

torch.Size([2, 3, 100])

In [194]:
c = torch.zeros((3, 1), dtype=torch.long)
o = torch.ones((3, 10), dtype=torch.long)

v = embed(c) 
u = embed(o)

v.shape, u.shape

(torch.Size([3, 1, 100]), torch.Size([3, 10, 100]))

In [195]:
torch.bmm(v, u.permute(0, 2, 1)).shape

torch.Size([3, 1, 10])

### 10.3.4.2 小批量乘法

In [177]:
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape, X

(torch.Size([2, 1, 6]),
 tensor([[[1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.]]]))

### 10.3.4.3 跳字模型前向计算

In [171]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

## 10.3.5 训练模型
### 10.3.5.1 二元交叉熵损失函数

In [218]:
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self): # none mean sum
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()
    def forward(self, inputs, targets, mask=None):
        """
        input – Tensor shape: (batch_size, len)
        target – Tensor of the same shape as input
        """
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)

loss = SigmoidBinaryCrossEntropyLoss()

值得一提的是，我们可以通过掩码变量指定小批量中参与损失函数计算的部分预测值和标签：当掩码为1时，相应位置的预测值和标签将参与损失函数的计算；当掩码为0时，相应位置的预测值和标签则不参与损失函数的计算。我们之前提到，掩码变量可用于避免填充项对损失函数计算的影响。

In [233]:
pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]], dtype = torch.float)
# 标签变量label中的1和0分别代表背景词和噪声词
label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]], dtype = torch.float)
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])  # 掩码变量
loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1)

tensor([0.8740, 1.2100])

In [234]:
loss(pred, label, mask)* mask.shape[1] 

tensor([3.4960, 3.6299])

In [235]:
mask.shape[1], mask.float().sum(dim=1)

(4, tensor([4., 3.]))

In [241]:
F.binary_cross_entropy_with_logits(pred[0], label[0]), F.binary_cross_entropy_with_logits(pred[1], label[1], weight=mask[1])*4/3

(tensor(0.8740), tensor(1.2100))

### 10.3.5.2 初始化模型参数

In [237]:
len(idx_to_token)

9858

In [238]:
num_embeddings = len(idx_to_token)
embed_size = 100
net = nn.Sequential(
    nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embed_size),
    nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embed_size)
)

### 10.3.5.3 定义训练函数

In [239]:
def train(net, lr, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("train on", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            center, context_negative, mask, label = [d.to(device) for d in batch]
            
            pred = skip_gram(center, context_negative, net[0], net[1])
            
            # 使用掩码变量mask来避免填充项对损失函数计算的影响
            l = (loss(pred.view(label.shape), label, mask) *
                 mask.shape[1] / mask.float().sum(dim=1)).mean() # 一个batch的平均loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start))

In [240]:
train(net, 0.01, 10)

train on cpu
epoch 1, loss 1.98, time 39.84s
epoch 2, loss 0.62, time 40.26s
epoch 3, loss 0.45, time 39.43s
epoch 4, loss 0.39, time 40.20s
epoch 5, loss 0.37, time 39.89s
epoch 6, loss 0.35, time 40.54s
epoch 7, loss 0.34, time 40.72s
epoch 8, loss 0.33, time 42.60s
epoch 9, loss 0.32, time 42.64s
epoch 10, loss 0.32, time 42.79s


In [244]:
for name, param in net.named_parameters():
    print(f"{name} - {param.shape}")

0.weight - torch.Size([9858, 100])
1.weight - torch.Size([9858, 100])


## 10.3.6 应用词嵌入模型

In [242]:
def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    x = W[token_to_idx[query_token]]
    # 添加的1e-9是为了数值稳定性
    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()
    _, topk = torch.topk(cos, k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[:]:  # 除去输入词
        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))
        
get_similar_tokens('chip', 3, net[0])

cosine sim=1.000: chip
cosine sim=0.422: seeing
cosine sim=0.412: film
cosine sim=0.399: computers


In [262]:
man_idx = token_to_idx["man"]
woman_idx = token_to_idx["woman"]
father_idx = token_to_idx["father"]
mother_idx = token_to_idx["mather"]

In [263]:
man_vector = net[0].weight.data[man_idx]
woman_vector = net[0].weight.data[woman_idx]
father_vector = net[0].weight.data[father_idx]
mother_vector = net[0].weight.data[mother_idx]

In [264]:
v1 = man_vector-woman_vector
v2 = father_vector-mother_vector

In [265]:
cos = torch.matmul(v1, v2) / (torch.sum(v1 * v1) * torch.sum(v2 * v2)).sqrt()

In [266]:
cos

tensor(0.1114)