# Word 2 Vec

In [1]:
import io
import itertools
import numpy as np
import os
import re
import string
import tensorflow as tf
import tqdm

from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Activation, Dense, Dot, Embedding, Flatten, GlobalAveragePooling1D, Reshape
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

In [2]:
SEED = 42
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Train Dataset Generator


We start by creating a sampling table. This object will help us choose negative examples from the most common words in the vocabulary such as "the", "is", "on".

In [4]:
sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=10)
print(sampling_table)

[0.00315225 0.00315225 0.00547597 0.00741556 0.00912817 0.01068435
 0.01212381 0.01347162 0.01474487 0.0159558 ]


In [10]:
# Function to Generate Training Data.

def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
    
    targets, contexts, labels = [], [] , []
    
    sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(size=vocab_size)
    
    for sequence in tqdm.tqdm(sequences):
        
        positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence, 
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)
        
        for target_word, context_word in positive_skip_grams:
            
            context_class = tf.expand_dims(tf.constant([context_word], dtype="int64"), 1)
            
            negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
                                                                  true_classes=context_class,
                                                                  num_true=1, 
                                                                  num_sampled=num_ns, 
                                                                  unique=True, 
                                                                  range_max=vocab_size, 
                                                                  seed=SEED, 
                                                                  name="negative_sampling")
            
            negative_sampling_candidates = tf.expand_dims(negative_sampling_candidates, 1)
            context = tf.concat([context_class, negative_sampling_candidates], 0)
            label = tf.constant([1]+[0]*num_ns, dtype="int64")
            
            targets.append(target_word)
            contexts.append(context)
            labels.append(label)
            
    
    return targets, contexts, labels
            
        
        
        
        

## Download Data

In [6]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt


In [8]:
with open(path_to_file) as f: 
    lines = f.read().splitlines()
    print(len(lines))
for line in lines[:20]:
    print(line)

40000
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.


## Prepare Data


The steps are:

1. Remove empty lines and create and object wo which tf can operate.
2. Remove punctuation.
3. Lower case.
4. Vectorize sentences. 
5. Save Created Vocabulary for reference
6. Generate vectors for each element in corpus
7. Generate training examples from vector sequences

In [12]:
# We use the TextLineDataset function from Tensorflow.

# We load only non-empty lines

text_ds = tf.data.TextLineDataset(path_to_file).filter(lambda x: tf.cast(tf.strings.length(x), bool))

In [13]:
# We create a custom standardization function to lowercase the text and  remove punctuation.

# This is implemented with the TextVectorization object

def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    return tf.strings.regex_replace(lowercase,
                                  '[%s]' % re.escape(string.punctuation), '')

# Define the vocabulary size and number of words in a sequence.

vocab_size = 4096
sequence_length = 10

# Use the text vectorization layer to normalize, split, and map strings to
# integers. Set output_sequence_length length to pad all samples to same length.

vectorize_layer = TextVectorization(
    standardize=custom_standardization,
    max_tokens=vocab_size,
    output_mode='int',
    output_sequence_length=sequence_length)


In [14]:
vectorize_layer.adapt(text_ds.batch(1024))

In [15]:
inverse_vocab = vectorize_layer.get_vocabulary()
print(inverse_vocab[:20])

['', '[UNK]', 'the', 'and', 'to', 'i', 'of', 'you', 'my', 'a', 'that', 'in', 'is', 'not', 'for', 'with', 'me', 'it', 'be', 'your']


In [16]:
def vectorize_text(text):
    text = tf.expand_dims(text, -1)
    return tf.squeeze(vectorize_layer(text))

# Vectorize the data in text_ds.

text_vector_ds = text_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_text).unbatch()


In [17]:
text_vector_ds

<_UnbatchDataset shapes: <unknown>, types: tf.int64>

In [19]:
sequences = list(text_vector_ds.as_numpy_iterator())
print(len(sequences))

for seq in sequences[:5]:
    print(f"{seq} => {[inverse_vocab[i] for i in seq]}")

32777
[ 89 270   0   0   0   0   0   0   0   0] => ['first', 'citizen', '', '', '', '', '', '', '', '']
[138  36 982 144 673 125  16 106   0   0] => ['before', 'we', 'proceed', 'any', 'further', 'hear', 'me', 'speak', '', '']
[34  0  0  0  0  0  0  0  0  0] => ['all', '', '', '', '', '', '', '', '', '']
[106 106   0   0   0   0   0   0   0   0] => ['speak', 'speak', '', '', '', '', '', '', '', '']
[ 89 270   0   0   0   0   0   0   0   0] => ['first', 'citizen', '', '', '', '', '', '', '', '']


In [20]:
%%time

targets, contexts, labels = generate_training_data(sequences=sequences,
                                                  window_size=2,
                                                  num_ns=4,
                                                  vocab_size=vocab_size,
                                                  seed=42)

100%|██████████| 32777/32777 [00:06<00:00, 5336.82it/s]

CPU times: user 5.87 s, sys: 179 ms, total: 6.05 s
Wall time: 6.16 s





In [21]:
print(len(targets), len(contexts), len(labels))

65812 65812 65812


## Configure Dataset For Performance

In [25]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000

dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)

<BatchDataset shapes: (((1024,), (1024, 5, 1)), (1024, 5)), types: ((tf.int32, tf.int64), tf.int64)>


In [26]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)

<PrefetchDataset shapes: (((1024,), (1024, 5, 1)), (1024, 5)), types: ((tf.int32, tf.int64), tf.int64)>


# Model And Training

The Word2Vec model can be implemented as a classifier to distinguish between true context words from skipgrams and false context words obtained through negatve sampling.

This can be done by calculating a dot product between the embeddings of target and contextwords. This allows us to obtain predictions for labels and compute the loss against the true labels.

___


We will use the following layers from Keras.
1. target embedding: Looks up he embedding of a word when appears as a target word.
2. context embedding: Looks up the embedding of a word when appears as a context word.
3. dots: keras.Dots. Computs the dot product between the embedding vectors of the target and context words.
4. flatten: Flattens the results of the dot product into logits.

Similar to pytorch's forward method, we will use the <code> call() </code> method to accept (target, context) pairs



In [32]:
class Word2Vec(Model):
    
    def __init__(self, vocab_size, embedding_dim, num_ns):
        super(Word2Vec, self).__init__()
        self.target_embedding = Embedding(vocab_size, embedding_dim, input_length=1, name="w2v_embedding")
        self.context_embedding = Embedding(vocab_size, embedding_dim, input_length=num_ns+1)
        
        self.dots = Dot(axes=(3,2))
        self.flatten = Flatten()
        
    def call(self, pair):
        
        target, context = pair
        we = self.target_embedding(target)
        ce = self.context_embedding(context)
        
        
        dots = self.dots([ce, we])
        
        return self.flatten(dots)
           
        

## Loss Function

In [33]:
def custom_loss(x_logit, y_true):
    return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=y_true)

## Compile Model

In [34]:
embedding_dim = 128
word2vec = Word2Vec(vocab_size=vocab_size, embedding_dim=embedding_dim, num_ns=4)
word2vec.compile(optimizer="adam", loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                metrics=["accuracy"])

In [35]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")

## Fit the Model

In [39]:
word2vec.fit(dataset, epochs=50, callbacks=[tensorboard_callback])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<tensorflow.python.keras.callbacks.History at 0x7fef988f4c70>

In [40]:
%load_ext tensorboard
%tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 98832), started 0:00:55 ago. (Use '!kill 98832' to kill it.)