## A test of word2vec: using skip-gram

Please refer to Mikolov et al. (2013). "Efficient Estimation of Word Representations in Vector Space."

We want to minimize:

<div style="text-align: center;">
    <img src="https://tensorflow.org/text/tutorials/images/word2vec_skipgram_objective.png" width="400">
</div>

where 'Wt' is the focused word and 'c' is the window size. As for prediction, to calculate probability:

<div style="text-align: center;">
    <img src="https://tensorflow.org/text/tutorials/images/word2vec_full_softmax.png" width="400">
</div>

where 'v' and 'v`' are target and context vector representations of words and 'W' is vocabulary size.

### Setup

In [1]:
import io
import re
import string
import tqdm

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers

In [2]:
%load_ext tensorboard

In [3]:
SEED = 42
AUTOTUNE = tf.data.AUTOTUNE

### Init

> The cute ginger cat sits comfortably on the mat.

In [10]:
sentence = "The cute ginger cat sits comfortably on the mat"
tokens = list(sentence.lower().split())
len(tokens)

9

In [11]:
vocab, index = {}, 1  # start indexing from 1

vocab['<pad>'] = 0  # add a padding token
for token in tokens:
  if token not in vocab:
    vocab[token] = index
    index += 1
      
vocab

{'<pad>': 0,
 'the': 1,
 'cute': 2,
 'ginger': 3,
 'cat': 4,
 'sits': 5,
 'comfortably': 6,
 'on': 7,
 'mat': 8}

In [12]:
inverse_vocab = {index: token for token, index in vocab.items()}
inverse_vocab

{0: '<pad>',
 1: 'the',
 2: 'cute',
 3: 'ginger',
 4: 'cat',
 5: 'sits',
 6: 'comfortably',
 7: 'on',
 8: 'mat'}

In [13]:
example_sequence = [vocab[word] for word in tokens]
example_sequence

[1, 2, 3, 4, 5, 6, 7, 1, 8]

### Skip-grams

The skip-gram model is a method for predicting the context words given a central word.

In [14]:
window_size = 3

In [17]:
#Create 'pairs'
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
      example_sequence,
      vocabulary_size=len(vocab),
      window_size=window_size,
      negative_samples=0)

# Should be 3+4+5+6+6+6+5+4+3=42
len(positive_skip_grams)

42

In [18]:
# A few positive pairs
for target, context in positive_skip_grams[:5]:
  print(f"({target}, {context}): ({inverse_vocab[target]}, {inverse_vocab[context]})")

(4, 7): (cat, on)
(6, 3): (comfortably, ginger)
(3, 5): (ginger, sits)
(4, 5): (cat, sits)
(3, 2): (ginger, cute)


#### Negative sampling

When training skip-gram, the model tries to predict context words given a target word. For each real pair, instead of comparing against all vocabulary words (which is expensive), we sample a few 'fake' context words and train the model to distinguish real from fake. Here we sample random words from the vocabulary for a given target word in a window.

In [30]:
target_word, context_word = positive_skip_grams[0] # [4, 7] in this example, '4' is the target word 'cat', '7' is the context_word: 'on'.
num_ns = 5

context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
    true_classes=context_class,  # positive ones
    num_true=1,  # each positive skip-gram has 1 positive context class
    num_sampled=num_ns,  # number of negative context words
    unique=True,  # all the negative samples should be unique
    range_max=vocab_size,  # pick index of the samples from [0, vocab_size]
    seed=SEED,  # seed for reproducibility
    name="negative_sampling"  # name of this operation
)

# Can contain anything except for that with index 7 ('on').
negative_sampling_candidates

<tf.Tensor: shape=(5,), dtype=int64, numpy=array([2, 8, 0, 6, 5])>

#### The first example

For a given positive pair, now we have 'num_ns' and negative sampled context words. Here we batch them into one tensor.

In [31]:
squeezed_context_class = tf.squeeze(context_class, 1)
context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)

# Label the first context word as `1` (positive) followed by `num_ns` `0`s (negative).
label = tf.constant([1] + [0]*num_ns, dtype="int64")
target = target_word

