In [18]:
import os
import json
from keras.models import Sequential
from keras.layers import LSTM, Dropout, TimeDistributed, Dense, Activation, Embedding
import numpy as np
import pandas as pd
from matplotlib import pyplot

In [19]:
# Initialize paths for data input and weights output
data_dir = "../Data/"
data_file = "Notting_OneillQuart"
save_weights_dir = '../Trained_Weights/Weights_Notting_Oneill/'
log_dir = "../Data/log_Notting_Oneill.csv"
charToIndex_json = "char_to_index.json"
# Parameters
BATCH_SIZE = 32
SEQ_LENGTH = 128

In [20]:
# Function used to create the batches
def get_batches(chars, unique_chars):
    char_no = chars.shape[0] # number of characters in the data
    batch_chars = int(char_no / BATCH_SIZE)
    
    # outer loop iterates every time a new batch is created
    for start in range(0, batch_chars - SEQ_LENGTH, SEQ_LENGTH):
        # number of batches wil be char_no/(BATCH_SIZE * SEQ_LENGTH)
        X = np.zeros((BATCH_SIZE, SEQ_LENGTH))  
        Y = np.zeros((BATCH_SIZE, SEQ_LENGTH, unique_chars))
        # iterates over rows in a batch
        for batch_row in range(0, BATCH_SIZE):
            # iterates over columns in a batch
            for i in range(0, SEQ_LENGTH): 
                X[batch_row, i] = chars[batch_row * batch_chars + start + i]
                Y[batch_row, i, chars[batch_row * batch_chars + start + i + 1]] = 1
                # by 1 we mark that the next character in the sequence is the correct one
        yield X, Y

In [21]:
def build_model(batch_size, seq_length, unique_chars):
    model = Sequential()
    
    # inputs have to be the same length which is achieved when creating batches
    # input dimension will be the number of unique characters in the training data
    # output-dimention needs more validation - 8?
    model.add(Embedding(input_dim = unique_chars, output_dim = 16, batch_input_shape = (batch_size, seq_length))) 
    
    # Using keras Dropout to prevent overfitting
    model.add(LSTM(256, return_sequences = True, stateful = True))
    model.add(Dropout(0.2))
    
    model.add(LSTM(128, return_sequences = True, stateful = True))
    model.add(Dropout(0.2))
    
    model.add(TimeDistributed(Dense(unique_chars)))
    model.add(Activation("softmax"))
    
    return model

In [22]:
def train_model(data, epochs = 80):
    
    # Mapping all unique characters to an index
    char_to_index = {char: x for (x, char) in enumerate(sorted(list(set(data))))}
    print("Unique characters in the training data = {}".format(len(char_to_index)))  
    # Saved the mapping in a json file
    with open(os.path.join(data_dir, charToIndex_json), mode = "w") as f:
        json.dump(char_to_index, f)
        
    index_to_char = {x: char for (char, x) in char_to_index.items()}
    unique_chars = len(char_to_index)
    
    # Build the model
    model = build_model(BATCH_SIZE, SEQ_LENGTH, unique_chars)
    model.summary()
    # multi-class classification problem - using Categorical Cross entropy as loss function
    model.compile(loss = "categorical_crossentropy", optimizer = "adam", metrics = ["accuracy"])
    
    characters = np.asarray([char_to_index[c] for c in data], dtype = np.int32)
    print("Number of characters = " + str(characters.shape[0]))
    
    # saving training data for furture logging
    saved_epoch, loss, accuracy = [], [], []
    for epoch in range(epochs):
        print("Epoch {}/{}".format(epoch+1, epochs))
        last_epoch_loss, last_epoch_accuracy = 0, 0
        saved_epoch.append(epoch+1)
        
        # reading the batches one by one and training the model on each one
        for i, (x, y) in enumerate(get_batches(characters, unique_chars)):
            last_epoch_loss, last_epoch_accuracy = model.train_on_batch(x, y) 
            print("Batch No.: {}, Loss: {}, Accuracy: {}".format(i+1, last_epoch_loss, last_epoch_accuracy))
        loss.append(last_epoch_loss)
        accuracy.append(last_epoch_accuracy)
        
        # Saving the computed weights each 10th epoch
        if (epoch + 1) % 10 == 0:
            if not os.path.exists(save_weights_dir):
                os.makedirs(save_weights_dir)
            model.save_weights(os.path.join(save_weights_dir, "Weights_{}.h5".format(epoch+1)))
            print('Saved weights computed at epoch {} to Weights_{}.h5'.format(epoch+1, epoch+1))
    
    # Logging the training data into a DataFrame structure to be saved to file after each training
    log_frame = pd.DataFrame(columns = ["Epoch", "Loss", "Accuracy"])
    log_frame["Epoch"] = saved_epoch
    log_frame["Loss"] = loss
    log_frame["Accuracy"] = accuracy
    log_frame.to_csv(log_dir, index = False)
    
    # Accuracy Plot
    pyplot.plot(accuracy, saved_epoch)
    pyplot.show()

In [None]:
file = open(os.path.join(data_dir, data_file), mode = 'r')
data = file.read()
file.close()
if __name__ == "__main__":
    train_model(data)

Unique characters in the training data = 93
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_3 (Embedding)      (32, 128, 16)             1488      
_________________________________________________________________
lstm_5 (LSTM)                (32, 128, 256)            279552    
_________________________________________________________________
dropout_5 (Dropout)          (32, 128, 256)            0         
_________________________________________________________________
lstm_6 (LSTM)                (32, 128, 128)            197120    
_________________________________________________________________
dropout_6 (Dropout)          (32, 128, 128)            0         
_________________________________________________________________
time_distributed_3 (TimeDist (32, 128, 93)             11997     
_________________________________________________________________
activation_3 (Activation)    (32

In [11]:
log = pd.read_csv(log_dir)
log

Unnamed: 0,Epoch,Loss,Accuracy
0,1,3.671352,0.161133
1,2,3.408862,0.158203
2,3,2.964650,0.232666
3,4,2.505745,0.350708
4,5,2.146323,0.435181
5,6,1.894940,0.494019
6,7,1.736072,0.533325
7,8,1.608731,0.558228
8,9,1.502052,0.583984
9,10,1.421677,0.606689
