In [1]:
import collections
import d2ltorch as d2lt
import math
import torch
from torch import nn, optim
from torch.utils import data as tdata
import random
import sys
import time
import zipfile

In [2]:
with zipfile.ZipFile('data/ptb.zip', 'r') as zin:
    zin.extractall('data/')
    pass
pass

with open('data/ptb/ptb.train.txt', 'r') as f:
    lines = f.readlines()
    raw_dataset = [st.split() for st in lines]
    pass
pass

'# sentences: %d' % len(raw_dataset)

'# sentences: 42068'

In [3]:
for st in raw_dataset[:3]:
    print('# tokens:', len(st), st[:5])
    pass
pass

# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']
# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


In [4]:
counter = collections.Counter([tk for st in raw_dataset for tk in st])
counter = dict(filter(lambda x: x[1] >= 5, counter.items()))

In [5]:
idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
          for st in raw_dataset]
num_tokens = sum([len(st) for st in dataset])
'# tokens: %d' % num_tokens

'# tokens: 887100'

In [6]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(
    1e-4 / counter[idx_to_token[idx]] * num_tokens)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
'# tokens: %d' % sum([len(st) for st in subsampled_dataset])

'# tokens: 375793'

In [7]:
def compare_counts(token):
    return '# %s: before=%d, after=%d' % (token, sum(
    [st.count(token_to_idx[token]) for st in dataset]), sum(
    [st.count(token_to_idx[token]) for st in subsampled_dataset]))

In [8]:
compare_counts('the')

'# the: before=50770, after=2177'

In [9]:
def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:
        if len(st) < 2:
            continue
            pass
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            incides = list(range(max(0, center_i - window_size),
                                min(len(st), center_i + 1 + window_size)))
            incides.remove(center_i)
            contexts.append([st[idx] for idx in incides])
            pass
        pass
    return centers, contexts

In [10]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center,'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [2, 4]
center 4 has contexts [2, 3, 5, 6]
center 5 has contexts [3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]


In [11]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [13]:
def get_negatives(all_contexts, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                i, neg_candidates = 0, random.choices(population, sampling_weights, k=int(1e5))
                pass
            neg, i = neg_candidates[i], i + 1
            if neg not in set(contexts):
                negatives.append(neg)
                pass
            pass
        all_negatives.append(negatives)
        pass
    return all_negatives

sampling_weights = [counter[w]**0.75 for w in idx_to_token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [15]:
def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(contexts_negatives),
            torch.tensor(masks, dtype=torch.float), torch.tensor(labels, dtype=torch.float))

In [16]:
class ArrayDataset(tdata.Dataset):
    def __init__(self, *args):
        assert len(args) > 0, "Needs at least 1 arrays"
        self._length = len(args[0])
        self._data = []
        for i, data in enumerate(args):
            assert len(data) == self._length, \
                "All arrays must have the same length; array[0] has length %d " \
                "while array[%d] has %d." % (self._length, i+1, len(data))
            if isinstance(data, torch.Tensor) and len(data.shape) == 1:
                data = data.numpy()
            self._data.append(data)
    
    def __getitem__(self, idx):
        if len(self._data) == 1:
            return self._data[0][idx]
        else:
            return tuple(data[idx] for data in self._data)

    def __len__(self):
        return self._length

In [17]:
batch_size = 512
num_workers = 0 if sys.platform.startswith('win32') else 4
dataset = ArrayDataset(all_centers, all_contexts, all_negatives)
data_iter = tdata.DataLoader(dataset, batch_size, shuffle=True, 
                             collate_fn=batchify ,num_workers=num_workers)

for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks',
                           'labels'], batch):
        print(name, 'shape:', data.shape)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


In [18]:
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
embed.weight.shape

torch.Size([20, 4])

In [19]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
embed(x)

tensor([[[ 0.3835, -0.0390, -0.0791,  0.2224],
         [ 1.4607, -1.1632,  0.1373, -0.3028],
         [ 1.8696, -1.9402, -0.3591, -1.8156]],

        [[-0.0190, -0.0039,  0.8115, -0.3738],
         [ 0.0646, -0.5820, -0.8326, -1.3271],
         [ 1.0805,  1.6886, -0.1212, -0.0629]]], grad_fn=<EmbeddingBackward>)

In [20]:
X = torch.ones(2, 1, 4)
Y = torch.ones(2, 4, 6)
torch.bmm(X, Y).shape

torch.Size([2, 1, 6])

In [21]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [22]:
loss = nn.BCEWithLogitsLoss(reduction='none')

In [23]:
pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])
# 标签变量label中的1和0分别代表背景词和噪声词
label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]]).float()
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]]).float()  # 掩码变量
loss.weight = mask
loss(pred, label).sum(dim=1) / mask.sum(dim=1)

tensor([0.8740, 1.2100])

In [24]:
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))

print('%.7f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4))
print('%.7f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))


0.8739896
1.2099689


In [25]:
embed_size = 100
net = nn.Sequential(
    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),
    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)
)

In [26]:
def train(net, lr, num_epochs):
    device = d2lt.try_gpu()
    net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            optimizer.zero_grad()
            center, context_negative, mask, label = [
                data.to(device) for data in batch]
            
            pred = skip_gram(center, context_negative, net[0], net[1])
            # 使用掩码变量mask来避免填充项对损失函数计算的影响
            loss.weight = mask
            l = (loss(pred.reshape(label.shape), label).sum(dim=1) / mask.sum(dim=1))
            l.backward(torch.ones_like(l))
            optimizer.step()
            l_sum += l.sum().item()
            n += torch.numel(l)
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start))

In [27]:
train(net, 0.005, 5)

epoch 1, loss 2.66, time 73.86s


KeyboardInterrupt: 