In [1]:
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
from collections import Counter
import torch.utils.data as Data

import d2l_pytorch as d2l
print(torch.__version__)

1.7.1


In [2]:
with open('Data/data/ptb.train.txt', 'r') as f:
    lines = f.readlines()
    raw_dataset = [st.split() for st in lines]
print('# sentences: %d' % len(raw_dataset))

# sentences: 42068


In [3]:
for st in raw_dataset[:3]:
    print('# token:', len(st), st[:5])

# token: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
# token: 15 ['pierre', '<unk>', 'N', 'years', 'old']
# token: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


In [4]:
counter = Counter([tk for st in raw_dataset for tk in st])
counter = dict(filter(lambda x:x[1] >= 5, counter.items()))

In [5]:
idx2token = [tk for tk, _ in counter.items()]
token2idx = {tk:idx for idx, tk in enumerate(idx2token)}
dataset = [[token2idx[tk] for tk in st if tk in token2idx] for st in raw_dataset]
num_token = int(sum([int(len(st)) for st in dataset]))
print('# token: %d' % num_token)

# token: 887100


In [6]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(
        1e-4 / counter[idx2token[idx]] * num_token)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
print('# token: %d' % (sum([int(len(st)) for st in subsampled_dataset])))

# token: 375592


In [7]:
def compare_counts(token):
    return '# %s: before=%d, after=%d' % (token,
                int(sum([st.count(token2idx[token]) for st in dataset])),
                int(sum([st.count(token2idx[token]) for st in subsampled_dataset])))

compare_counts('the')

'# the: before=50770, after=2198'

In [8]:
compare_counts('join')

'# join: before=45, after=45'

In [9]:
def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:
        if len(st) < 2:
            continue
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, center_i - window_size),
                                 min(int(len(st)), center_i + window_size + 1)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    print(contexts)
    return centers, contexts

In [10]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
# get_centers_and_contexts(tiny_dataset, 2)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
[[1], [0, 2], [0, 1, 3, 4], [2, 4], [2, 3, 5, 6], [4, 6], [4, 5], [8, 9], [7, 9], [8]]
center 0 has contexts [1]
center 1 has contexts [0, 2]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [2, 4]
center 4 has contexts [2, 3, 5, 6]
center 5 has contexts [4, 6]
center 6 has contexts [4, 5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [8]


In [11]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.


In [12]:
def get_negatives(all_contexts, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while int(len(negatives)) < int(len(contexts)) * K:
            if i == int(len(neg_candidates)):
                i, neg_candidates = 0, random.choices(population, sampling_weights, k=int(1e5) )
            neg, i = neg_candidates[i], i + 1

            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

In [15]:
sampling_weights = [counter[w] ** 0.75 for w in idx2token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [21]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert int(len(centers)) == int(len(contexts)) == int(len(negatives))
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, item):
        return (self.centers[item], self.contexts[item], self.negatives[item])

    def __len__(self):
        return len(self.centers)

In [19]:
def batchify(data):
    max_len = max(len(c)  + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = int(len(context)) + int(len(negative))
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * int(len(context)) + [0] * (max_len - int(len(context)))]
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

In [23]:
batch_size = 512
num_workers = 4

dataset = MyDataset(all_centers,
                    all_contexts,
                    all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,
                            collate_fn=batchify,
                            num_workers=num_workers)

for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'],
                          batch):
        print(name, 'shape:', data.shape)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


In [25]:
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
embed.weight

Parameter containing:
tensor([[ 0.1019, -0.2441,  0.6801, -0.1640],
        [ 0.5275, -0.3540, -0.9393,  1.6713],
        [-0.8598,  1.2049,  0.1575,  0.8977],
        [ 0.2519, -0.8882,  0.9017,  0.2163],
        [ 0.7473,  1.3696,  1.2898,  0.3270],
        [-2.2228, -0.8946,  1.7971, -1.2397],
        [ 0.7047, -0.2511,  1.3355, -0.2577],
        [ 1.9745,  0.3108,  0.0765, -0.2504],
        [ 0.4052,  1.5951,  0.2573, -0.2861],
        [-0.3335, -0.8295,  0.1491,  0.3676],
        [-0.0200,  0.9105,  0.6291,  0.1826],
        [ 2.0091,  0.8152,  0.8912,  0.8021],
        [ 0.1772, -0.0799, -0.7612,  0.3121],
        [ 0.1588, -0.7941,  0.5032,  0.7222],
        [-1.0796,  0.0996,  1.3165,  1.7903],
        [-1.0000, -0.5204,  0.3152,  0.1836],
        [ 0.2163, -0.0470, -0.3259,  0.7394],
        [ 1.5086, -2.0289, -1.1218, -0.2524],
        [ 0.4902,  0.6436, -1.2170, -2.2503],
        [-0.6141,  0.6577, -0.1521, -0.3868]], requires_grad=True)

In [27]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)
embed(x)

tensor([[[ 0.5275, -0.3540, -0.9393,  1.6713],
         [-0.8598,  1.2049,  0.1575,  0.8977],
         [ 0.2519, -0.8882,  0.9017,  0.2163]],

        [[ 0.7473,  1.3696,  1.2898,  0.3270],
         [-2.2228, -0.8946,  1.7971, -1.2397],
         [ 0.7047, -0.2511,  1.3355, -0.2577]]], grad_fn=<EmbeddingBackward>)

In [29]:
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape

torch.Size([2, 1, 6])

In [37]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [35]:
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()

    def forward(self, inputs, targets, mask=None):
        '''
        :param inputs: Tensor shape: (batch_size, len)
        :param targets: Tensor of the same shape as input
        :param mask:
        :return:
        '''
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)
loss = SigmoidBinaryCrossEntropyLoss()

In [36]:
pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])

label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.float().sum(dim=1)

tensor([0.8740, 1.2100])