In [1]:
import math
import os
import random
import torch
from d2l import torch as d2l
import requests
import hashlib
import zipfile
import collections
import errno
import numpy as np
from six.moves import urllib
from six.moves import xrange

In [2]:
data_dir= "word2vec_data/words/words"
data_url= "http://mattmahoney.net/dc/text8.zip"

In [3]:
def fetch_words_data(url=data_url, words_data=data_dir):
    os.makedirs(words_data, exist_ok=True)
    zip_path= os.path.join(words_data, "words.zip")
    if not os.path.exists(zip_path):
        urllib.request.urlretrieve(url, zip_path)
    with zipfile.ZipFile(zip_path) as f:
        data=f.read(f.namelist()[0])
    return data.decode("ascii").split()
sentences=fetch_words_data()
f'# sentences: {len(sentences)}'

'# sentences: 17005207'

In [4]:
class Vocab:
    def __init__(self, tokens=[], min_frequency=0, reserved_tokens=[]):

        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        counter = collections.Counter(tokens)
        self.token_frequency = sorted(counter.items(), key=lambda x: x[1],reverse=True)
        self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [token for token, frequency in self.token_frequency if frequency >= min_frequency])))
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
    def __len__(self):
        return len(self.idx_to_token)
    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]
    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]
    @property
    def unk(self): #for uknown token
        return self.token_to_idx['<unk>']

In [6]:
vocab = Vocab(sentences, min_frequency=10)
f'vocab size: {len(vocab)}'

'vocab size: 47135'

In [7]:
def corpus_count(tokens):
    if len(tokens) == 0 or isinstance(tokens[0], list):
        tokens = [token for line in tokens for token in line]
    return collections.Counter(tokens)

In [9]:
def subsampling(sentences, vocab):
    sentences = [[token for token in line if vocab[token] != vocab.unk] for line in sentences]
    counter = corpus_count(sentences)
    num_tokens = sum(counter.values())
    def keep(token):
        return (random.uniform(0, 1) < math.sqrt(1e-4 / counter[token] * num_tokens))
    return ([[token for token in line if keep(token)] for line in sentences], counter)
subsampled, counter = subsampling(sentences, vocab)

In [15]:
corpus = [vocab[line] for line in subsampled]
corpus[:4]

[[], [], [], []]

In [11]:
def get_centers_and_contexts(corpus, max_window_size):
    centers, contexts = [], []
    for line in corpus:
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)): 
            window_size = max_window_size 
            indices = list(range(max(0, i - window_size),min(len(line), i + 1 + window_size)))
            indices.remove(i)
            contxt = [line[idx] for idx in indices]
            contexts.append(contxt)
    return centers, contexts

In [13]:
small_dataset = [list(range(6)), list(range(6, 9))]
print('dataset', small_dataset)
for center, context in zip(*get_centers_and_contexts(small_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5], [6, 7, 8]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [1, 2, 4, 5]
center 4 has contexts [2, 3, 5]
center 5 has contexts [3, 4]
center 6 has contexts [7, 8]
center 7 has contexts [6, 8]
center 8 has contexts [6, 7]


In [14]:
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'

'# center-context pairs: 919980'

In [18]:
#Negative Sampling
class RandomGenerator:
    def __init__(self, sampling_weights):
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0
    def draw(self):
        if self.i == len(self.candidates):
            self.candidates = random.choices(self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]

In [19]:
def get_negatives(all_contexts, vocab, counter, N):
    sampling_weights = [ counter[vocab.to_tokens(i)]**0.75 for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * N:
            neg = generator.draw()
            if neg not in contexts:
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives
all_negatives = get_negatives(all_contexts, vocab, counter, 5)

In [20]:
def generate_batches(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(contexts_negatives), torch.tensor(masks),torch.tensor(labels))

In [21]:
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = generate_batches((x_1, x_2))
names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)

centers = tensor([[1],
        [1]])
contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],
        [2, 2, 2, 3, 3, 0]])
masks = tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0]])
labels = tensor([[1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0]])


