In [1]:
import os
import math
import torch
import random
import collections
from typing import Optional, Union

In [22]:
%whos

Variable                   Type               Data/Info
-------------------------------------------------------
Optional                   _SpecialForm       typing.Optional
RandomGenerator            type               <class '__main__.RandomGenerator'>
Union                      _SpecialForm       typing.Union
Vocab                      type               <class '__main__.Vocab'>
all_centers                list               n=343109
all_contexts               list               n=343109
all_negatives              list               n=343109
batch                      tuple              n=4
batchify                   function           <function batchify at 0x0000014964F22840>
center                     int                9
collections                module             <module 'collections' fro<...>ollections\\__init__.py'>
compare_counts             function           <function compare_counts at 0x0000014961CD1DA0>
context                    list               n=2
corpus            

In [2]:
def read_ptb() -> list[str]:
    data_dir = "../../Data/ptb/"
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]

In [3]:
sentences = read_ptb()
print(f"senteces: {len(sentences)}")

senteces: 42069


In [4]:
class Vocab:
    """
    Split sentences to word, create map between index and token.
    min_freq: discard token which frequency less than it
    """
    def __init__(
        self, 
        tokens: Union[list[str], list[list[str]]] = [], 
        min_freq: int = 0,
        reserved_tokens: list[str] = []
    ):
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                              reverse=True)
        # The list of unique tokens
        self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx
                         for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]

    @property
    def unk(self):  # Index for the unknown token
        return self.token_to_idx['<unk>']

In [5]:
vocab = Vocab(sentences, min_freq=10)
print(f"vocab size: {len(vocab)}")

vocab size: 6719


In [6]:
vocab.unk, vocab.idx_to_token[:3], vocab.token_freqs[:3], vocab.token_freqs[-3:], vocab.token_to_idx['#']

(26,
 ['#', '$', '&'],
 [('the', 50770), ('<unk>', 45020), ('N', 32481)],
 [('flat-rolled', 1), ('biscuits', 1), ('isi', 1)],
 0)

In [7]:
def subsample(sentences, vocab):
    """
    Subsample high-frequency words.
    """
    # Exclude unknown tokens ('<unk>')
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                 for line in sentences]
    counter = collections.Counter([
        token for line in sentences for token in line])
    num_tokens = sum(counter.values())

    # Return True if `token` is kept during subsampling
    def keep(token):
        return(random.uniform(0, 1) <
               math.sqrt(1e-4 / counter[token] * num_tokens))

    return ([[token for token in line if keep(token)] for line in sentences],
            counter)

In [8]:
subsampled, counter = subsample(sentences, vocab)
print(f"{len(subsampled)}, {len(counter)}")

42069, 6718


In [9]:
def compare_counts(token):
    return (f'{token}: '
            f'before={sum([l.count(token) for l in sentences])}, '
            f'after={sum([l.count(token) for l in subsampled])}')

compare_counts('the')

'the: before=50770, after=2083'

In [10]:
corpus = [vocab[line] for line in subsampled]
corpus[:3]

[[], [4127, 3228], [3922, 1922, 4743]]

In [11]:
def get_centers_and_contexts(corpus, max_window_size):
    """
    Return center words and context words in skip-gram.
    window size is random
    """
    centers, contexts = [], []
    for line in corpus:
        # To form a "center word--context word" pair, each sentence needs to
        # have at least 2 words
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)):  # Context window centered at `i`
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # Exclude the center word from the context words
            indices.remove(i)
            contexts.append([line[idx] for idx in indices])
    return centers, contexts

In [12]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2]
center 2 has contexts [1, 3]
center 3 has contexts [1, 2, 4, 5]
center 4 has contexts [2, 3, 5, 6]
center 5 has contexts [3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]


In [13]:
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'

'center-context pairs: 1503713'

In [14]:
class RandomGenerator:
    """
    Randomly draw among {1, ..., n} according to n sampling weights.
    """
    def __init__(self, sampling_weights):
        # Exclude
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # Cache `k` random sampling results
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]

In [15]:
generator = RandomGenerator([2, 3, 4]) # P(x=1) = 2 / (2 + 3 + 4) = 2 / 9, p(x=2) = 3 / 9, p(x=3) = 4 / 9
[generator.draw() for _ in range(10)]

[1, 3, 2, 1, 2, 2, 3, 2, 2, 1]

In [16]:
def get_negatives(all_contexts, vocab, counter, K):
    """
    Return noise words in negative sampling.
    """
    # Sampling weights for words with indices 1, 2, ... (index 0 is the
    # excluded unknown token) in the vocabulary
    # weight is freq of power 2
    sampling_weights = [counter[vocab.to_tokens(i)]**0.75
                        for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # Noise words cannot be context words
            if neg not in contexts:
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

all_negatives = get_negatives(all_contexts, vocab, counter, 5)

In [19]:
len(all_negatives), all_negatives[0], len(all_contexts), all_contexts[0]

(343109, [29, 2441, 788, 2282, 2131], 343109, [3228])

In [20]:
def batchify(data):
    """Return a minibatch of examples for skip-gram with negative sampling."""
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(
        contexts_negatives), torch.tensor(masks), torch.tensor(labels))

In [21]:
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)

centers = tensor([[1],
        [1]])
contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],
        [2, 2, 2, 3, 3, 0]])
masks = tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0]])
labels = tensor([[1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0]])
