https://www.tensorflow.org/text/tutorials/word2vec#compile_all_steps_into_one_function

In [2]:
import io
import re
import string
import tqdm

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers

In [2]:
%load_ext tensorboard

In [12]:
SEED = 42
AUTOTUNE = tf.data.AUTOTUNE

In [4]:
sentence = "The wide road shimmered in the hot sun"
tokens = list(sentence.lower().split())
len(tokens)

8

In [5]:
vocab, index = {}, 1

vocab['<pad>'] = 0
for token in tokens:
    if token not in vocab:
        vocab[token] = index
        index += 1

vocab_size = len(vocab)
vocab_size

8

In [6]:
inverse_vocab = {index: token for token, index in vocab.items()}
inverse_vocab

{0: '<pad>',
 1: 'the',
 2: 'wide',
 3: 'road',
 4: 'shimmered',
 5: 'in',
 6: 'hot',
 7: 'sun'}

In [7]:
example_sequence = [vocab[word] for word in tokens]
print(example_sequence)

[1, 2, 3, 4, 5, 1, 6, 7]


In [11]:
window_size = 2
positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
      example_sequence,
      vocabulary_size=vocab_size,
      window_size=window_size,
      negative_samples=0,
      seed=SEED,)
print(len(positive_skip_grams))

26


In [15]:
[[inverse_vocab[i], inverse_vocab[j]] for i, j in positive_skip_grams]

[['in', 'hot'],
 ['shimmered', 'the'],
 ['shimmered', 'wide'],
 ['sun', 'hot'],
 ['the', 'hot'],
 ['road', 'wide'],
 ['road', 'the'],
 ['shimmered', 'road'],
 ['in', 'the'],
 ['hot', 'the'],
 ['shimmered', 'in'],
 ['the', 'in'],
 ['the', 'road'],
 ['in', 'shimmered'],
 ['hot', 'sun'],
 ['in', 'road'],
 ['wide', 'the'],
 ['the', 'shimmered'],
 ['sun', 'the'],
 ['wide', 'shimmered'],
 ['hot', 'in'],
 ['road', 'shimmered'],
 ['road', 'in'],
 ['the', 'wide'],
 ['wide', 'road'],
 ['the', 'sun']]

In [14]:
positive_skip_grams

[[5, 6],
 [4, 1],
 [4, 2],
 [7, 6],
 [1, 6],
 [3, 2],
 [3, 1],
 [4, 3],
 [5, 1],
 [6, 1],
 [4, 5],
 [1, 5],
 [1, 3],
 [5, 4],
 [6, 7],
 [5, 3],
 [2, 1],
 [1, 4],
 [7, 1],
 [2, 4],
 [6, 5],
 [3, 4],
 [3, 5],
 [1, 2],
 [2, 3],
 [1, 7]]

In [17]:
targe_word, context_word = positive_skip_grams[0]

num_ns = 4

context_class = tf.reshape(tf.constant(context_word, dtype="int64"), (1, 1))

context_class

<tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[6]])>

In [18]:
context_class.shape

TensorShape([1, 1])

In [22]:
negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
    true_classes=context_class,
    num_true=1,
    num_sampled=num_ns,
    unique=True,
    range_max=vocab_size,
    seed=SEED,
    name="negative_sampling_"
)
negative_sampling_candidates

<tf.Tensor: shape=(4,), dtype=int64, numpy=array([2, 1, 4, 3])>

In [23]:
[inverse_vocab[index.numpy()] for index in negative_sampling_candidates]

['wide', 'the', 'shimmered', 'road']

In [27]:
squeezed_context_class = tf.squeeze(context_class, 1)
squeezed_context_class

<tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>

In [33]:
context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)

In [29]:
label = tf.constant([1] + [0]*num_ns, dtype="int64")
target = targe_word

In [31]:
target

5

In [34]:
inverse_vocab[targe_word]

'in'

In [35]:
context

<tf.Tensor: shape=(5,), dtype=int64, numpy=array([6, 2, 1, 4, 3])>

In [36]:
[inverse_vocab[c.numpy()] for c in context]

['hot', 'wide', 'the', 'shimmered', 'road']

In [37]:
label

<tf.Tensor: shape=(5,), dtype=int64, numpy=array([1, 0, 0, 0, 0])>

