# Training the Model

- https://jaketae.github.io/study/word2vec/
- https://www.geeksforgeeks.org/word-embeddings-in-nlp/
- https://www.youtube.com/playlist?list=PLhWB2ZsrULv-wEM8JDKA1zk8_2Lc88I-s
- https://towardsdatascience.com/skip-gram-neural-network-from-scratch-485f2e688238
- https://stackoverflow.com/questions/4576077/how-can-i-split-a-text-into-sentences
- https://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/

In [23]:
import os
import re
import math
import itertools
import time
import random

import pickle as pkl
from collections import defaultdict

import numpy as np
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

seed = 1234

## Data Processing

In [1]:
# Split a string into a list of words
def tokenize(document):
    pattern = re.compile(r'[A-Za-z]+[\w^\']*|[\w^\']*[A-Za-z]+[\w^\']*')
    return pattern.findall(document.lower())

# Tokenizes each string in the list.
# by_sentence splits each document into sentences beforehand
def batch_tokenize(documents, by_sentence):
    if by_sentence:
        return list(itertools.chain.from_iterable(
            [[tokenize(sentence) for sentence in split_into_sentences(document)] for document in documents]
        ))
    else:
        return [tokenize(document) for document in documents]

# From a list of tokenized documents gets unique vocabulary words with their frequencies
def generate_vocab_counts(documents):
    vocab = defaultdict(int)
    for document in documents:
         for word in document:
             vocab[word] += 1
    return vocab

In [63]:
alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co|Bros)"
starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov|edu|me)"
digits = "([0-9])"
multiple_dots = r'\.{2,}'

# Split a string into sentences
# https://stackoverflow.com/questions/4576077/how-can-i-split-a-text-into-sentences
def split_into_sentences(text: str) -> list[str]:
    """
    Split the text into sentences.

    If the text contains substrings "<prd>" or "<stop>", they would lead 
    to incorrect splitting because they are used as markers for splitting.

    :param text: text to be split into sentences
    :type text: str

    :return: list of sentences
    :rtype: list[str]
    """
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = [s.strip() for s in sentences]
    if sentences and not sentences[-1]: sentences = sentences[:-1]
    return sentences

In [64]:
# Iterate through all text files and tokenize
mario_texts = []
data_dir = 'data/'
# for file_name in os.listdir(data_dir):
#     if not file_name.endswith(".txt"):
#         continue
#     with open(os.path.join(data_dir, file_name), 'r', encoding='utf-8') as file:
#         mario_texts.append(file.read())

with open(os.path.join(data_dir, "Mario.txt"), 'r', encoding='utf-8') as file:
    mario_texts.append(file.read())

with open(os.path.join(data_dir, "Boo.txt"), 'r', encoding='utf-8') as file:
    mario_texts.append(file.read())

mario_documents = batch_tokenize(mario_texts, False)
mario_documents_sentences = batch_tokenize(mario_texts, True)
mario_vocab_counts = generate_vocab_counts(mario_documents)

## Non-Neural Representations

In [37]:
# Get a one-hot representation of a word in a vocab list
def one_hot_str(word, vocab_to_id):
    return one_hot_int(vocab_to_id[word], len(vocab_to_id))

def one_hot_int(index, vocab_size):
    out = np.zeros(vocab_size)
    out[index] = 1
    return out

def batch_one_hot_str(words, vocab_to_id):
    return [one_hot_int(word, vocab_to_id) for word in words]

def batch_one_hot_int(indices, vocab_size):
    return [one_hot_int(index, vocab_size) for index in indices if index >= 0]

def batch_one_hot(words, vocab, index=False):
    return 

In [53]:
# document and frequency based vectorization
def tfidf(vocab_to_id, documents):
    tf = []
    idf = np.zeros(len(vocab_to_id))
    for document in documents:
        tf_row = np.zeros(len(vocab_to_id))
        for word in document:
            tf_row[vocab_to_id[word]] += 1
        tf_row /= len(document)
        tf.append(tf_row)
        idf += np.where(tf_row > 0, 1, 0)
    idf = np.log(len(documents) / idf)
    return tf * idf     