In [24]:
def words_data_load(batch_size, max_window_size, num_noise_words):
    num_workers = d2l.get_dataloader_workers()
    sentences = fetch_words_data()
    vocab = Vocab(sentences, min_frequency=10)
    subsampled, counter = subsampling(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(corpus, max_window_size)
    all_negatives = get_negatives(all_contexts, vocab, counter,num_noise_words)
    class Words(torch.utils.data.Dataset):
        def __init__(self, centers, contexts, negatives):
            assert len(centers) == len(contexts) == len(negatives) #assert in front
            self.centers = centers
            self.contexts = contexts
            self.negatives = negatives
        def __getitem__(self, index):
            return (self.centers[index], self.contexts[index], self.negatives[index])
        def __len__(self):
            return len(self.centers)
    dataset = Words(all_centers, all_contexts, all_negatives)
    data_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=generate_batches, num_workers=num_workers)
    return data_iter, vocab

In [25]:
import math
import torch
from torch import nn
from d2l import torch as d2l
batch_size, max_window_size, num_noise_words = 422, 4, 5
data_iter, vocab = words_data_load(batch_size, max_window_size,num_noise_words)

In [26]:
embedd = nn.Embedding(num_embeddings=20, embedding_dim=4)
print(f'Parameter embedding_weight ({embedd.weight.shape}, 'f'dtype={embedd.weight.dtype})')

Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)


In [27]:
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
embedd(x)

tensor([[[-0.4135, -0.6137, -0.0303,  1.2562],
         [-1.0078, -2.1746,  0.6127, -0.1101],
         [ 0.4498,  0.0501,  0.4286, -0.5150]],

        [[ 1.1400,  0.1930,  0.6193, -0.5477],
         [ 0.3997, -0.3701, -0.5582,  0.3698],
         [ 1.4657,  1.4002, -0.0736, -0.6114]]], grad_fn=<EmbeddingBackward0>)

In [28]:
def skip_gram(center, contexts_and_negatives, embedd_v, embedd_u):
    v = embedd_v(center)
    u = embedd_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))
    return pred

In [29]:
skip_gram(torch.ones((2, 1), dtype=torch.long),torch.ones((2, 4), dtype=torch.long), embedd, embedd).shape

torch.Size([2, 1, 4])

In [30]:
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))
print(f'{(sigmd(1.1) + sigmd(2.2) + sigmd(-3.3) + sigmd(4.4)) / 4:.4f}')
print(f'{(sigmd(-1.1) + sigmd(-2.2)) / 2:.4f}')

0.9352
1.8462


In [31]:
embedd_size = 50
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embedd_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embedd_size))

In [32]:
def train(net, data_iter, lr, num_epochs, device=d2l.try_gpu()):
    def init_weights(m):
        if type(m) == nn.Embedding:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[1, num_epochs])
    metric = d2l.Accumulator(2)
    for epoch in range(num_epochs):
        timer, num_batches = d2l.Timer(), len(data_iter)
        for i, batch in enumerate(data_iter):
            optimizer.zero_grad()
            center, context_negative, mask, label = [data.to(device) for data in batch]
            pred = skip_gram(center, context_negative, net[0], net[1])
            l = (loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.sum(axis=1) * mask.shape[1]) 
            l.sum().backward()
            optimizer.step()
            metric.add(l.sum(), l.numel())
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, 'f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')

In [None]:
lr, num_epochs= 0.002, 5
train(net, data_iter,lr,num_epochs)

In [37]:
def get_similar_tokens(known_token, k, embedd):
    W = embedd.weight.data
    x = W[vocab[known_token]]
    cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)
    topk = torch.topk(cos, k=k + 1)[1].cpu().numpy().astype('int32')
    for i in topk[1:]: 
        print(f'cosine sim={float(cos[i]):.3f}: {vocab.to_tokens(i)}')

In [38]:
get_similar_tokens('baby', 3, net[0])

cosine sim=0.518: interglacial
cosine sim=0.508: dementia
cosine sim=0.506: junius