In [39]:
print(f"target_index    : {target}")
print(f"target_word     : {inverse_vocab[targe_word]}")
print(f"context_indices : {context}")
print(f"context_words   : {[inverse_vocab[c.numpy()] for c in context]}")
print(f"label           : {label}")

target_index    : 5
target_word     : in
context_indices : [6 2 1 4 3]
context_words   : ['hot', 'wide', 'the', 'shimmered', 'road']
label           : [1 0 0 0 0]


(target, context, label)

In [40]:
print("target  :", target)
print("context :", context)
print("label   :", label)

target  : 5
context : tf.Tensor([6 2 1 4 3], shape=(5,), dtype=int64)
label   : tf.Tensor([1 0 0 0 0], shape=(5,), dtype=int64)


In [3]:
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
    targets, contexts, labels = [], [], []
    sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

    for sequence in tqdm.tqdm(sequences):
        positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(sequence, vocabulary_size=vocab_size, window_size=window_size, negative_samples=0, seed=seed)
        for target_word, context_word in positive_skip_grams:
            context_class = tf.expand_dims(
                tf.constant([context_word], dtype="int64"), 1
            )
            negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
                true_classes=context_class,
                num_true=1,
                num_sampled=num_ns,
                unique=True,
                range_max=vocab_size,
                seed=seed,
                name="negative_sampling"
            )
            context = tf.concat([tf.squeeze(context_class, 1), negative_sampling_candidates], 0)
            label = tf.constant([1] + [0]*num_ns, dtype="int64")
            targets.append(target_word)
            contexts.append(context)
            labels.append(label)

    return targets, contexts, labels

In [4]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
[1m1115394/1115394[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1us/step


In [5]:
with open(path_to_file) as f:
    lines = f.read().splitlines()

print(lines[:10])

['First Citizen:', 'Before we proceed any further, hear me speak.', '', 'All:', 'Speak, speak.', '', 'First Citizen:', 'You are all resolved rather to die than to famish?', '', 'All:']


In [6]:
text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))

2025-12-14 11:31:47.230227: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    return tf.strings.regex_replace(lowercase, '[%s]' % re.escape(string.punctuation), '')

In [8]:
vocab_size = 4096
sequence_length = 10

vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length
)

In [9]:
vectorize_layer.adapt(text_ds.batch(1024))

2025-12-14 11:35:55.695723: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [10]:
inverse_vocab = vectorize_layer.get_vocabulary()
print(inverse_vocab[:20])

['', '[UNK]', np.str_('the'), np.str_('and'), np.str_('to'), np.str_('i'), np.str_('of'), np.str_('you'), np.str_('my'), np.str_('a'), np.str_('that'), np.str_('in'), np.str_('is'), np.str_('not'), np.str_('for'), np.str_('with'), np.str_('me'), np.str_('it'), np.str_('be'), np.str_('your')]


In [13]:
text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()

In [15]:
sequences = list(text_vector_ds.as_numpy_iterator())

2025-12-14 11:39:32.496594: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [16]:
len(sequences)

32777

In [17]:
sequences[:5]

[array([ 89, 270,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([138,  36, 982, 144, 673, 125,  16, 106,   0,   0]),
 array([34,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([106, 106,   0,   0,   0,   0,   0,   0,   0,   0]),
 array([ 89, 270,   0,   0,   0,   0,   0,   0,   0,   0])]

In [18]:
targets, contexts, labels = generate_training_data(
    sequences, window_size=2, num_ns=4, vocab_size=vocab_size, seed=SEED
)

100%|██████████| 32777/32777 [02:20<00:00, 233.17it/s]


In [19]:
targets, contexts, labels = np.array(targets), np.array(contexts), np.array(labels)

In [20]:
targets.shape, contexts.shape, labels.shape

((620830,), (620830, 5), (620830, 5))

In [21]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000

In [22]:
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset

<_TensorSliceDataset element_spec=((TensorSpec(shape=(), dtype=tf.int64, name=None), TensorSpec(shape=(5,), dtype=tf.int64, name=None)), TensorSpec(shape=(5,), dtype=tf.int64, name=None))>

In [23]:
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

In [24]:
dataset

<_BatchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None))>

In [25]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)

In [26]:
dataset

<_PrefetchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None))>

In [27]:
DATASET_PATH = "saved_word2vec_dataset"