mario_tfidf = tfidf(mario_vocab_to_id, mario_documents)

## Skipgram and CBOW

In [57]:
# Probability of sampling a word given its frequency
def subsample_prob(frequency, subsample_rate):
    return (math.sqrt(frequency / subsample_rate) + 1) * (subsample_rate / frequency)

def subsample(vocab_counts, documents, subsample_rate):
    total_count = sum(vocab_counts.values())
    vocab_subsample_rate = {key: subsample_prob(vocab_counts[key] / total_counts, subsample_rate) for key in vocab_counts.keys()}
    np.random.seed(seed)
    documents_copy = documents.copy()
    for document in documents_copy:
        i = 0
        while i < len(document):
            if np.random.uniform(0, 1) < vocab_subsample_rate[document[i]]:
                document.pop(i)
            else:
                i += 1
    return documents_copy

In [58]:
# Returns indices in the vocab
def skipgram(documents, vocab_counts, vocab_to_id, left_context_window, right_context_window, k_negsample):
    vocab_negsample_rate = {key: vocab_counts[key]**(3/4) for key in vocab_counts.keys()}
    negsample_rate_sum = sum(vocab_negsample_rate.values())
    vocab_negsample_rate = {key: vocab_negsample_rate[key] / negsample_rate_sum for key in vocab_counts.keys()}
    
    np.random.seed(seed)
    contexts = []
    targets = []
    labels = []
    for document in documents:
        for i in range(len(document)):
            target = document[i]
            target_index = vocab_to_id[target]

            for j in range(max(0, i - left_context_window), min(i + right_context_window + 1, len(document))):
                if i == j:
                    continue
                
                targets.append(target_index)
                contexts.append(vocab_to_id[context])
                labels.append(1)

                for neg_context_index in np.random.choice(range(len(vocab_negsample_rate)), 
                                                         size=k_negsample, replace=False, 
                                                         p=list(vocab_negsample_rate.values())):
                    targets.append(target_index)
                    contexts.append(neg_context_index)
                    labels.append(0)
                    
    return (jnp.array(contexts), jnp.array(targets), jnp.array(labels))

In [59]:
# context filled with -1 to reach homgenous shape
def cbow(documents, vocab_counts, left_context_window, right_context_window, k_negsample):
    vocab_negsample_rate = {key: vocab_counts[key]**(3/4) for key in vocab_counts.keys()}
    negsample_rate_sum = sum(vocab_negsample_rate.values())
    vocab_negsample_rate = {key: vocab_negsample_rate[key] / negsample_rate_sum for key in vocab_counts.keys()}
    
    np.random.seed(seed)
    contexts = []
    targets = []
    labels = []
    for document in documents:
        for i in range(len(document)):
            target = document[i]
            target_index = vocab_to_id[target]

            context = []
            for j in range(max(0, i - left_context_window), min(i + right_context_window + 1, len(document))):
                if i == j:
                    continue
                context.append(vocab_to_id(document[j]))
                
            targets.append(target_index)
            contexts.append(context)
            labels.append(1)

            for neg_target_index in np.random.choice(range(len(vocab_negsample_rate)), 
                                                     size=k_negsample, replace=False, 
                                                     p=list(vocab_negsample_rate.values())):
                targets.append(neg_target_index)
                contexts.append(context)
                labels.append(0)
                    
    return (jnp.array(contexts), jnp.array(targets), jnp.array(labels))
    
    vocab = list(vocab_counts.keys())
    counts = list(vocab_counts.values())
    total_counts = float(sum(counts))
    vocab_subsample_rate = {key: subsample_prob(vocab_counts[key] / total_counts, subsample_rate) for key in vocab_counts.keys()}
    vocab_negsample_rate = {key: math.pow(vocab_counts[key], 3/4) for key in vocab_counts.keys()}
    negsample_rate_sum = sum(vocab_negsample_rate.values())
    vocab_negsample_rate = {key: vocab_negsample_rate[key] / negsample_rate_sum for key in vocab_counts.keys()}
    np.random.seed(seed)
    
    contexts = []
    targets = []
    labels = []
    for document in documents:
        for i in range(len(document)):
            target = document[i]

            context_window = [document[j] for j in range(max(0, i - left_context_window), i)] + [document[j] for j in range(i+1, min(i + right_context_window + 1, len(document)))]
            
            U = np.random.uniform(0,1)
            if U < vocab_subsample_rate[target]:
                continue

            context = []
            for context_word in context_window:
                V = np.random.uniform(0,1)
                if V < vocab_subsample_rate[context_word]:
                    context.append(vocab.index(context_word))
            while len(context) < left_context_window + right_context_window:
                context.append(-1)
            
            targets.append(vocab.index(target))
            contexts.append(context)
            labels.append(1)

            if k_negsample > 0:
                neg_target_words = np.random.choice(vocab, size=k_negsample, replace=False, p=list(vocab_negsample_rate.values()))
                for neg_target_word in neg_target_words:
                    targets.append(vocab.index(neg_target_word))
                    contexts.append(context)
                    labels.append(0)
                    
    return (jnp.array(contexts), jnp.array(targets), jnp.array(labels))

