In [34]:
# Import relevant libraries
import os
import pickle
import jax
jax.config.update('jax_platforms', 'cpu')
import numpy
import trax.fastmath.numpy as np
import random 
import trax
from trax import fastmath
from trax import layers as tl

In [35]:
# Read the data
lines = []
for file in os.listdir('shakespeare_data'):
    f = open('shakespeare_data/' + file,'r')
    for line in f.readlines():
        line = line.strip()
        if line:
            lines.append(line.lower())


In [36]:
# Split into training and eval
train_lines = lines[:-2000]
test_lines = lines[-2000:]

In [37]:
# Function to convert a line to a tensor
def line_to_tensor(line, EOS_int=1):
    char_tensor = []
    for char in line:
        char_tensor.append(ord(char))
    char_tensor.append(EOS_int)
    return char_tensor

In [38]:
# Function for data generator
def data_generator(batch_size, max_length, data_lines, line_to_tensor=line_to_tensor, shuffle=True):
    batch = []
    line_indexes = [*range(len(data_lines))]

    if shuffle:
        numpy.random.shuffle(line_indexes)
    len_batch = 0
    index = 0
    while True:
        while len_batch < batch_size:
            if index == len(data_lines):
                index = 0
                if shuffle:
                    numpy.random.shuffle(data_lines)
            if len(data_lines[index]) < max_length:
                batch.append(data_lines[index])
                len_batch += 1
            index+=1
        tensors = []
        for line in batch:
            tensors.append(line_to_tensor(line))
        padded_tensor = []
        mask = []
        for tensor in tensors:
            if len(tensor) < max_length:
                padded_tensor.append(tensor + [0]*(max_length - len(tensor)))
                mask.append([1]*len(tensor) + [0]*(max_length - len(tensor)))
            else:
                padded_tensor.append(tensor)
                mask.append([1]*max_length)
        padded_tensor = np.array(padded_tensor)
        mask = np.array(mask)
        yield padded_tensor,padded_tensor,mask
        len_batch = 0
        batch = []

In [39]:
# Define the model
def GRU_Model(vocab_size=256, d_model=512, n_layers=2, mode='train'):
    model = tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(vocab_size=vocab_size, d_feature=d_model), 
      [tl.GRU(n_units=d_model) for _ in range(n_layers)], 
      tl.Dense(n_units=vocab_size), 
      tl.LogSoftmax()
    )
    return model

In [40]:
# Create the model
model = GRU_Model()

In [41]:
# Define hyperparameters
batch_size = 32
max_length = 64

In [42]:
# Define the training loop
from trax.supervised import training
import itertools
def train_model(model, data_generator, batch_size=32, max_length=64, lines=train_lines, eval_lines=test_lines, n_steps=5000): 
    bare_train_generator = data_generator(batch_size, max_length, data_lines=lines)
    infinite_train_generator = itertools.cycle(bare_train_generator)
    
    bare_eval_generator = data_generator(batch_size, max_length, data_lines=eval_lines)
    infinite_eval_generator = itertools.cycle(bare_eval_generator)
   
    train_task = training.TrainTask(
        labeled_data=infinite_train_generator,  
        loss_layer=tl.CrossEntropyLoss(),  
        optimizer=trax.optimizers.Adam(0.001)
    )

    eval_task = training.EvalTask(
        labeled_data=infinite_eval_generator,
        metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
        n_eval_batches=3  
    )
    
    training_loop = training.Loop(model,
                                  train_task,
                                  eval_tasks=eval_task)

    training_loop.run(n_steps=n_steps)
    
    return training_loop

In [43]:
# Train the model
training_loop = train_model(GRU_Model(), data_generator)



Will not write evaluation metrics, because output_dir is None.
Did not save checkpoint as output_dir is None

Step      1: Total number of trainable weights: 3411200
Step      1: Ran 1 train steps in 12.96 secs
Step      1: train CrossEntropyLoss |  5.54513550
Step      1: eval  CrossEntropyLoss |  5.54091152
Step      1: eval          Accuracy |  0.16486038
Did not save checkpoint as output_dir is None

Step    100: Ran 99 train steps in 212.41 secs
Step    100: train CrossEntropyLoss |  3.37593198
Step    100: eval  CrossEntropyLoss |  2.88648979
Step    100: eval          Accuracy |  0.18967202
Did not save checkpoint as output_dir is None

Step    200: Ran 100 train steps in 212.30 secs
Step    200: train CrossEntropyLoss |  2.73634338
Step    200: eval  CrossEntropyLoss |  2.57715742
Step    200: eval          Accuracy |  0.26294392
Did not save checkpoint as output_dir is None

Step    300: Ran 100 train steps in 205.81 secs
Step    300: train CrossEntropyLoss |  2.42932844
Step 

In [44]:
# Function to get the log perplexity
def test_model(preds, target):
    total_log_ppx = np.sum(preds * tl.one_hot(target, preds.shape[-1]),axis= -1)
    non_pad = 1.0 - np.equal(target, 0)          
    ppx = total_log_ppx * non_pad                   
    log_ppx = np.sum(ppx) / np.sum(non_pad)
    return -log_ppx

In [45]:
# Get the log perplexity for a batch
batch = next(data_generator(batch_size, max_length, lines, shuffle=False))
preds = training_loop.eval_model(batch[0])
log_ppx = test_model(preds, batch[1])
print('The log perplexity and perplexity of your model are respectively', log_ppx, np.exp(log_ppx))

The log perplexity and perplexity of your model are respectively 2.7764776 16.062344


In [46]:
# Function to sample from the gumbel distribution
def gumbel_sample(log_probs, temperature=1.0):
    u = numpy.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
    g = -np.log(-np.log(u))
    return np.argmax(log_probs + g * temperature, axis=-1)

shoul to caniolanusus


In [None]:
# Function to predict the next characters
def predict(num_chars, prefix):
    inp = [ord(c) for c in prefix]
    result = [c for c in prefix]
    max_len = len(prefix) + num_chars
    for _ in range(num_chars):
        cur_inp = np.array(inp + [0] * (max_len - len(inp)))
        outp = training_loop.eval_model(cur_inp[None, :]) 
        next_char = gumbel_sample(outp[0, len(inp)])
        inp += [int(next_char)]
       
        if inp[-1] == 1:
            break  # EOS
        result.append(chr(int(next_char)))
    
    return "".join(result)

In [50]:
# Predict the next characters for a given prefix
print(predict(32, "king henry"))

king henry serviand blear what me fair hum