dataset.save(DATASET_PATH)

In [28]:
loaded_dataset = tf.data.Dataset.load(DATASET_PATH)

print(loaded_dataset)


<_LoadDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 5), dtype=tf.int64, name=None))>


In [29]:
class Word2Vec(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.target_embedding = layers.Embedding(vocab_size, embedding_dim, name="w2v_embedding")
        self.context_embedding = layers.Embedding(vocab_size, embedding_dim)

    def call(self, pair):
        target, context = pair
        if len(target.shape) == 2:
            target = tf.squeeze(target, axis=1)
        word_emb = self.target_embedding(target)
        context_emb = self.context_embedding(context)
        dots = tf.einsum('be,bce->bc', word_emb, context_emb)
        return dots

In [30]:
def custom_loss(x_logit, y_true):
    return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=y_true)

In [31]:
embedding_dim = 64
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

In [32]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

In [33]:
word2vec.fit(dataset, epochs=10, callbacks=[tensorboard_callback])

Epoch 1/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 11ms/step - accuracy: 0.3314 - loss: 1.4961
Epoch 2/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 9ms/step - accuracy: 0.4002 - loss: 1.3934
Epoch 3/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 7ms/step - accuracy: 0.4310 - loss: 1.3433
Epoch 4/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 10ms/step - accuracy: 0.4510 - loss: 1.3078
Epoch 5/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 8ms/step - accuracy: 0.4660 - loss: 1.2792
Epoch 6/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 7ms/step - accuracy: 0.4790 - loss: 1.2548
Epoch 7/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 7ms/step - accuracy: 0.4901 - loss: 1.2330
Epoch 8/10
[1m606/606[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 7ms/step - accuracy: 0.4997 - loss: 1.2133
Epoch 9/10
[1m606/606[0m [32m━━━━━━

<keras.src.callbacks.history.History at 0x7b5a2645f770>

In [34]:
#docs_infra: no_execute
%tensorboard --logdir logs

UsageError: Line magic function `%tensorboard` not found.


In [35]:
weights = word2vec.get_layer('w2v_embedding').get_weights()[0]
vocab = vectorize_layer.get_vocabulary()

In [36]:
out_v = io.open('vectors.tsv', 'w', encoding='utf-8')
out_m = io.open('metadata.tsv', 'w', encoding='utf-8')

for index, word in enumerate(vocab):
  if index == 0:
    continue  # skip 0, it's padding.
  vec = weights[index]
  out_v.write('\t'.join([str(x) for x in vec]) + "\n")
  out_m.write(word + "\n")
out_v.close()
out_m.close()

In [37]:
try:
  from google.colab import files
  files.download('vectors.tsv')
  files.download('metadata.tsv')
except Exception:
  pass

In [38]:
vocab

['',
 '[UNK]',
 np.str_('the'),
 np.str_('and'),
 np.str_('to'),
 np.str_('i'),
 np.str_('of'),
 np.str_('you'),
 np.str_('my'),
 np.str_('a'),
 np.str_('that'),
 np.str_('in'),
 np.str_('is'),
 np.str_('not'),
 np.str_('for'),
 np.str_('with'),
 np.str_('me'),
 np.str_('it'),
 np.str_('be'),
 np.str_('your'),
 np.str_('his'),
 np.str_('this'),
 np.str_('but'),
 np.str_('he'),
 np.str_('have'),
 np.str_('as'),
 np.str_('thou'),
 np.str_('him'),
 np.str_('so'),
 np.str_('what'),
 np.str_('thy'),
 np.str_('will'),
 np.str_('no'),
 np.str_('by'),
 np.str_('all'),
 np.str_('king'),
 np.str_('we'),
 np.str_('shall'),
 np.str_('her'),
 np.str_('if'),
 np.str_('our'),
 np.str_('are'),
 np.str_('do'),
 np.str_('thee'),
 np.str_('now'),
 np.str_('lord'),
 np.str_('good'),
 np.str_('on'),
 np.str_('o'),
 np.str_('come'),
 np.str_('from'),
 np.str_('sir'),
 np.str_('or'),
 np.str_('which'),
 np.str_('more'),
 np.str_('then'),
 np.str_('well'),
 np.str_('at'),
 np.str_('would'),
 np.str_('was'),
 