In [32]:
print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[target_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

target_index    : 4
target_word     : cat
context_indices : [7 2 8 0 6 5]
context_words   : ['on', 'cute', 'mat', '<pad>', 'comfortably', 'sits']
label           : [1 0 0 0 0 0]


In [33]:
print("target  :", target)
print("context :", context)
print("label   :", label)

target  : 4
context : tf.Tensor([7 2 8 0 6 5], shape=(6,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0 0], shape=(6,), dtype=int64)


#### Skip-gram sampling table

Randomly drop overly frequent words to improve embedding quality. Word at index 0 (most frequent) has 0.3% chance to be kept → almost always dropped, word at higher index is always kept → it's rare. In this case we can drop words like 'the' that do not carry semantic information.

In [56]:
# word-frequency rank based probabilistic sampling table. 'sampling_table[i]' denotes the probability of sampling the i-th most common word in a dataset
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=100)
print(sampling_table)

[0.00315225 0.00315225 0.00547597 0.00741556 0.00912817 0.01068435
 0.01212381 0.01347162 0.01474487 0.0159558  0.0171136  0.01822533
 0.01929662 0.02033198 0.02133515 0.02230924 0.02325687 0.02418031
 0.02508148 0.02596208 0.02682359 0.02766731 0.02849441 0.02930593
 0.03010279 0.03088585 0.03165585 0.0324135  0.03315943 0.0338942
 0.03461837 0.03533241 0.03603678 0.0367319  0.03741815 0.03809591
 0.0387655  0.03942724 0.04008143 0.04072834 0.04136824 0.04200136
 0.04262794 0.0432482  0.04386234 0.04447055 0.04507302 0.04566992
 0.04626142 0.04684768 0.04742884 0.04800505 0.04857644 0.04914315
 0.04970529 0.05026299 0.05081636 0.0513655  0.05191052 0.05245153
 0.05298861 0.05352186 0.05405136 0.05457721 0.05509948 0.05561824
 0.05613359 0.05664558 0.05715429 0.05765979 0.05816214 0.05866141
 0.05915765 0.05965093 0.06014131 0.06062883 0.06111355 0.06159553
 0.06207481 0.06255144 0.06302548 0.06349696 0.06396593 0.06443243
 0.0648965  0.0653582  0.06581754 0.06627458 0.06672936 0.06718

### Compile everything together


Compile all the steps above into a function that can be called on a list of vectorized sentences obtained from any text dataset. Notice that the sampling table is built before sampling skip-gram word pairs.

In [37]:
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  targets, contexts, labels = [], [], []

  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Many sequences (sentences) in the dataset.
  for sequence in tqdm.tqdm(sequences):

    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # negative samples
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

#### Prepare training data for word2vec

In [38]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
[1m1115394/1115394[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1us/step


Read the text from the file and print the first few lines:

In [44]:
with open(path_to_file) as f:
    lines = f.read().splitlines()
    print(len(lines))
print('\n')
for line in lines[:20]:
    print(line)

40000


First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.


Use the non empty lines to construct a `tf.data.TextLineDataset` object for the next steps:

In [45]:
text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))

#### Vectorize sentences from the corpus

The following lines are for preparing text so that the model can work with integer token IDs instead of raw text, which is required for Word2Vec, LSTMs, Transformers, etc. vocab_size=4096 means: keep only the 4096 most common words in the dataset; sequence_length=10 means: each piece of text will be turned into a sequence of exactly 10 integers; batch(1024): number of sentences for training each time.

In [51]:
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  return tf.strings.regex_replace(lowercase, '[%s]' % re.escape(string.punctuation), '')

vocab_size = 4096
sequence_length = 10

vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length)

In [52]:
vectorize_layer.adapt(text_ds.batch(1024))

In [54]:
vectorize_layer(["The cute ginger cat sits comfortably on the mat."])

<tf.Tensor: shape=(1, 10), dtype=int64, numpy=array([[   2,    1,    1, 3783, 1211,    1,   47,    2,    1,    0]])>

Here '1' or '2' means this word does not exist. Larger numbers match the exact index. Having '0's indicates that the sequence length is shorter than 10.

The following function returns a list of all vocabulary tokens sorted (descending) by their frequency.

In [58]:
# Save the created vocabulary for reference.
inverse_vocab = vectorize_layer.get_vocabulary()
print(inverse_vocab[:50])