## Softmax Architecture

In [60]:
@jax.jit
def softmax(x):
    exp_x = jnp.exp(x - jnp.max(x, axis=1, keepdims=True))
    return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)

In [61]:
@jax.jit
def softmax_skipgram_net(params, target):
    output = params['E'][target]
    output = output @ params['Theta']
    output = softmax(output)
    return output

@jax.jit
def softmax_skipgram_loss(params, target, y_true):
    y_pred = softmax_skipgram_net(params, target)
    return jnp.mean(-jnp.sum(jnp.log(y_pred + 1e-8) * y_true, axis=1))
softmax_skipgram_loss_value_and_grad = jax.jit(jax.value_and_grad(softmax_skipgram_loss))

In [62]:
@jax.jit
def softmax_cbow_net(params, context):
    output = np.mean(params['E'][context], axis=1)
    output = output @ params['Theta']
    output = softmax(output)
    return output

@jax.jit
def softmax_cbow_loss(params, context, y_true):
    y_pred = softmax_cbow_net(params, context)
    return jnp.mean(-jnp.sum(jnp.log(y_pred + 1e-8) * y_true, axis=1))
softmax_cbow_loss_value_and_grad = jax.jit(jax.value_and_grad(softmax_cbow_loss))

## Negative Sampling Architecture

In [None]:
@jax.jit
def sigmoid(X):
    return 1 / (1 + jnp.exp(-X))

In [None]:
@jax.jit
def negsample_skipgram_net(params, context, target):
    target_embedding = params['E'][target]
    context_parameter = params['Theta'][context]
    output = jnp.sum(target_embedding * context_parameter, axis=1)
    output = sigmoid(output)
    return output

@jax.jit
def negsample_skipgram_loss(params, context, target, label_true):
    label_pred = negsample_skipgram_net(params, context, target)
    return jnp.mean(-jnp.log(jnp.where(label_true == 1, label_pred, 1-label_pred)))
negsample_skipgram_loss_value_and_grad = jax.jit(jax.value_and_grad(negsample_skipgram_loss))

In [None]:
@jax.jit
def negsample_cbow_net(params, context, target):
    context_embedding = jnp.mean(params['E'][context], axis=1)
    target_parameter = params['Theta'][target]
    output = jnp.sum(context_embedding * target_parameter, axis=1)
    output = sigmoid(output)
    return output

@jax.jit
def negsample_cbow_loss(params, context, target, label_true):
    label_pred = negsample_cbow_net(params, context, target)
    return jnp.mean(-jnp.log(jnp.where(label_true == 1, label_pred, 1-label_pred)))
negsample_cbow_loss_value_and_grad = jax.jit(jax.value_and_grad(negsample_cbow_loss))

## Model

