In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import itertools
from collections import defaultdict

# Sample corpus

In [2]:
sentences = [
    "machine learning is fun",
    "deep learning is part of machine learning",
    "natural language processing is a field of ai",
    "word embeddings are learned representations",
    "tensorflow makes it easy to build models"
]

# Tokenize corpus

In [3]:
tokenizer = tf.keras.preprocessing.text.Tokenizer()
tokenizer.fit_on_texts(sentences)
word2idx = tokenizer.word_index
idx2word = {v: k for k, v in word2idx.items()}
vocab_size = len(word2idx) + 1

# Generate skip-gram pairs

In [4]:
window_size = 2
sequences = tokenizer.texts_to_sequences(sentences)
pairs = []
for seq in sequences:
    for i, target_word in enumerate(seq):
        context_window = seq[max(i - window_size, 0): i] + seq[i + 1: i + window_size + 1]
        for context_word in context_window:
            pairs.append((target_word, context_word))

# Convert to numpy arrays

In [5]:
targets, contexts = zip(*pairs)
targets = np.array(targets)
contexts = np.array(contexts)

# One-hot encode targets

In [6]:
context_labels = tf.keras.utils.to_categorical(contexts, num_classes=vocab_size)

# Define skip-gram model

In [7]:
embedding_dim = 64
input_word = tf.keras.Input(shape=(1,))
embedding = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)(input_word)
x = tf.keras.layers.Reshape((embedding_dim,))(embedding)
output = tf.keras.layers.Dense(vocab_size, activation='softmax')(x)
 
model = tf.keras.Model(inputs=input_word, outputs=output)
model.compile(optimizer='adam', loss='categorical_crossentropy')

# Train the model

In [8]:
model.fit(targets, context_labels, epochs=100, verbose=0)

<keras.src.callbacks.history.History at 0x23198180ec0>

# Extract and display learned embeddings

In [9]:
embedding_weights = model.get_layer('embedding').get_weights()[0]
for word, idx in word2idx.items():
    vec = embedding_weights[idx][:5]  # Show first 5 dims
    print(f"{word}: {vec.round(3)}")

learning: [-0.009 -0.205  0.206 -0.239  0.193]
is: [-0.12  -0.12  -0.042 -0.338  0.253]
machine: [-0.15  -0.409  0.322 -0.168  0.271]
of: [-0.2   -0.18   0.124 -0.124  0.241]
fun: [-0.235 -0.246  0.29  -0.157  0.18 ]
deep: [-0.273 -0.202  0.277 -0.229  0.182]
part: [-0.276 -0.171  0.399 -0.286  0.236]
natural: [-0.234  0.276 -0.207 -0.051  0.103]
language: [-0.165  0.178 -0.184 -0.273 -0.022]
processing: [-0.245  0.04  -0.067  0.047  0.258]
a: [-0.206  0.039 -0.143 -0.238  0.24 ]
field: [-0.143 -0.178  0.169  0.054  0.282]
ai: [ 0.155 -0.224 -0.032 -0.222  0.258]
word: [ 0.29   0.113 -0.129  0.257 -0.277]
embeddings: [ 0.141 -0.274 -0.235  0.165 -0.228]
are: [ 0.238  0.082 -0.05   0.181 -0.301]
learned: [ 0.321  0.159  0.005  0.278 -0.246]
representations: [ 0.288 -0.208 -0.184  0.179 -0.219]
tensorflow: [-0.232  0.296  0.158 -0.305 -0.296]
makes: [-0.117  0.32   0.25   0.119 -0.137]
it: [ 0.176  0.165  0.212 -0.002 -0.165]
easy: [-0.041  0.126  0.178 -0.205 -0.287]
to: [-0.066  0.363 