In [2]:
import numpy as np
import pandas as pd
import nltk
from nltk import word_tokenize
from nltk.corpus import stopwords
from datasets import load_dataset
import string
import jax
import jax.numpy as jnp
import re

In [3]:
nltk.download('brown')
nltk.download('stopwords')

[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\Vincent\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Vincent\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [4]:
brown = nltk.corpus.brown
sents = brown.sents()
stopwords = stopwords.words("english")

processed_sents = []
for sent in sents:
    processed_sents.append([word.lower() for word in sent if word.isalnum() and word not in stopwords])

print('Number of Sentences:', len(sents))

Number of Sentences: 57340


In [22]:
def create_dictionary(texts):
    # record known words
    tmp = set()
    for text in texts:
        for word in text:
            tmp.add(word)
            
    # create dictionary mapping from word to unique id
    dct = {}
    i = 0
    for word in tmp:
        dct[word] = i
        i += 1
    return dct

def one_hot_encode_from_word(word, dct):
    arr = np.zeros(len(dct))
    arr[dct[word]] = 1
    return arr

def one_hot_encode_from_id(index, dct):
    arr = np.zeros(len(dct))
    arr[index] = 1
    return arr

def softmax(x):
    # Shift x to avoid numerical instability
    shifted_x = x - jnp.max(x, axis=0, keepdims=True)
    
    # Compute softmax values
    exp_x = jnp.exp(shifted_x)
    softmax_values = exp_x / jnp.sum(exp_x, axis=0, keepdims=True)
    
    return softmax_values

def init_weights(dct_size, embedding_size):
    lst = []
    lst.append(np.random.normal(0, 1, size = (embedding_size, dct_size)))
    lst.append(np.random.normal(0, 1, size = (dct_size, embedding_size)))
    return lst

@jax.jit
def forward(weights, x):
    return jnp.matmul(weights[1], jnp.matmul(weights[0], x))

@jax.jit
def loss(weights, x, y):
    x_max = jnp.max(x, axis=0, keepdims=True)
    log_softmax = x - x_max - jnp.log(jnp.sum(jnp.exp(x - x_max), axis=0, keepdims=True))
    return -jnp.sum(log_softmax * y)

loss_value_and_grad = jax.jit(jax.value_and_grad(loss))

In [26]:
# create dictionary for known words
dct = create_dictionary(processed_sents)
print(len(dct))

# initialize weights randomly
W = init_weights(len(dct), 300)

context_amount = 5
context_target_pair = []
for review in processed_sents:
    for i in range(context_amount//2, len(review)-context_amount//2):
        context_target_pair.append(((*review[(i-context_amount//2):i], *review[(i+1):(i+context_amount//2+1)]), review[i]))
print(len(context_target_pair))
        
# gradient_descent
N = len(context_target_pair)
lr = 0.01
n_epochs = 100
batch_size = 256
n_batches = N // batch_size

losses = []

for epoch in range(n_epochs):
    # shuffle the data
    perm = np.random.permutation(N)
    data = np.array(context_target_pair, dtype=object)
    data_x = data[perm, 0]
    data_y = data[perm, 1]

    # stores all the losses for this epoch
    epoch_losses = []
    for batch in range(n_batches):
        x_batch = data_x[batch*batch_size:(batch+1)*batch_size]
        y_batch = data_y[batch*batch_size:(batch+1)*batch_size]
        # one_hot_encode each batch
        tmp = []
        for context in x_batch:
            vec = np.zeros(len(dct))
            for word in context:
                vec += one_hot_encode_from_word(word, dct)
            tmp.append(vec)
        x_batch = jnp.array(tmp)
        y_batch = jnp.array([one_hot_encode_from_word(word, dct) for word in y_batch])
        # print(x_batch)
        # print(y_batch)
        
        loss_value, grad = loss_value_and_grad(W, x_batch.T, y_batch.T)
        losses.append(loss_value)
        #  print(grad)
        #  print(jax.grad(loss)(W, x_batch.T, y_batch.T))
        #  break
        W = [W[0]-lr*grad[0], W[1]-lr*grad[1]]
        # a slightly better / more-readable way to do this update is to use the 
        # `jax.tree_map`: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
        # and the update would read as follows:
        #
        # params = jax.tree_map(lambda param, grad: param - lr*grad,params, grad)
        #
        epoch_losses.append(loss_value)
#         if batch%500 == 0: print("batch: ", batch, "/", n_batches, " ", loss_value)
        
    # display the loss every 100 epochs
#     if (epoch+1) % 5 == 0:
#         print(f"Epoch {epoch+1}/{n_epochs}, loss = {np.mean(epoch_losses)}")
    print(f"Epoch {epoch+1}/{n_epochs}, loss = {np.mean(epoch_losses)}")


41007
343371


KeyboardInterrupt: 