In [None]:
# k_negsample = 0 to do softmax else negative sampling
class Model:
    def __init__(self, mode, vocab_counts, n_embeddings, left_context_window, right_context_window, k_negsample):
        assert mode in ["skipgram", "cbow"]
        self.mode = mode
        
        self.vocab_counts = vocab_counts
        self.vocab_size = len(vocab_counts)
        self.vocab_to_id = dict(zip(vocab_counts.keys(), range(len(mario_vocab))))
        self.id_to_vocab = dict(zip(range(len(mario_vocab)), vocab_counts.keys()))
        
        self.n_embeddings = n_embeddings
        self.left_context_window = left_context_window
        self.right_context_window = right_context_window
        self.k_negsample = k_negsample
        
        self.initialize_params()

    def initialize_params(self):
        np.random.seed(seed)
        if self.k_negsample > 0:
            self.params = {
                'E': np.random.normal(0, np.sqrt(1/self.vocab_size), (self.vocab_size, self.n_embeddings)),
                'Theta': np.random.normal(0, np.sqrt(1/self.vocab_size), (self.vocab_size, self.n_embeddings))
            }
        else:
            self.params = {
                'E': np.random.normal(0, np.sqrt(1/self.vocab_size), (self.vocab_size, self.n_embeddings)),
                'Theta': np.random.normal(0, np.sqrt(1/self.n_embeddings), (self.n_embeddings, self.vocab_size))
            }

    def train(self, documents, 
              lr, beta1, beta2, epsilon, 
              n_epochs, batch_size, subsample_rate):
        
        documents_subsampled = subsample(documents)
        
        if self.mode == "skipgram":
            contexts, targets, labels = skipgram(documents_subsampled, self.vocab_counts, self.left_context_window, self.right_context_window, self.k_negsample)
        elif self.mode == "cbow":
            contexts, targets, labels = cbow(documents_subsampled, self.vocab_counts, self.left_context_window, self.right_context_window, self.k_negsample)
        
        n_batches = int(len(contexts) / batch_size)
        self.loss_history = []
        self.time_history = []
        start_time = time.time()
        m = None
        v = None
        np.random.seed(seed)
        for epoch in range(n_epochs):
            shuffled_indices = np.random.permutation(len(contexts))
            shuffled_contexts = contexts[shuffled_indices]
            shuffled_targets = targets[shuffled_indices]
            shuffled_labels = labels[shuffled_indices]
            
            epoch_loss_history = []
            for batch in range(n_batches):
                context_batch = shuffled_contexts[batch*batch_size : (batch+1)*batch_size]
                target_batch = shuffled_targets[batch*batch_size : (batch+1)*batch_size]
                label_batch = shuffled_labels[batch*batch_size : (batch+1)*batch_size]

                if self.k_negsample > 0:
                    if self.mode == "skipgram":
                        loss, loss_grad = negsample_skipgram_loss_value_and_grad(self.params, context_batch, target_batch, label_batch)
                    elif self.mode == "cbow":
                        loss, loss_grad = negsample_cbow_loss_value_and_grad(self.params, context_batch, target_batch, label_batch)
                else:
                    if self.mode == "skipgram":
                        one_hot_context_batch = jnp.array(batch_one_hot(context_batch, self.vocab, True))
                        loss, loss_grad = softmax_skipgram_loss_value_and_grad(self.params, target_batch, one_hot_context_batch)
                    elif self.mode == "cbow":
                        one_hot_target_batch = jnp.array(batch_one_hot(target_batch, self.vocab, True))
                        loss, loss_grad = softmax_cbow_loss_value_and_grad(self.params, context_batch, one_hot_target_batch)
                    
                if m == None:
                    m = loss_grad
                    v = {key: loss_grad[key]**2 for key in loss_grad.keys()}
                else:
                    m = {key: beta1*m[key] + (1-beta1)*loss_grad[key] for key in loss_grad.keys()}
                    v = {key: beta2*v[key] + (1-beta2)*(loss_grad[key]**2) for key in loss_grad.keys()}
                    
                self.params = {key: self.params[key] - lr * m[key] / (epsilon + jnp.sqrt(v[key])) for key in self.params.keys()}
                
                epoch_loss_history.append(loss)
        
            self.loss_history.append(np.mean(epoch_loss_history))
            self.time_history.append(time.time() - start_time)

    # takes in the word for skipgram, list of words for cbow
    def predict(self, X):
        if self.k_negsample > 0:
            if self.mode == "skipgram":
                return negsample_skipgram_net(self.params, jnp.array([self.vocab_to_id(X)]))
            elif self.mode == "cbow":
                return negsample_cbow_net(self.params, jnp.array([[self.vocab_to_id(context) for context in X]]))
        else:
            if self.mode == "skipgram":
                return softmax_skipgram_net(self.params, jnp.array([self.vocab_to_id(X)]))
            elif self.mode == "cbow":
                return softmax_cbow_net(self.params, jnp.array([[self.vocab_to_id(context) for context in X]]))

    def get_embedding(self, word):
        return one_hot(word, self.vocab) @ self.params['E']

