In [1]:
import numpy as np
import pandas as pd
from nltk import word_tokenize
from datasets import load_dataset
import string
import jax
import jax.numpy as jnp
import re

In [2]:
# get dataset from huggingface
data = load_dataset("imdb")

# get the reviews as array
movie_reviews = np.array(data["unsupervised"].data["text"])

# tokenize the dataset and remove punctuations
movie_reviews_clean = []
for review in movie_reviews:
    movie_reviews_clean.append([re.sub("[^\w]", "", word.lower()) for word in word_tokenize(review)])

In [13]:
context_amount = 5
context_target_pair = []
for review in movie_reviews_clean:
    for i in range(context_amount//2, len(review)-context_amount//2):
        context_target_pair.append(((*review[(i-context_amount//2):i], *review[(i+1):(i+context_amount//2+1)]), review[i]))

print(context_target_pair[:10])
print(len(context_target_pair))

[(('this', 'is', 'a', 'precious'), 'just'), (('is', 'just', 'precious', 'little'), 'a'), (('just', 'a', 'little', 'diamond'), 'precious'), (('a', 'precious', 'diamond', ''), 'little'), (('precious', 'little', '', 'the'), 'diamond'), (('little', 'diamond', 'the', 'play'), ''), (('diamond', '', 'play', ''), 'the'), (('', 'the', '', 'the'), 'play'), (('the', 'play', 'the', 'script'), ''), (('play', '', 'script', 'are'), 'the')]
13967876


In [37]:
def create_dictionary(texts):
    # record known words
    tmp = set()
    for text in texts:
        for word in text:
            tmp.add(word)
            
    # create dictionary mapping from word to unique id
    dct = {}
    i = 0
    for word in tmp:
        dct[word] = i
        i += 1
    return dct

def one_hot_encode_from_word(word, dct):
    arr = np.zeros(len(dct))
    arr[dct[word]] = 1
    return arr

def one_hot_encode_from_id(index, dct):
    arr = np.zeros(len(dct))
    arr[index] = 1
    return arr

def softmax(z):
    # shift input values for numerical stability
    z -= jnp.max(z)
    
    # compute softmax
    exp_z = jnp.exp(z)
    s = exp_z / jnp.sum(exp_z, axis=0)
    
    return s

def init_weights(dct_size, embedding_size):
    lst = []
    lst.append(np.random.normal(0, 1, size = (embedding_size, dct_size)))
    lst.append(np.random.normal(0, 1, size = (dct_size, embedding_size)))
    return lst

def forward(x, weights):
    return softmax(jnp.matmul(weights[1], jnp.matmul(weights[0], x)))

def loss(weights, x, y):
    return -jnp.mean(jnp.sum(y * jnp.log(forward(x, weights)), axis=1))

loss_value_and_grad = jax.jit(jax.value_and_grad(loss))

In [31]:
# create dictionary for known words
dct = create_dictionary(movie_reviews_clean)

# initialize weights randomly
W = init_weights(len(dct), 300)

# gradient_descent
N = len(context_target_pair)
lr = 0.01
n_epochs = 10000
batch_size = 256
n_batches = N // batch_size

In [41]:
losses = []

for epoch in range(n_epochs):
    # shuffle the data
    perm = np.random.permutation(N)
    data = np.array(context_target_pair, dtype=object)
    data_x = data[perm, 0]
    data_y = data[perm, 1]

    # stores all the losses for this epoch
    epoch_losses = []
    for batch in range(n_batches):
        x_batch = data_x[batch*batch_size:(batch+1)*batch_size]
        y_batch = data_y[batch*batch_size:(batch+1)*batch_size]
        # one_hot_encode each batch
        tmp = []
        for context in x_batch:
            vec = np.zeros(len(dct))
            for word in context:
                vec += one_hot_encode_from_word(word, dct)
            tmp.append(vec)
        x_batch = jnp.array(tmp)
        y_batch = jnp.array([one_hot_encode_from_word(word, dct) for word in y_batch])
        # print(x_batch)
        # print(y_batch)
        
        loss_value, grad = loss_value_and_grad(W, x_batch[..., jnp.newaxis], y_batch[..., jnp.newaxis])
        losses.append(loss_value)
        # print(grad)
        W = [W[0]-grad[0], W[1]-grad[1]]
        # a slightly better / more-readable way to do this update is to use the 
        # `jax.tree_map`: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
        # and the update would read as follows:
        #
        # params = jax.tree_map(lambda param, grad: param - lr*grad,params, grad)
        #
        epoch_losses.append(loss_value)
        
    # display the loss every 100 epochs
    if epoch % 500 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, loss = {np.mean(epoch_losses)}")


KeyboardInterrupt: 