In [1]:
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
from collections import Counter
import torch.utils.data as Data

import d2l_pytorch as d2l
print(torch.__version__)

1.7.1


In [2]:
with open('Data/ptb/ptb.train.txt', 'r') as f:
    lines = f.readlines()
    raw_dataset = [st.split()
                   for st in lines]
print('# sentences: %d' % len(raw_dataset))

# sentences: 42068


In [3]:
for st in raw_dataset[:3]:
    print('# token:', len(st), st[:5])

# token: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
# token: 15 ['pierre', '<unk>', 'N', 'years', 'old']
# token: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


In [4]:
counter = Counter([tk for st in raw_dataset for tk in st])
counter = dict(filter(lambda x:x[1] >= 5, counter.items()))

In [6]:
idx2token = [tk for tk, _ in counter.items()]
token2idx = {tk:idx for idx, tk in enumerate(idx2token)}
dataset = [[token2idx[tk] for tk in st if tk in token2idx] for st in raw_dataset]
num_token = int(sum([int(len(st)) for st in dataset]))
print('# token: %d' % num_token)

# token: 887100


In [7]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(
        1e-4 / counter[idx2token[idx]] * num_token)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
print('# token: %d' % (sum([int(len(st)) for st in subsampled_dataset])))

# token: 375932


In [9]:
def compare_counts(token):
    return '# %s: before=%d, after=%d' % (token,
                int(sum([st.count(token2idx[token]) for st in dataset])),
                int(sum([st.count(token2idx[token]) for st in subsampled_dataset])))

compare_counts('the')

'# the: before=50770, after=2073'

In [10]:
compare_counts('join')

'# join: before=45, after=45'

In [11]:
def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:
        if len(st) < 2:
            continue
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, center_i - window_size),
                                 min(int(len(st)), center_i + window_size + 1)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    print(contexts)
    return centers, contexts

In [12]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
# get_centers_and_contexts(tiny_dataset, 2)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
[[1, 2], [0, 2, 3], [1, 3], [2, 4], [3, 5], [3, 4, 6], [5], [8], [7, 9], [7, 8]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]


In [13]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.


In [14]:
def get_negatives(all_contexts, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while int(len(negatives)) < int(len(contexts)) * K:
            if i == int(len(neg_candidates)):
                i, neg_candidates = 0, random.choices(population, sampling_weights, k=int(1e5) )
            neg, i = neg_candidates[i], i + 1

            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

In [15]:
sampling_weights = [counter[w] ** 0.75 for w in idx2token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [16]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert int(len(centers)) == int(len(contexts)) == int(len(negatives))
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, item):
        return (self.centers[item], self.contexts[item], self.negatives[item])

    def __len__(self):
        return len(self.centers)

In [17]:
def batchify(data):
    max_len = max(len(c)  + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = int(len(context)) + int(len(negative))
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * int(len(context)) + [0] * (max_len - int(len(context)))]
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

In [18]:
batch_size = 512
num_workers = 4

dataset = MyDataset(all_centers,
                    all_contexts,
                    all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,
                            collate_fn=batchify,
                            num_workers=num_workers)

for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'],
                          batch):
        print(name, 'shape:', data.shape)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


In [19]:
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
embed.weight

Parameter containing:
tensor([[ 1.1082,  0.0223, -0.2641, -0.9674],
        [ 1.0021, -0.1309,  0.9061,  1.6272],
        [ 0.5292, -0.4852,  0.5429,  1.2413],
        [ 1.5484,  1.0803, -2.3830, -0.6018],
        [-0.6238,  0.1711,  1.2688,  0.5511],
        [-0.1328, -1.3113, -0.4525, -0.4474],
        [ 0.7505,  1.4784,  1.8879,  0.0164],
        [-0.0052,  0.6375, -1.6961, -0.0902],
        [ 0.8441, -1.0053, -0.1840, -0.2559],
        [-1.1726,  2.0344, -0.4531, -0.1865],
        [ 2.0633, -0.2314, -0.8438, -0.3874],
        [-0.8828, -0.7985,  0.9833,  1.3108],
        [ 0.8046, -0.0068,  0.1218, -0.1726],
        [-1.9494, -0.9722,  0.8864, -0.4447],
        [ 0.5895, -2.0303,  1.0301, -1.2444],
        [-0.4792, -0.2503,  0.6525,  0.1051],
        [ 0.0525,  1.9913,  0.0883,  0.8583],
        [-0.3125,  1.1683, -0.1469,  0.5906],
        [-0.5531, -0.0248, -0.7794,  1.5271],
        [-0.6502, -0.7264, -0.6477, -0.4103]], requires_grad=True)

In [20]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)
embed(x)

tensor([[[ 1.0021, -0.1309,  0.9061,  1.6272],
         [ 0.5292, -0.4852,  0.5429,  1.2413],
         [ 1.5484,  1.0803, -2.3830, -0.6018]],

        [[-0.6238,  0.1711,  1.2688,  0.5511],
         [-0.1328, -1.3113, -0.4525, -0.4474],
         [ 0.7505,  1.4784,  1.8879,  0.0164]]], grad_fn=<EmbeddingBackward>)

In [29]:
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

torch.Size([2, 1, 6])

In [21]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [22]:
from torch.nn import functional as F

class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()

    def forward(self, inputs, targets, mask=None):
        '''
        :param inputs: Tensor shape: (batch_size, len)
        :param targets: Tensor of the same shape as input
        :param mask:
        :return:
        '''
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)
loss = SigmoidBinaryCrossEntropyLoss()

In [24]:
pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])

label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1)

tensor([0.8740, 1.2100])

In [26]:
embed_size = 100
net = nn.Sequential(
    nn.Embedding(num_embeddings=int(len(idx2token)), embedding_dim=embed_size),
    nn.Embedding(num_embeddings=int(len(idx2token)), embedding_dim=embed_size)
)

In [30]:
def train(net, lr, num_epochs, device):
    print('train on', device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr = lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            center, context_negative, mask, label = [d.to(device) for d in batch]

            pred = skip_gram(center, context_negative, net[0], net[1])
            l = (loss(pred.view(label.shape), label, mask) * mask.shape[1] \
                 / mask.float().sum(dim=1)).mean()
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start))

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train(net, 0.01, 5, device)

train on cpu
epoch 1, loss 1.97, time 55.50s
epoch 2, loss 0.62, time 55.50s
epoch 3, loss 0.45, time 52.01s
epoch 4, loss 0.40, time 52.01s
epoch 5, loss 0.37, time 55.07s