In [None]:
mario_model_skipgram2 = Model(mario_vocab_counts, 300, 2, 2, "skipgram", 10)
mario_model_skipgram5 = Model(mario_vocab_counts, 300, 5, 5, "skipgram", 10)
mario_model_cbow2 = Model(mario_vocab_counts, 100, 2, 2, "cbow", 10)
mario_model_cbow5 = Model(mario_vocab_counts, 100, 5, 5, "cbow", 10)

mario_model_skipgram2.train(mario_documents_sentences, 0.01, 0.9, 0.999, 1e-8, 10000, 100, 0.0001)
mario_model_skipgram5.train(mario_documents_sentences, 0.01, 0.9, 0.999, 1e-8, 10000, 100, 0.0001)
mario_model_cbow2.train(mario_documents_sentences, 0.01, 0.9, 0.999, 1e-8, 10000, 100, 0.0001)
mario_model_cbow5.train(mario_documents_sentences, 0.01, 0.9, 0.999, 1e-8, 10000, 100, 0.0001)

model_path = "models/"
with open(os.path.join(model_path, 'mario_model_skipgram2.pkl'), 'wb') as file:
    pkl.dump(mario_model_skipgram2, file)
with open(os.path.join(model_path, 'mario_model_skipgram5.pkl'), 'wb') as file:
    pkl.dump(mario_model_skipgram5, file)
with open(os.path.join(model_path, 'mario_model_cbow2.pkl'), 'wb') as file:
    pkl.dump(mario_model_cbow2, file)
with open(os.path.join(model_path, 'mario_model_cbow5.pkl'), 'wb') as file:
    pkl.dump(mario_model_cbow5, file)

In [None]:
plt.plot(mario_model_skipgram2.time_history, mario_model_skipgram2.loss_history)
plt.plot(mario_model_skipgram5.time_history, mario_model_skipgram5.loss_history)
plt.plot(mario_model_cbow2.time_history, mario_model_cbow2.loss_history)
plt.plot(mario_model_cbow5.time_history, mario_model_cbow5.loss_history)
plt.yscale("log")
plt.xlabel("Time (s)")
plt.ylabel("Log Loss")
plt.legend(["skipgram2", "skipgram5", "cbow2", "cbow5"])
plt.show()

In [None]:
p = mario_model_skipgram2.predict("princess")
sorted_indices = jnp.flip(jnp.argsort(p[0]))
np.array(list(zip(mario_vocab, p[0].tolist())))[sorted_indices][:10]

In [None]:
@jax.jit
def cosine_sim(a, b):
    return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))

batch_cosine_sim = jax.jit(jax.vmap(cosine_sim, (None, 0), 0))

In [None]:
model = mario_model_skipgram2

m = model.get_embedding("mario")
sim = batch_cosine_sim(m, model.params['E'])
sorted_indices = jnp.flip(jnp.argsort(sim))
np.array(list(zip(mario_vocab, sim.tolist())))[sorted_indices][:10]