In [2]:
import collections
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data

sys.path.append("..")
import d2lzh_pytorch as d2l
print(torch.__version__)

1.13.1+cu117


In [3]:
assert 'ptb.train.txt' in os.listdir('../ptb')

with open('../ptb/ptb.train.txt','r') as f:
    lines = f.readlines()
    raw_dataset = [st.split() for st in lines]

'# sentences: %d' % len(raw_dataset)

'# sentences: 42068'

In [4]:
lines

[' aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim snack-food ssangyong swapo wachter \n',
 ' pierre <unk> N years old will join the board as a nonexecutive director nov. N \n',
 ' mr. <unk> is chairman of <unk> n.v. the dutch publishing group \n',
 ' rudolph <unk> N years old and former chairman of consolidated gold fields plc was named a nonexecutive director of this british industrial conglomerate \n',
 ' a form of asbestos once used to make kent cigarette filters has caused a high percentage of cancer deaths among a group of workers exposed to it more than N years ago researchers reported \n',
 ' the asbestos fiber <unk> is unusually <unk> once it enters the <unk> with even brief exposures to it causing symptoms that show up decades later researchers said \n',
 ' <unk> inc. the unit of new york-based <unk> corp. that makes kent cigarettes stopped using <unk> in its <unk> cigarette filters in

In [5]:
raw_dataset

[['aer',
  'banknote',
  'berlitz',
  'calloway',
  'centrust',
  'cluett',
  'fromstein',
  'gitano',
  'guterman',
  'hydro-quebec',
  'ipo',
  'kia',
  'memotec',
  'mlx',
  'nahb',
  'punts',
  'rake',
  'regatta',
  'rubens',
  'sim',
  'snack-food',
  'ssangyong',
  'swapo',
  'wachter'],
 ['pierre',
  '<unk>',
  'N',
  'years',
  'old',
  'will',
  'join',
  'the',
  'board',
  'as',
  'a',
  'nonexecutive',
  'director',
  'nov.',
  'N'],
 ['mr.',
  '<unk>',
  'is',
  'chairman',
  'of',
  '<unk>',
  'n.v.',
  'the',
  'dutch',
  'publishing',
  'group'],
 ['rudolph',
  '<unk>',
  'N',
  'years',
  'old',
  'and',
  'former',
  'chairman',
  'of',
  'consolidated',
  'gold',
  'fields',
  'plc',
  'was',
  'named',
  'a',
  'nonexecutive',
  'director',
  'of',
  'this',
  'british',
  'industrial',
  'conglomerate'],
 ['a',
  'form',
  'of',
  'asbestos',
  'once',
  'used',
  'to',
  'make',
  'kent',
  'cigarette',
  'filters',
  'has',
  'caused',
  'a',
  'high',
  'percen

In [6]:
for st in raw_dataset[:3]:
    print('# tokens:', len(st), st[:5])

# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']
# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


In [7]:
#建立词语索引
counter = collections.Counter([tk for st in raw_dataset for tk in st])
counter = dict(filter(lambda x: x[1] >=5,counter.items()))#只保留在数据集中出现至少5次的词
counter

{'pierre': 6,
 '<unk>': 45020,
 'N': 32481,
 'years': 1241,
 'old': 268,
 'will': 3270,
 'join': 45,
 'the': 50770,
 'board': 612,
 'as': 4833,
 'a': 21196,
 'nonexecutive': 6,
 'director': 359,
 'nov.': 259,
 'mr.': 4326,
 'is': 7337,
 'chairman': 635,
 'of': 24400,
 'n.v.': 13,
 'dutch': 28,
 'publishing': 64,
 'group': 928,
 'rudolph': 8,
 'and': 17474,
 'former': 306,
 'consolidated': 37,
 'gold': 165,
 'fields': 44,
 'plc': 114,
 'was': 4073,
 'named': 210,
 'this': 2438,
 'british': 337,
 'industrial': 243,
 'conglomerate': 22,
 'form': 115,
 'asbestos': 27,
 'once': 219,
 'used': 372,
 'to': 23638,
 'make': 646,
 'kent': 11,
 'cigarette': 18,
 'filters': 11,
 'has': 3494,
 'caused': 110,
 'high': 400,
 'percentage': 142,
 'cancer': 107,
 'deaths': 32,
 'among': 444,
 'workers': 247,
 'exposed': 19,
 'it': 6112,
 'more': 2065,
 'than': 1731,
 'ago': 468,
 'researchers': 85,
 'reported': 430,
 'fiber': 8,
 'unusually': 20,
 'enters': 9,
 'with': 4585,
 'even': 773,
 'brief': 34,
 

In [8]:
idx_to_token = [tk for tk,_ in counter.items()]
token_to_index = {tk : idx for idx,tk in enumerate(idx_to_token)}
token_to_index

{'pierre': 0,
 '<unk>': 1,
 'N': 2,
 'years': 3,
 'old': 4,
 'will': 5,
 'join': 6,
 'the': 7,
 'board': 8,
 'as': 9,
 'a': 10,
 'nonexecutive': 11,
 'director': 12,
 'nov.': 13,
 'mr.': 14,
 'is': 15,
 'chairman': 16,
 'of': 17,
 'n.v.': 18,
 'dutch': 19,
 'publishing': 20,
 'group': 21,
 'rudolph': 22,
 'and': 23,
 'former': 24,
 'consolidated': 25,
 'gold': 26,
 'fields': 27,
 'plc': 28,
 'was': 29,
 'named': 30,
 'this': 31,
 'british': 32,
 'industrial': 33,
 'conglomerate': 34,
 'form': 35,
 'asbestos': 36,
 'once': 37,
 'used': 38,
 'to': 39,
 'make': 40,
 'kent': 41,
 'cigarette': 42,
 'filters': 43,
 'has': 44,
 'caused': 45,
 'high': 46,
 'percentage': 47,
 'cancer': 48,
 'deaths': 49,
 'among': 50,
 'workers': 51,
 'exposed': 52,
 'it': 53,
 'more': 54,
 'than': 55,
 'ago': 56,
 'researchers': 57,
 'reported': 58,
 'fiber': 59,
 'unusually': 60,
 'enters': 61,
 'with': 62,
 'even': 63,
 'brief': 64,
 'exposures': 65,
 'causing': 66,
 'symptoms': 67,
 'that': 68,
 'show': 69,

In [9]:
dataset = [[token_to_index[tk] for tk in st if tk in token_to_index] for st in raw_dataset]
dataset

[[],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2],
 [14, 1, 15, 16, 17, 1, 18, 7, 19, 20, 21],
 [22,
  1,
  2,
  3,
  4,
  23,
  24,
  16,
  17,
  25,
  26,
  27,
  28,
  29,
  30,
  10,
  11,
  12,
  17,
  31,
  32,
  33,
  34],
 [10,
  35,
  17,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  10,
  46,
  47,
  17,
  48,
  49,
  50,
  10,
  21,
  17,
  51,
  52,
  39,
  53,
  54,
  55,
  2,
  3,
  56,
  57,
  58],
 [7,
  36,
  59,
  1,
  15,
  60,
  1,
  37,
  53,
  61,
  7,
  1,
  62,
  63,
  64,
  65,
  39,
  53,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  57,
  73],
 [1,
  74,
  7,
  75,
  17,
  76,
  77,
  1,
  78,
  68,
  79,
  41,
  80,
  81,
  82,
  1,
  83,
  84,
  1,
  42,
  43,
  83,
  2],
 [85,
  86,
  87,
  88,
  58,
  54,
  55,
  10,
  89,
  56,
  7,
  90,
  91,
  92,
  83,
  93,
  94,
  76,
  95,
  96,
  17,
  97,
  10,
  98,
  99,
  39,
  100,
  76,
  101,
  39,
  7,
  102],
 [10, 1, 1, 73, 31, 15, 103, 4, 104],
 [105, 106, 107, 108, 3, 56, 109, 110,

In [10]:
num_tokens = sum([len(st) for st in dataset])
num_tokens

887100

二次采样
文本数据中一般会出现一些高频词，如英文中的“the”“a”和“in”。通常来说，在一个背景窗口中，一个词（如“chip”）和较低频词（如“microprocessor”）同时出现比和较高频词（如“the”）同时出现对训练词嵌入模型更有益。因此，训练词嵌入模型时可以对词进行二次采样 [2]。 具体来说，数据集中每个被索引词wiw
i
​
 将有一定概率被丢弃，该丢弃概率为
 其中 f(wi)f(w
i
​
 ) 是数据集中词wiw
i
​
 的个数与总词数之比，常数tt是一个超参数（实验中设为10−410
−4
 ）。可见，只有当f(wi)>tf(w
i
​
 )>t时，我们才有可能在二次采样中丢弃词wiw
i
​
 ，并且越高频的词被丢弃的概率越大。

In [11]:
def discard(idx):
    return random.uniform(0,1) < 1 -math.sqrt(1e-4/counter[idx_to_token[idx]]*num_tokens)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
'# tokens: %d' % sum([len(st) for st in subsampled_dataset])

'# tokens: 375436'

In [13]:
def compare_counts(token):
    return '# %s: before=%d, after=%d' % (token, sum(
        [st.count(token_to_index[token]) for st in dataset]), sum(
        [st.count(token_to_index[token]) for st in subsampled_dataset]))

compare_counts('the') # '# the: before=50770, after=2013'

'# the: before=50770, after=2097'

10.3.1.3 提取中心词和背景词
我们将与中心词距离不超过背景窗口大小的词作为它的背景词。下面定义函数提取出所有中心词和它们的背景词。它每次在整数1和max_window_size（最大背景窗口）之间随机均匀采样一个整数作为背景窗口大小。

In [14]:
def get_centers_and_contexts(dataset,max_window_size):
    centers,contexts = [],[]
    for st in dataset:
        if len(st) < 2:
            continue
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1,max_window_size)
            indices = list(range(max(0,center_i-window_size),min(len(st),center_i+1+window_size)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    return centers,contexts


In [15]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [8]


In [16]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [17]:
all_centers

[0,
 5,
 6,
 11,
 13,
 16,
 18,
 19,
 20,
 22,
 2,
 3,
 4,
 16,
 25,
 26,
 27,
 28,
 10,
 11,
 12,
 32,
 34,
 35,
 36,
 37,
 40,
 41,
 42,
 43,
 45,
 47,
 17,
 48,
 49,
 21,
 51,
 52,
 53,
 57,
 58,
 36,
 59,
 60,
 37,
 61,
 64,
 65,
 53,
 66,
 67,
 71,
 57,
 77,
 79,
 41,
 80,
 81,
 83,
 42,
 43,
 86,
 87,
 58,
 90,
 91,
 92,
 93,
 95,
 96,
 17,
 97,
 98,
 99,
 100,
 101,
 102,
 73,
 104,
 105,
 107,
 56,
 109,
 110,
 111,
 36,
 112,
 113,
 114,
 115,
 117,
 36,
 83,
 119,
 121,
 122,
 57,
 124,
 51,
 125,
 126,
 128,
 41,
 80,
 130,
 131,
 133,
 134,
 136,
 137,
 138,
 140,
 141,
 57,
 145,
 48,
 146,
 147,
 17,
 148,
 149,
 139,
 149,
 150,
 36,
 29,
 38,
 152,
 153,
 43,
 158,
 159,
 160,
 161,
 144,
 41,
 80,
 43,
 73,
 165,
 166,
 167,
 168,
 169,
 55,
 171,
 173,
 174,
 176,
 51,
 129,
 177,
 170,
 179,
 48,
 7,
 49,
 181,
 1,
 182,
 48,
 183,
 184,
 57,
 186,
 187,
 189,
 190,
 177,
 141,
 47,
 182,
 48,
 49,
 51,
 191,
 192,
 155,
 193,
 194,
 196,
 113,
 36,
 51,
 124,
 197,


In [18]:
all_contexts

[[5, 6, 11, 13],
 [0, 6],
 [0, 5, 11, 13],
 [0, 5, 6, 13],
 [6, 11],
 [18, 19, 20],
 [16, 19, 20],
 [16, 18, 20],
 [16, 18, 19],
 [2],
 [22, 3, 4],
 [22, 2, 4, 16, 25, 26],
 [3, 16],
 [22, 2, 3, 4, 25, 26, 27, 28, 10],
 [3, 4, 16, 26, 27, 28],
 [25, 27],
 [26, 28],
 [4, 16, 25, 26, 27, 10, 11, 12, 32, 34],
 [16, 25, 26, 27, 28, 11, 12, 32, 34],
 [25, 26, 27, 28, 10, 12, 32, 34],
 [26, 27, 28, 10, 11, 32, 34],
 [11, 12, 34],
 [11, 12, 32],
 [36, 37, 40],
 [35, 37, 40],
 [35, 36, 40, 41, 42, 43, 45],
 [35, 36, 37, 41, 42, 43, 45, 47],
 [36, 37, 40, 42, 43, 45],
 [40, 41, 43, 45],
 [36, 37, 40, 41, 42, 45, 47, 17, 48, 49],
 [42, 43, 47, 17],
 [42, 43, 45, 17, 48, 49],
 [47, 48],
 [47, 17, 49, 21],
 [45, 47, 17, 48, 21, 51, 52, 53],
 [48, 49, 51, 52],
 [47, 17, 48, 49, 21, 52, 53, 57, 58],
 [17, 48, 49, 21, 51, 53, 57, 58],
 [21, 51, 52, 57, 58],
 [53, 58],
 [52, 53, 57],
 [59, 60, 37],
 [36, 60, 37],
 [36, 59, 37, 61],
 [36, 59, 60, 61, 64, 65, 53, 66],
 [60, 37, 64, 65],
 [37, 61, 65, 53

负采样我们使用负采样来进行近似训练。对于一对中心词和背景词，我们随机采样KK个噪声词（实验中设K=5K=5）。根据word2vec论文的建议，噪声词采样概率P(w)P(w)设为ww词频与总词频之比的0.75次方 [2]。

In [25]:
def get_negatives(all_contexts, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                # 根据每个词的权重（sampling_weights）随机生成k个词的索引作为噪声词。
                # 为了高效计算，可以将k设得稍大一点
                i, neg_candidates = 0, random.choices(
                    population, sampling_weights, k=int(1e5))
            neg, i = neg_candidates[i], i + 1
            # 噪声词不能是背景词
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

sampling_weights = [counter[w]**0.75 for w in idx_to_token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)
print(all_centers)
len(all_centers)



[0, 5, 6, 11, 13, 16, 18, 19, 20, 22, 2, 3, 4, 16, 25, 26, 27, 28, 10, 11, 12, 32, 34, 35, 36, 37, 40, 41, 42, 43, 45, 47, 17, 48, 49, 21, 51, 52, 53, 57, 58, 36, 59, 60, 37, 61, 64, 65, 53, 66, 67, 71, 57, 77, 79, 41, 80, 81, 83, 42, 43, 86, 87, 58, 90, 91, 92, 93, 95, 96, 17, 97, 98, 99, 100, 101, 102, 73, 104, 105, 107, 56, 109, 110, 111, 36, 112, 113, 114, 115, 117, 36, 83, 119, 121, 122, 57, 124, 51, 125, 126, 128, 41, 80, 130, 131, 133, 134, 136, 137, 138, 140, 141, 57, 145, 48, 146, 147, 17, 148, 149, 139, 149, 150, 36, 29, 38, 152, 153, 43, 158, 159, 160, 161, 144, 41, 80, 43, 73, 165, 166, 167, 168, 169, 55, 171, 173, 174, 176, 51, 129, 177, 170, 179, 48, 7, 49, 181, 1, 182, 48, 183, 184, 57, 186, 187, 189, 190, 177, 141, 47, 182, 48, 49, 51, 191, 192, 155, 193, 194, 196, 113, 36, 51, 124, 197, 198, 73, 201, 202, 203, 204, 40, 42, 43, 187, 209, 210, 123, 211, 68, 213, 214, 215, 36, 217, 36, 218, 219, 147, 221, 73, 212, 223, 198, 224, 225, 184, 227, 228, 156, 229, 230, 231, 232

374502

10.3.3 读取数据
我们从数据集中提取所有中心词all_centers，以及每个中心词对应的背景词all_contexts和噪声词all_negatives。我们先定义一个Dataset类。

In [33]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,centers,contexts,negatives):

        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, index):
        return (self.centers[index],self.contexts[index],self.negatives[index])

    def __len__(self):
        return len(self.centers)

我们将通过随机小批量来读取它们。在一个小批量数据中，第ii个样本包括一个中心词以及它所对应的nin
i
​
 个背景词和mim
i
​
 个噪声词。由于每个样本的背景窗口大小可能不一样，其中背景词与噪声词个数之和ni+min
i
​
 +m
i
​
 也会不同。在构造小批量时，我们将每个样本的背景词和噪声词连结在一起，并添加填充项0直至连结后的长度相同，即长度均为maxini+mimax
i
​
 n
i
​
 +m
i
​
 （max_len变量）。为了避免填充项对损失函数计算的影响，我们构造了掩码变量masks，其每一个元素分别与连结后的背景词和噪声词contexts_negatives中的元素一一对应。当contexts_negatives变量中的某个元素为填充项时，相同位置的掩码变量masks中的元素取0，否则取1。为了区分正类和负类，我们还需要将contexts_negatives变量中的背景词和噪声词区分开来。依据掩码变量的构造思路，我们只需创建与contexts_negatives变量形状相同的标签变量labels，并将与背景词（正类）对应的元素设1，其余清0。

下面我们实现这个小批量读取函数batchify。它的小批量输入data是一个长度为批量大小的列表，其中每个元素分别包含中心词center、背景词context和噪声词negative。该函数返回的小批量数据符合我们需要的格式，例如，包含了掩码变量。

In [26]:
all_contexts

[[5, 6, 11, 13],
 [0, 6],
 [0, 5, 11, 13],
 [0, 5, 6, 13],
 [6, 11],
 [18, 19, 20],
 [16, 19, 20],
 [16, 18, 20],
 [16, 18, 19],
 [2],
 [22, 3, 4],
 [22, 2, 4, 16, 25, 26],
 [3, 16],
 [22, 2, 3, 4, 25, 26, 27, 28, 10],
 [3, 4, 16, 26, 27, 28],
 [25, 27],
 [26, 28],
 [4, 16, 25, 26, 27, 10, 11, 12, 32, 34],
 [16, 25, 26, 27, 28, 11, 12, 32, 34],
 [25, 26, 27, 28, 10, 12, 32, 34],
 [26, 27, 28, 10, 11, 32, 34],
 [11, 12, 34],
 [11, 12, 32],
 [36, 37, 40],
 [35, 37, 40],
 [35, 36, 40, 41, 42, 43, 45],
 [35, 36, 37, 41, 42, 43, 45, 47],
 [36, 37, 40, 42, 43, 45],
 [40, 41, 43, 45],
 [36, 37, 40, 41, 42, 45, 47, 17, 48, 49],
 [42, 43, 47, 17],
 [42, 43, 45, 17, 48, 49],
 [47, 48],
 [47, 17, 49, 21],
 [45, 47, 17, 48, 21, 51, 52, 53],
 [48, 49, 51, 52],
 [47, 17, 48, 49, 21, 52, 53, 57, 58],
 [17, 48, 49, 21, 51, 53, 57, 58],
 [21, 51, 52, 57, 58],
 [53, 58],
 [52, 53, 57],
 [59, 60, 37],
 [36, 60, 37],
 [36, 59, 37, 61],
 [36, 59, 60, 61, 64, 65, 53, 66],
 [60, 37, 64, 65],
 [37, 61, 65, 53

In [27]:
len(all_contexts)

374502

In [28]:
all_negatives

[[6970,
  770,
  127,
  2605,
  9137,
  3233,
  3613,
  3834,
  8651,
  7880,
  2928,
  35,
  7,
  9344,
  476,
  2,
  170,
  2737,
  1644,
  3237],
 [2012, 7037, 1098, 7, 1725, 2909, 1346, 1243, 391, 4471],
 [2,
  5573,
  1056,
  4708,
  2233,
  527,
  1,
  443,
  10,
  805,
  24,
  1466,
  824,
  3398,
  7566,
  1248,
  3549,
  3298,
  127,
  1858],
 [1964,
  70,
  1751,
  527,
  6229,
  9,
  827,
  84,
  3686,
  23,
  909,
  312,
  7959,
  1983,
  1828,
  886,
  8944,
  90,
  6615,
  1],
 [3290, 1, 3383, 713, 435, 3342, 15, 349, 226, 1629],
 [643,
  461,
  8027,
  2933,
  8852,
  1783,
  39,
  4028,
  4115,
  5715,
  108,
  314,
  1585,
  21,
  2127],
 [9120,
  4365,
  392,
  9629,
  402,
  2412,
  95,
  8818,
  3492,
  7,
  1205,
  1461,
  325,
  226,
  6026],
 [3253,
  9353,
  6591,
  1028,
  7510,
  116,
  423,
  898,
  3413,
  1,
  953,
  885,
  2437,
  129,
  519],
 [3529,
  15,
  390,
  1348,
  2,
  1700,
  7071,
  2638,
  9267,
  39,
  2774,
  2493,
  1292,
  9632,
  8700],
 

In [29]:
len(all_negatives)

374502

In [30]:
def batchify(data):
    """用作DataLoader的参数collate_fn: 输入是个长为batchsize的list,
    list中的每个元素都是Dataset类调用__getitem__得到的结果
    """
    max_len = max(len(c) + len(n) for _,c,n in data)
    centers,contexts_negatives,masks,labels = [],[],[],[]
    for center,context,negative in data:
        cur_len = len(context) + len(negative)
        centers +=[center]
        contexts_negatives +=[context + negative + [0]*(max_len-cur_len)]
        masks +=[[1]*cur_len+[0]*(max_len-cur_len)]
        labels +=[[1]*len(context) + [0]*(max_len-len(context))]
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

In [35]:
batch_size = 32
num_workers = 0 if sys.platform.startswith('win32') else 4
dataset = MyDataset(all_centers,
                    all_contexts,
                    all_negatives)
data_iter = Data.DataLoader(dataset,batch_size,shuffle=True,collate_fn=batchify,num_workers=num_workers)

for batch in data_iter:
     for name, data in zip(['centers', 'contexts_negatives', 'masks',
                           'labels'], batch):
        print(name, 'shape:', data.shape)
     break

centers shape: torch.Size([32, 1])
contexts_negatives shape: torch.Size([32, 60])
masks shape: torch.Size([32, 60])
labels shape: torch.Size([32, 60])


我们将通过使用嵌入层和小批量乘法来实现跳字模型。它们也常常用于实现其他自然语言处理的应用。

获取词嵌入的层称为嵌入层，在PyTorch中可以通过创建nn.Embedding实例得到。嵌入层的权重是一个矩阵，其行数为词典大小（num_embeddings），列数为每个词向量的维度（embedding_dim）。我们设词典大小为20，词向量的维度为4。

In [36]:
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
embed.weight

Parameter containing:
tensor([[-0.4339,  2.2959, -0.2563, -0.1267],
        [-0.7528, -0.6395,  0.4897,  1.5668],
        [ 0.3544,  0.2492,  0.6429,  0.0135],
        [ 0.8456, -0.0415, -0.4465,  1.2168],
        [ 1.3540, -1.4556, -0.5895, -0.1792],
        [ 0.9285,  0.9146,  1.0979, -1.0584],
        [-0.3171,  0.3416,  1.8385,  0.0076],
        [-1.8653,  0.7264, -0.8505, -0.9399],
        [-1.4109,  0.7909, -0.6108, -0.7085],
        [ 0.8495,  2.0391, -1.4961, -1.8778],
        [ 0.3412, -1.7167, -0.0840,  0.6822],
        [ 0.6801,  1.7708,  1.1606,  0.6201],
        [ 0.0692, -0.7331,  0.4125,  1.3196],
        [ 0.4397, -0.5218, -0.6722, -0.2058],
        [-0.5475,  0.6126, -0.8376, -0.2054],
        [ 0.7671,  1.8827,  0.0449,  0.0324],
        [-1.8180, -0.4105, -0.4899, -2.3536],
        [ 1.3216, -1.2274,  0.3201, -0.1502],
        [-0.0486,  0.2915,  0.3544, -1.4220],
        [ 0.4435, -2.0156, -0.2989, -0.3479]], requires_grad=True)

嵌入层的输入为词的索引。输入一个词的索引ii，嵌入层返回权重矩阵的第ii行作为它的词向量。下面我们将形状为(2, 3)的索引输入进嵌入层，由于词向量的维度为4，我们得到形状为(2, 3, 4)的词向量。

In [37]:
x = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.long)
x

tensor([[1, 2, 3],
        [4, 5, 6]])

In [38]:
embed(x)

tensor([[[-0.7528, -0.6395,  0.4897,  1.5668],
         [ 0.3544,  0.2492,  0.6429,  0.0135],
         [ 0.8456, -0.0415, -0.4465,  1.2168]],

        [[ 1.3540, -1.4556, -0.5895, -0.1792],
         [ 0.9285,  0.9146,  1.0979, -1.0584],
         [-0.3171,  0.3416,  1.8385,  0.0076]]], grad_fn=<EmbeddingBackward0>)

10.3.4.2 小批量乘法
我们可以使用小批量乘法运算bmm对两个小批量中的矩阵一一做乘法。假设第一个小批量中包含nn个形状为a×ba×b的矩阵X1,…,XnX
1
​
 ,…,X
n
​
 ，第二个小批量中包含nn个形状为b×cb×c的矩阵Y1,…,YnY
1
​
 ,…,Y
n
​
 。这两个小批量的矩阵乘法输出为nn个形状为a×ca×c的矩阵X1Y1,…,XnYnX
1
​
 Y
1
​
 ,…,X
n
​
 Y
n
​
 。因此，给定两个形状分别为(nn, aa, bb)和(nn, bb, cc)的Tensor，小批量乘法输出的形状为(nn, aa, cc)。

In [39]:
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

torch.Size([2, 1, 6])

10.3.4.3 跳字模型前向计算
在前向计算中，跳字模型的输入包含中心词索引center以及连结的背景词与噪声词索引contexts_and_negatives。其中center变量的形状为(批量大小, 1)，而contexts_and_negatives变量的形状为(批量大小, max_len)。这两个变量先通过词嵌入层分别由词索引变换为词向量，再通过小批量乘法得到形状为(批量大小, 1, max_len)的输出。输出中的每个元素是中心词向量与背景词向量或噪声词向量的内积。

In [40]:
def skip_gram(center,context_and_negatives,embed_v,embed_u):
    v = embed_v(center)
    u = embed_u(context_and_negatives)
    pred = torch.bmm(v,u.permute(0,2,1))
    return pred


In [42]:
#二元交叉熵损失函数
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self): # none mean sum
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()
    def forward(self, inputs, targets, mask=None):
        """
        input – Tensor shape: (batch_size, len)
        target – Tensor of the same shape as input
        """
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)

loss = SigmoidBinaryCrossEntropyLoss()

值得一提的是，我们可以通过掩码变量指定小批量中参与损失函数计算的部分预测值和标签：当掩码为1时，相应位置的预测值和标签将参与损失函数的计算；当掩码为0时，相应位置的预测值和标签则不参与损失函数的计算。我们之前提到，掩码变量可用于避免填充项对损失函数计算的影响。

In [52]:
pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])
# 标签变量label中的1和0分别代表背景词和噪声词
label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])  # 掩码变量)
print(loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1))
print(mask.shape)
print(mask.shape[1])
print(mask.float().sum(dim=1))

tensor([0.8740, 1.2100])
torch.Size([2, 4])
4
tensor([4., 3.])


In [44]:
#初始化模型参数
embed_size = 100
net = nn.Sequential(
    nn.Embedding(num_embeddings=len(idx_to_token),embedding_dim=embed_size),
    nn.Embedding(num_embeddings=len(idx_to_token),embedding_dim=embed_size)
)

In [45]:
#定义训练函数
def train(net, lr, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("train on", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            center, context_negative, mask, label = [d.to(device) for d in batch]

            pred = skip_gram(center, context_negative, net[0], net[1])

            # 使用掩码变量mask来避免填充项对损失函数计算的影响
            l = (loss(pred.view(label.shape), label, mask) *
                 mask.shape[1] / mask.float().sum(dim=1)).mean() # 一个batch的平均loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start))

In [46]:
train(net,0.01,10)

train on cuda
epoch 1, loss 1.13, time 23.88s
epoch 2, loss 0.57, time 21.21s
epoch 3, loss 0.54, time 21.54s
epoch 4, loss 0.53, time 20.84s
epoch 5, loss 0.53, time 22.88s
epoch 6, loss 0.52, time 22.25s
epoch 7, loss 0.52, time 22.38s
epoch 8, loss 0.52, time 22.34s
epoch 9, loss 0.52, time 21.53s
epoch 10, loss 0.52, time 21.31s


In [54]:
def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    x = W[token_to_index[query_token]]
    # 添加的1e-9是为了数值稳定性
    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()
    _, topk = torch.topk(cos, k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[1:]:  # 除去输入词
        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))

get_similar_tokens('chip', 3, net[0])

cosine sim=0.434: microprocessors
cosine sim=0.432: models
cosine sim=0.422: butcher
