# Skip-Gram Word Embedding with Negative Sampling

This notebook implements the Word2Vec algorithm using the **Skip-gram** architecture with **Negative Sampling** from scratch using TensorFlow/Keras.
We use the **text8** dataset (Wikipedia dump) to learn dense vector representations of words.

## Project Structure
- **Data Loading**: Download and extract the text8 dataset.
- **Preprocessing**: 
  - Tokenization, Lowercasing.
  - Stopwords removal.
  - Subsampling freqeunt words.
- **Data Generation**: 
  - Skip-gram pair generation (Target, Context).
  - Negative Sampling.
- **Model**: Custom Keras Model subclassing `tf.keras.Model`.
- **Training**: Custom training loop with `BinaryCrossentropy` loss.
- **Visualization**: Exporting embeddings for Tensorflow Projector.

In [None]:
import re
import io
import math
import gzip
import nltk
import time
import random
import numpy as np
import tensorflow as tf
import gensim.downloader as api
import tensorflow_datasets as tfds
from collections import Counter
from nltk.corpus import stopwords
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import skipgrams

# Ensure stopwords are downloaded
nltk.download('stopwords')

## 1. Data Loading
Loading the **text8** dataset which consists of the first 100MB of clean Wikipedia text.

In [None]:
def load_data():
    text8_zip_file_path = api.load('text8', return_path=True)
    with gzip.open(text8_zip_file_path, 'rb') as file:
        file_content = file.read()
    wiki = file_content.decode()
    return wiki

wiki = load_data()
print(f"Data Loaded. Length: {len(wiki)} characters.")

## 2. Preprocessing
We define a preprocessing pipeline that:
1. Cleans punctuation.
2. Removes stopwords (common words like "the", "is").
3. Removes rare words (< 5 occurrences).
4. Subsamples frequent words using Mikolov's heuristic: $P(w_i) = 1 - \sqrt{\frac{t}{f(w_i)}}$

In [None]:
def preprocess_text(text):
    # Step 1: Clean punctuation
    text = re.sub(r'[^\w\s]', ' ', text)
    
    # Step 2: Lowercase and strip
    text = text.lower().strip()
    
    # Step 3: Stopwords removal
    stop_words = set(stopwords.words('english'))
    words = text.split()
    words = [word for word in words if word not in stop_words]
    
    # Step 4: Minimum frequency filter
    word_counts = Counter(words)
    words = [word for word in words if word_counts[word] >= 5]
    
    # Step 5: Subsampling
    total_words = sum(word_counts.values())
    
    def subsample_probability(word):
        freq = word_counts[word] / total_words
        # Heuristic from the paper
        return 1 - (np.sqrt(1e-5 / freq) + 1e-5 / freq)
        
    subsampled_words = [word for word in words if random.random() > subsample_probability(word)]
    
    return subsampled_words, word_counts

# Run preprocessing (might take a moment)
preprocessed_words, word_counts = preprocess_text(wiki)
print(f"Preprocessing complete. Vocabulary size: {len(word_counts)}")
print(f"Sample words: {preprocessed_words[:10]}")

## 3. Dataset Generation
We use Keras `skipgrams` to generate positive pairs (Target, Context) and Negative pairs.

In [None]:
# Hyperparameters
EMBEDDING_DIM = 100
BUFFER_SIZE = 10000
BATCH_SIZE = 256
EPOCHS = 5
WINDOW_SIZE = 5

# Tokenization
tokenizer = Tokenizer()
tokenizer.fit_on_texts(preprocessed_words)
word_index = tokenizer.word_index
vocab_size = len(word_index) + 1

# Generate Sequences
sequences = tokenizer.texts_to_sequences([preprocessed_words])[0]

# Generate Skip-grams
print("Generating skip-gram pairs...")
pairs, labels = skipgrams(sequences, vocabulary_size=vocab_size, window_size=WINDOW_SIZE, negative_samples=0.75)

# Prepare Arrays
targets, contexts = zip(*pairs)
targets = np.array(targets, dtype=np.int32)
contexts = np.array(contexts, dtype=np.int32)
labels = np.array(labels, dtype=np.int32)