['', '[UNK]', np.str_('the'), np.str_('and'), np.str_('to'), np.str_('i'), np.str_('of'), np.str_('you'), np.str_('my'), np.str_('a'), np.str_('that'), np.str_('in'), np.str_('is'), np.str_('not'), np.str_('for'), np.str_('with'), np.str_('me'), np.str_('it'), np.str_('be'), np.str_('your'), np.str_('his'), np.str_('this'), np.str_('but'), np.str_('he'), np.str_('have'), np.str_('as'), np.str_('thou'), np.str_('him'), np.str_('so'), np.str_('what'), np.str_('thy'), np.str_('will'), np.str_('no'), np.str_('by'), np.str_('all'), np.str_('king'), np.str_('we'), np.str_('shall'), np.str_('her'), np.str_('if'), np.str_('our'), np.str_('are'), np.str_('do'), np.str_('thee'), np.str_('now'), np.str_('lord'), np.str_('good'), np.str_('on'), np.str_('o'), np.str_('come')]


In [59]:
# Vectorize the data in text_ds.
text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()

In [65]:
# flatten the dataset into a list of sentence vector sequences
sequences = list(text_vector_ds.as_numpy_iterator())
sequences[:20]

[array([ 89, 270,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([138,  36, 982, 144, 673, 125,  16, 106,   0,   0]),
 array([34,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([106, 106,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([ 89, 270,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([   7,   41,   34, 1286,  344,    4,  200,   64,    4, 3690]),
 array([34,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([1286, 1286,    0,    0,    0,    0,    0,    0,    0,    0]),
 array([ 89, 270,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([  89,    7,   93, 1187,  225,   12, 2442,  592,    4,    2]),
 array([34,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([  36, 2655,   36, 2655,    0,    0,    0,    0,    0,    0]),
 array([ 89, 270,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([  72,   79,  506,   27,    3,   56,   24, 1390,   57,   40]),
 array([644,   9,   1,   0,   0,   0,   0,   0,   0,   0]),
 array([34,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([  32,   54, 2863,  885

#### Generate training examples from sequences

Call the 'generate_training_data' function to iterate over each word from each sequence to collect positive and negative context words. Length of target, contexts and labels should be the same, representing the total number of training examples.

In [66]:
targets, contexts, labels = generate_training_data(
    sequences=sequences,
    window_size=3,
    num_ns=5,
    vocab_size=vocab_size,
    seed=SEED)

targets = np.array(targets)
contexts = np.array(contexts)
labels = np.array(labels)

print('\n')
print(f"targets.shape: {targets.shape}")
print(f"contexts.shape: {contexts.shape}")
print(f"labels.shape: {labels.shape}")

100%|██████████████████████████████████████████████████████████████████████████| 32777/32777 [00:14<00:00, 2197.85it/s]




targets.shape: (89288,)
contexts.shape: (89288, 6)
labels.shape: (89288, 6)


This step is to create an object with '(target_word, context_word), (label)' elements to train word2vec model.

In [67]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)

<_BatchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 6), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 6), dtype=tf.int64, name=None))>


In [68]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)

<_PrefetchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 6), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 6), dtype=tf.int64, name=None))>


### Training

The word2vec model can be implemented as a classifier to distinguish between true context words from skip-grams and false context words obtained through negative sampling. You can perform a dot product multiplication between the embeddings of target and context words to obtain predictions for labels and compute the loss function against true labels in the dataset.

#### Subclassed word2vec model

In the following function:

'target_embedding': looks up the embedding of a word when it appears as a target word. The number of parameters in this layer is (vocab_size * embedding_dim).
'context_embedding': looks up the embedding of a word when it appears as a context word. The number of parameters in this layer is the same as 'target_embedding'.
Basically, it can be understood that there is 1 hidden layer: with 'target_embedding' we transfer the one-hot encodings of words into their embeddings, then with 'context_embedding' they are transferred to possibilities of context words and compare with their one-hot encodings. 
'dots': computes the dot product of target and context embeddings from a training pair.
'flatten': flattens the results of 'dots' layer into logits.

'call()' accepts (target, context) pairs which can then be passed into their corresponding embedding layer.

In [69]:
class Word2Vec(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2Vec, self).__init__()
    self.target_embedding = layers.Embedding(vocab_size, embedding_dim, name="w2v_embedding")
    self.context_embedding = layers.Embedding(vocab_size, embedding_dim)

  def call(self, pair):
    target, context = pair
    if len(target.shape) == 2:
      target = tf.squeeze(target, axis=1)
    # target: (batch,)
    word_emb = self.target_embedding(target)
    # word_emb: (batch, embed)
    context_emb = self.context_embedding(context)
    # context_emb: (batch, context, embed)
    dots = tf.einsum('be,bce->bc', word_emb, context_emb)
    # dots: (batch, context)
    return dots

#### Define loss function and compile model


In [70]:
embedding_dim = 300
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
                 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])

In [71]:
# Log training statistics for TensorBoard
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

In [72]:
word2vec.fit(dataset, epochs=50, callbacks=[tensorboard_callback])

Epoch 1/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 29ms/step - accuracy: 0.1907 - loss: 1.7899
Epoch 2/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 33ms/step - accuracy: 0.5677 - loss: 1.7312
Epoch 3/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 33ms/step - accuracy: 0.4730 - loss: 1.6179
Epoch 4/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 22ms/step - accuracy: 0.5071 - loss: 1.4885
Epoch 5/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 30ms/step - accuracy: 0.5728 - loss: 1.3499
Epoch 6/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 31ms/step - accuracy: 0.6382 - loss: 1.2139
Epoch 7/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 32ms/step - accuracy: 0.6914 - loss: 1.0867
Epoch 8/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 33ms/step - accuracy: 0.7335 - loss: 0.9711
Epoch 9/50
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x16a8ceabbe0>

In [73]:
#docs_infra: no_execute
%tensorboard --logdir logs

<!-- <img class="tfo-display-only-on-site" src="images/word2vec_tensorboard.png"/> -->

### Embedding lookup and analysis

Obtain the weights from the model, and the vocabulary to build a metadata file with one token per line.

In [74]:
weights = word2vec.get_layer('w2v_embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()

In [75]:
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

for index, word in enumerate(vocab):
  if index == 0:
    continue  # skip 0, it's padding.
  vec = weights[index]
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(word + "\n")
out_v.close()
out_m.close()

Analyze the obtained embeddings in the [Embedding Projector](https://projector.tensorflow.org/):

### Use

In [76]:
word2vec.target_embedding.weights

[<Variable path=word2_vec/w2v_embedding/embeddings, shape=(4096, 300), dtype=float32, value=[[-0.02471266 -0.0074613  -0.02244885 ... -0.0293352  -0.04548246
   -0.00054247]
  [ 0.05865565  0.29671758  0.08005863 ... -0.04820159  0.20211284
   -0.24353975]
  [ 0.03541387  0.10301269 -0.31272495 ... -0.01232275  0.05818801
    0.1957827 ]
  ...
  [-0.1862249  -0.1937608   0.03388905 ...  0.31373867 -0.12302848
    0.00505144]
  [-0.1522535  -0.10752758  0.27593774 ...  0.21866505  0.06104349
    0.20777278]
  [-0.16917834  0.07915179  0.09426652 ... -0.11175995 -0.03594727
    0.1652813 ]]>]

In [77]:
embedding_matrix = word2vec.get_layer("w2v_embedding").get_weights()[0]

In [78]:
def embed(word):
    index = inverse_vocab.index(word)  # find word index
    return embedding_matrix[index]     # return its embedding vector

In [80]:
v1 = embed("king")
v2 = embed("queen")
v3 = embed("woman")
v4 = embed("man")

In [85]:
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

In [88]:
cosine_similarity((v1+v3-v4), v2)

np.float32(0.10787292)

In [89]:
def most_similar(query_word, top_k=10):
    query_vec = embed(query_word)
    sims = np.dot(embedding_matrix, query_vec) / (
        np.linalg.norm(embedding_matrix, axis=1) * np.linalg.norm(query_vec)
    )
    top_indices = sims.argsort()[-top_k:][::-1]
    return [(inverse_vocab[i], sims[i]) for i in top_indices]

most_similar("king")

[(np.str_('king'), np.float32(1.0000001)),
 (np.str_('richard'), np.float32(0.33785653)),
 (np.str_('conveyd'), np.float32(0.32946467)),
 (np.str_('xi'), np.float32(0.32419485)),
 (np.str_('isabel'), np.float32(0.31823367)),
 (np.str_('was'), np.float32(0.3180876)),
 (np.str_('iv'), np.float32(0.31232157)),
 (np.str_('ii'), np.float32(0.28394505)),
 (np.str_('killed'), np.float32(0.2832326)),
 (np.str_('storm'), np.float32(0.2812345))]