In [1]:
import sys
sys.path.append("/home/jaxmao/JaxMao")
sys.path

['/home/jaxmao/JaxMao/Example',
 '/usr/lib/python310.zip',
 '/usr/lib/python3.10',
 '/usr/lib/python3.10/lib-dynload',
 '',
 '/home/jaxmao/.local/lib/python3.10/site-packages',
 '/usr/local/lib/python3.10/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/home/jaxmao/JaxMao']

In [2]:
import jax.numpy as jnp
from jax import jit, value_and_grad
from jax import random

import re    
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

seed = 42
key = random.PRNGKey(seed)

# def generate_dataset(tokens, window_size=5):
#     half_window = window_size // 2
#     dataset = []

#     for i in range(len(tokens)):
#         # Extract the window centered at the current token
#         start_idx = max(0, i - half_window)
#         end_idx = min(len(tokens), i + half_window + 1)
#         window = tokens[start_idx:end_idx]

#         # If the window size is correct
#         if len(window) == window_size:
#             input_word = window[half_window]
#             output_words = [word for idx, word in enumerate(window) if idx != half_window]
#             dataset.append((input_word, output_words))

#     return dataset

def generate_dataset(tokens, window_size=5):
    half_window = window_size // 2
    dataset = []

    for i in range(len(tokens)):
        # Extract the window centered at the current token
        start_idx = max(0, i - half_window)
        end_idx = min(len(tokens), i + half_window + 1)
        window = tokens[start_idx:end_idx]

        # If the window size is correct
        if len(window) == window_size:
            input_word = window[half_window]
            output_words = [word for idx, word in enumerate(window) if idx != half_window]
            dataset.append((input_word, output_words))

    return dataset

def generate_dataset2(tokens, encoder, window_size=5):
    half_window = window_size // 2
    dataset = []

    for i in range(len(tokens)):
        # Extract the window centered at the current token
        start_idx = max(0, i - half_window)
        end_idx = min(len(tokens), i + half_window + 1)
        window = tokens[start_idx:end_idx]

        # If the window size is correct
        if len(window) == window_size:
            input_word = encoder([[window[half_window]]])
            output_words = [encoder([[word]]) for idx, word in enumerate(window) if idx != half_window]
            dataset.append((input_word, output_words))

    return dataset

In [3]:
with open('datasets/randomlyGeneratedText.txt') as f:
    sentences = f.read()
    sentences = re.sub('[^A-Za-z0-9]+', ' ', sentences)
    sentences = re.sub(r'\[ *\d+ *\]', ' ', sentences)
    sentences = re.sub('ï»¿', '', sentences)
    tokens = sentences.lower().split(' ')

patterns_to_be_removed = ['\n', '(', ')']
for pattern in patterns_to_be_removed:
    tokens = [token.replace(pattern, '') for token in tokens]
    
vocab = list(set(tokens))
vocab_size = len(vocab)

len(tokens), vocab_size

(48611, 1648)

In [5]:
window_size = 5

onehot_encoder = OneHotEncoder()
onehot_encoder.fit(np.reshape(vocab, (-1,1)))

# Kernel keep crashing when use full dataset
dataset = generate_dataset2(tokens[:40000], onehot_encoder.transform, window_size) 

In [7]:
len(dataset)

39996

##### idea

In [8]:
X = np.array([[e.toarray() for e in entry[1]] for entry in dataset]).reshape(-1, window_size-1, vocab_size)
y = np.array([entry[0].toarray() for entry in dataset]).reshape(-1, vocab_size)

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.85)

In [9]:
np.save('randomlyGeneratedTextX.npy', X)
np.save('randomlyGeneratedTexty.npy', y)

X.shape, y.shape

((39996, 4, 1648), (39996, 1648))

In [10]:
onehot_encoder.inverse_transform(X_train[2]), onehot_encoder.inverse_transform(y_train[2].reshape(-1, vocab_size))

(array([['to'],
        ['preserved'],
        ['mr'],
        ['cordially']], dtype='<U14'),
 array([['be']], dtype='<U14'))

# Build Model

In [11]:
# Model
from jaxmao.Modules import Module
from jaxmao.Layers import FC, SimpleRNN, Flatten
from jaxmao.Activations import ReLU, StableSoftmax

# Training
from jaxmao.Optimizers import GradientDescent, Adam
from jaxmao.Losses import CategoricalCrossEntropy

class Embedding2(Module):
    def __init__(
            self
        ):
        self.flatten = Flatten()
        self.fc1 = FC(vocab_size, embedding_dim)
        self.fc_out = FC(embedding_dim*4, vocab_size)

        self.relu = ReLU()
        self.softmax = StableSoftmax()

    def __call__(self, x):
        x = self.fc1(x)

        x = x.reshape(x.shape[0], -1)
        x = self.softmax(self.fc_out(x))
        return x
    
key = random.PRNGKey(42)

embedding_dim = 16
PNWmodel = Embedding2()
PNWmodel.init_params(key)
best_params = {'params' : PNWmodel.params, 'loss' : np.inf}

In [12]:
loss_crossentropy = CategoricalCrossEntropy()

def loss_params(params, x, y):
    pred = PNWmodel.forward(params, x)
    return loss_crossentropy(pred, y)

grad_loss = jit(value_and_grad(loss_params))

In [13]:
optimizer = GradientDescent()

def training_loop(x, y, epochs=20, lr=0.01, batch_size=32, initial_epoch=0):
    num_batches = len(x) // batch_size

    for epoch in range(epochs):
        lr = lr*(0.99**initial_epoch)
        x, y = shuffle(x, y)
        for batch_idx in range(num_batches):
            starting_idx = batch_idx * batch_size
            ending_idx = (batch_idx + 1) * batch_size
            batch_x = x[starting_idx:ending_idx]
            batch_y = y[starting_idx:ending_idx]

            losses, gradients = grad_loss(PNWmodel.params, batch_x, batch_y)
            PNWmodel.params = optimizer(PNWmodel.params, gradients, lr=lr)
        
        loss = losses / batch_size
        if loss < best_params['loss']:
            best_params['params'] = PNWmodel.params
            best_params['loss'] = loss
        if (epoch+1) % 1 == 0:
            print("Epoch: {}\tbatch loss: {}".format(epoch+1, loss))
        
        
    return PNWmodel.params

params = training_loop(
        X_train, y_train,
        epochs=50, lr=0.01, batch_size=32,
        initial_epoch=0
    )

: 

In [None]:
best_params['loss']

Array(0.11528338, dtype=float32)

In [None]:
outputs = PNWmodel.forward(best_params['params'], X_train)
output = outputs[:500]
trainY = y_train[:500]

output.shape, trainY.shape

((500, 1272), (500, 1272))

In [None]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(output.argmax(axis=1), trainY.argmax(axis=1))
print("Accuray: ", accuracy)

Accuray:  1.0


In [None]:
outputs = PNWmodel.forward(best_params['params'], X_test)
output = outputs[:500]
testY = y_test[:500].reshape(-1, vocab_size)

output.shape, testY.shape

((500, 1272), (500, 1272))

In [None]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(output.argmax(axis=1), testY.argmax(axis=1))
print("Accuray: ", accuracy)

Accuray:  0.458