# Train/Test Split
indices = np.arange(len(targets))
np.random.shuffle(indices)
split_idx = int(len(targets) * 0.8)

train_indices = indices[:split_idx]
test_indices = indices[split_idx:]

# TF Datasets
train_dataset = tf.data.Dataset.from_tensor_slices(((targets[train_indices], contexts[train_indices]), labels[train_indices]))
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices(((targets[test_indices], contexts[test_indices]), labels[test_indices]))
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train/Test batches: {len(train_dataset)}, {len(test_dataset)}")

## 4. Model Definition
We define a custom Keras Model.
- **Target Embedding**: Embedding for the center word.
- **Context Embedding**: Embedding for the context word.
- **Dot Product**: Similarity score.
- **Sigmoid**: Probability output.

In [None]:
class Word2Vec(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.target_embedding = tf.keras.layers.Embedding(
            input_dim=vocab_size, 
            output_dim=embedding_dim, 
            name="target_embedding"
        )
        self.context_embedding = tf.keras.layers.Embedding(
            input_dim=vocab_size, 
            output_dim=embedding_dim, 
            name="context_embedding"
        )
        self.dot = tf.keras.layers.Dot(axes=-1, normalize=False, name="dot_product")
        self.output_layer = tf.keras.layers.Activation("sigmoid", name="output")

    def call(self, inputs):
        target, context = inputs
        target_embed = self.target_embedding(target)
        context_embed = self.context_embedding(context)
        dot_product = self.dot([target_embed, context_embed])
        return self.output_layer(dot_product)

model = Word2Vec(vocab_size=vocab_size, embedding_dim=EMBEDDING_DIM)

## 5. Training
Custom training loop using `GradientTape`.

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.BinaryCrossentropy()
train_acc_metric = tf.keras.metrics.BinaryAccuracy()
val_acc_metric = tf.keras.metrics.BinaryAccuracy()

@tf.function
def train_step(model, optimizer, loss_fn, targets, x, y):
    with tf.GradientTape() as tape:
        predictions = model((targets, x), training=True)
        loss = loss_fn(y, predictions)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    train_acc_metric.update_state(y, predictions)
    return loss

@tf.function
def test_step(model, loss_fn, targets, x, y):
    predictions = model((targets, x), training=False)
    loss = loss_fn(y, predictions)
    val_acc_metric.update_state(y, predictions)
    return loss

print("Starting training...")
for epoch in range(EPOCHS):
    start_time = time.time()
    train_acc_metric.reset_state()
    val_acc_metric.reset_state()
    
    # Train
    train_loss = 0.0
    for step, ((targets, contexts), labels) in enumerate(train_dataset):
        loss = train_step(model, optimizer, loss_fn, targets, contexts, labels)
        train_loss += loss.numpy()
    train_loss /= (step + 1)
    
    # Validation
    val_loss = 0.0
    for step, ((targets, contexts), labels) in enumerate(test_dataset):
        loss = test_step(model, loss_fn, targets, contexts, labels)
        val_loss += loss.numpy()
    val_loss /= (step + 1)
    
    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc_metric.result():.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc_metric.result():.4f} | "
          f"Time: {time.time() - start_time:.2f}s")

# Save Checkmark
checkpoint_path = "./checkpoints/word2vec_ckpt"
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
checkpoint.save(file_prefix=checkpoint_path)
print("Model saved.")

## 6. Embedding Export
Exporting vectors for visualization in [TensorFlow Embedding Projector](http://projector.tensorflow.org/).

In [None]:
weights = model.target_embedding.get_weights()[0]

with open("vecs.tsv", "w", encoding='utf-8') as vecs_file, open("meta.tsv", "w", encoding='utf-8') as meta_file:
    for word, idx in tokenizer.word_index.items():
        if idx < vocab_size: # Ensure index is within range
            vec = weights[idx]
            meta_file.write(f"{word}\n")
            vecs_file.write("\t".join(map(str, vec)) + "\n")

print("Export complete: vecs.tsv, meta.tsv")