Training deep learning model for Spell Correction:

- Training data stored in phx2 HDFS, Parquet format
- Run with Kernel Python 3(General DS)
- Train with default docker image deeplearning_examples
- Training cluster: phx4-prod2
- Resource pool: root/ProductPlatform/Michelangelo
- Trained model will be saved in 'containers' of Instance 0

In [None]:
import michelangelo.malib.lambdadl as ldl
from pyspark import SparkContext, SparkConf
import numpy as np
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Embedding, Bidirectional, LSTM, CuDNNLSTM, Dense, Dropout, TimeDistributed, Masking, Lambda
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
import datetime
import json

In [None]:
# Please refer to https://docs.google.com/spreadsheets/d/1-XldpTa11qHMYJDR7ZTh27Vmo9HcUdqDXAqgMp4N6uU/edit#gid=2084303453 for the performance with various setting combinations
# Bi-gram with 256 LSTM units show good performance

NGRAM = 2

USE_GPS = False

if USE_GPS:
   BASE_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789 -.,"    
   DATA_COLS = ["s2cell_id", "raw_query", "corrected_query"] 
   QUERY_ENCODED_LEN = 52
   LABEL_ENCODED_LEN = 54
else:
   BASE_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789 -"
   DATA_COLS = ["raw_query", "corrected_query"]
   QUERY_ENCODED_LEN = 30
   LABEL_ENCODED_LEN = 32

# Network hidden units
LSTM_UNITS = 256

BATCH_SIZE = 256
EPOCH_CNT = 200

LEARNING_RATE = 1.0
LEARNING_RATE_STR = 'lr1_0'

PETASTORM_HDFS_DRIVER = 'libhdfs'
ZONE = 'phx4-prod02'

WORKER_CNT = 40
MEMORY_SIZE_MB = 61440
TIME_OUT_S = 20000

# To get the delegation token, in terminal of development machine
# $ drogon token --renew -user <username> --cluster phx4
# $ drogon token-view --user <username> --cluster phx4 
HDFS_DELEGATION_TOKEN = <hdfs_delegation_token>

In [None]:
model_path = '/mnt/mesos/sandbox/'
model_file_prefix = 'sc_lstm'

if USE_GPS:
    use_gps_str = '_gps'
else:
    use_gps_str = "_nogps"

file_parts = "_ngram" + str(NGRAM) + "_" + LEARNING_RATE_STR + "_batch" + str(BATCH_SIZE) + "_lstmunit" + str(LSTM_UNITS) + use_gps_str

# Result files
model_file = model_path + model_file_prefix + file_parts + ".h5"
model_weights_file = model_path + model_file_prefix + "weights_" + file_parts + ".h5"
model_config_json = model_path + model_file_prefix + "config_" + file_parts + ".json"
model_checkpoint_file = model_path + model_file_prefix + "{epoch:02d}_{acc:.4f}_{val_acc:.4f}_" + file_parts + ".h5"

training_history_pickle = model_path + model_file_prefix + "training_history_" + file_parts + ".pickle"

In [None]:
# Training & Validation datasets
training_data202010 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2020-10_training4M/part-00000-de71c9bc-1905-49e6-8c48-89c1c40890db-c000.snappy.parquet'
training_data202011 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2020-11_training4M/part-00000-ab29cacb-79a6-4f5e-b5fe-76a6f588555b-c000.snappy.parquet'
training_data202012 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2020-12_training4M/part-00000-5c86a121-63e2-48c9-85c4-1eaec3250a19-c000.snappy.parquet'
training_data202101 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-01_training4M/part-00000-bedd01f2-6061-4ee8-90e8-bb36ff8c26a7-c000.snappy.parquet'
training_data202102 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-02_training4M/part-00000-6f651532-b1fe-426a-b9b8-9ba046d9c2d6-c000.snappy.parquet'
training_data20210301 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-01_training4M/part-00000-e0141b8c-3ce7-4ccb-903e-294db6bb7431-c000.snappy.parquet'
training_data20210302 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-02_training4M/part-00000-83f7938f-4247-4fc0-8f28-e7740bc1f867-c000.snappy.parquet'
training_data20210303 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-03_training4M/part-00000-344ab067-ceb5-4f14-becc-ee6939dbaccf-c000.snappy.parquet'
training_data20210304 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-04_training4M/part-00000-eb4d0113-67ee-481a-b9a9-3014fc7b1249-c000.snappy.parquet'
training_data20210305 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-05_training4M/part-00000-056c0013-3a2c-472a-8241-bb05834fb70e-c000.snappy.parquet'

val_data202010 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2020-10_val4M/part-00000-caede584-91ef-416c-8081-f0c01dad0fa4-c000.snappy.parquet'
val_data202011 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2020-11_val4M/part-00000-bcc46407-4df3-41b3-a3cb-bd2ff976c0e9-c000.snappy.parquet'
val_data202012 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2020-12_val4M/part-00000-c19c9487-8158-492c-b217-ac28c31447d0-c000.snappy.parquet'
val_data202101 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-01_val4M/part-00000-f7b6dd0d-2620-42c7-a79c-8a19a2165115-c000.snappy.parquet'
val_data202102 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-02_val4M/part-00000-0028f347-42bb-458e-98b1-0978ab283c16-c000.snappy.parquet'
val_data20210301 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-01_val4M/part-00000-f83bb908-1d30-4e6a-8027-21ee275b8ac9-c000.snappy.parquet'
val_data20210302 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-02_val4M/part-00000-77dae0c1-004e-4422-8c4d-e1f6221c0448-c000.snappy.parquet'
val_data20210303 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-03_val4M/part-00000-c8218082-88a5-4df1-9e9e-9eafde61519e-c000.snappy.parquet'
val_data20210304 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-04_val4M/part-00000-3cd137db-493b-4ba4-8c2b-b398f132e16e-c000.snappy.parquet'
val_data20210305 = '/app/geosearch/spell_check/misspelling_corrected_s2cell_lvl6_parquet/US_2021-03-05_val4M/part-00000-68f5aa31-39b3-4b38-95f6-465b41c3b58c-c000.snappy.parquet'

In [None]:
def buildVocabulary(gramLen):
    encodeVoc = dict()
    decodeVoc = dict()
    charCnt = len(BASE_CHARS)

    gramIndex = 0
    if gramLen == 1:
        for i in range(charCnt):
            gram = BASE_CHARS[i]
            gramIndex = gramIndex + 1
            encodeVoc[gram] = gramIndex
            decodeVoc[gramIndex] = gram
    elif gramLen == 2:
        for i in range(charCnt):
            for j in range(charCnt):
                gram = BASE_CHARS[i] + BASE_CHARS[j]
                gramIndex = gramIndex + 1
                encodeVoc[gram] = gramIndex
                decodeVoc[gramIndex] = gram
    elif gramLen == 3:
        for i in range(charCnt):
            for j in range(charCnt):
                for k in range(charCnt):
                    gram = BASE_CHARS[i] + BASE_CHARS[j] + BASE_CHARS[k]
                    gramIndex = gramIndex + 1
                    encodeVoc[gram] = gramIndex
                    decodeVoc[gramIndex] = gram

    return encodeVoc,decodeVoc

In [None]:
encodeVoc,decodeVoc = buildVocabulary(NGRAM)
vocabulary_size = len(BASE_CHARS)**NGRAM + 1

In [None]:
def init(mactx):
    mactx.execute('pip install petastorm==0.10.0')
    
system_config = ldl.SystemConfig(memory_size_mb=MEMORY_SIZE_MB, hdfs_delegation_token=HDFS_DELEGATION_TOKEN)
job_config = ldl.DLJobConfig(zone=ZONE, timeout=TIME_OUT_S, num_workers=WORKER_CNT, num_gpus=1, system_config=system_config)

mactx = ldl.DLJobContext(job_config, on_job_init=init)

In [None]:
def getDataSetRows(data_paths):
    from pyarrow.parquet import ParquetDataset 
    nrows = 0
    nrow_groups = 0
    dataset = ParquetDataset(data_paths)
    for data_meta in dataset.pieces:
       nrows += data_meta.get_metadata().num_rows
       nrow_groups += data_meta.get_metadata().num_row_groups
    return nrows, nrow_groups


def get_training_configuration(batchSize):
    import pydoop.hdfs as hdfs
    from pyarrow.parquet import ParquetDataset    
            
    trainingDataSet = []
    trainingDataSet.append(hdfs.path.abspath(training_data20210301))
    trainingDataSet.append(hdfs.path.abspath(training_data20210302))
    trainingDataSet.append(hdfs.path.abspath(training_data20210303))
    trainingDataSet.append(hdfs.path.abspath(training_data20210304))
    trainingDataSet.append(hdfs.path.abspath(training_data202010))    
    trainingDataSet.append(hdfs.path.abspath(training_data202011))
    trainingDataSet.append(hdfs.path.abspath(training_data202012))
    trainingDataSet.append(hdfs.path.abspath(training_data202101))
    trainingDataSet.append(hdfs.path.abspath(training_data202102))
    
    trainingSampleCnt, trainingSampleRowGroupCnt = getDataSetRows(trainingDataSet)

    valDataSet = []
    valDataSet.append(hdfs.path.abspath(val_data20210301))
    valDataSet.append(hdfs.path.abspath(val_data20210302))    
    valDataSet.append(hdfs.path.abspath(val_data20210303))    
    valDataSet.append(hdfs.path.abspath(val_data20210304))
    valDataSet.append(hdfs.path.abspath(val_data202010))
    valDataSet.append(hdfs.path.abspath(val_data202011))
    valDataSet.append(hdfs.path.abspath(val_data202012))
    valDataSet.append(hdfs.path.abspath(val_data202101))
    valDataSet.append(hdfs.path.abspath(val_data202102))

    valSampleCnt, valSampleRowGroupCnt = getDataSetRows(valDataSet)
    
    return trainingDataSet, valDataSet, trainingSampleCnt, valSampleCnt

In [None]:
trainingDataSet, valDataSet, trainingSampleCnt, valSampleCnt = \
    mactx.run_single(get_training_configuration, args=(BATCH_SIZE,))

print("[Training config] trainingSampleCnt = {}".format(trainingSampleCnt))
print("[Training config] valSampleCnt = {}".format(valSampleCnt))

In [None]:
def isValidTrainingSample(trainingSample):
    query = trainingSample[2].lstrip()
    if len(query) < 3:
        return False
    
    return True
          

def getGramIndex(voc, xgram):
    if xgram in voc:
        return voc[xgram]
    else:
        return 0


def encodeString(strData, encodeVoc, gramLen, expectedLen):
    encoded = []

    for i in range(len(strData)-gramLen+1):
        ngram = strData[i:i+gramLen]
        idx = getGramIndex(encodeVoc, ngram)
        encoded.append(idx)
        if len(encoded) == expectedLen:
            break

    return np.asarray(encoded, dtype=np.uint16)


def encodeSampleFromList(trainingSample, encodeVoc, gramLen, withGPS = False):
    if withGPS:
        input = trainingSample[0] + "," + trainingSample[1]
        target = trainingSample[0] + "," + trainingSample[2]
    else:
        input = trainingSample[0]
        target = trainingSample[1]

    encodedInput = encodeString(input, encodeVoc, gramLen, QUERY_ENCODED_LEN)
    encodedOutput = encodeString(target, encodeVoc, gramLen, LABEL_ENCODED_LEN)
    return [encodedInput, encodedOutput]

In [None]:
def createModel(vocabulary_size, hvd):
    encoder_inputs = Input(shape=(None, vocabulary_size), name='encoder_inputs')
    encoder_lstm = LSTM(LSTM_UNITS, return_state=True, name='encoder_lstm')
    encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)

    encoder_states = [state_h, state_c]

    decoder_inputs = Input(shape=(None, vocabulary_size), name='decoder_inputs')
    decoder_lstm = LSTM(LSTM_UNITS, return_sequences=True, return_state=True, name='decoder_lstm')
    decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)

    decoder_dense = Dense(vocabulary_size, activation='softmax', name='decoder_dense')
    decoder_outputs = decoder_dense(decoder_outputs)

    model = Model([encoder_inputs, decoder_inputs], decoder_outputs, name='encoder_decoder_training')
    
    # Horovod: adjust learning rate based on number of GPUs.  1.0
    opt = keras.optimizers.Adadelta(learning_rate=LEARNING_RATE * hvd.size())

    # Horovod: add Horovod Distributed Optimizer.
    opt = hvd.DistributedOptimizer(opt)
    
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])

    model.summary()
    return model

In [None]:
def make_generator(train_ds, sc, sess, batchSize, graph, trainingShuffleSize = None):
    try:
        with graph.as_default():
            if trainingShuffleSize is None:
                train_ds = train_ds \
                    .apply(tf.data.experimental.unbatch()) \
                    .batch(batchSize) 
            else:
                train_ds = train_ds \
                    .apply(tf.data.experimental.unbatch()) \
                    .shuffle(trainingShuffleSize) \
                    .batch(batchSize) 
                
            myiter = train_ds.make_one_shot_iterator()
            
            mynext = myiter.get_next()
            
            while True:
                train_ds_batch = sess.run(mynext)
                                
                raw_query_batch =  train_ds_batch.raw_query                                
                raw_query_batch = np.char.decode(raw_query_batch.astype(np.bytes_), 'utf-8')

                corrected_query_batch =  train_ds_batch.corrected_query
                corrected_query_batch = np.char.decode(corrected_query_batch.astype(np.bytes_), 'utf-8')
                
                if USE_GPS:
                    s2cell_id_batch = train_ds_batch.s2cell_id
                    s2cell_id_batch = np.char.decode(s2cell_id_batch.astype(np.bytes_), 'utf-8')
                    
                    trainingSet = np.hstack((s2cell_id_batch.reshape(len(s2cell_id_batch),1), raw_query_batch.reshape(len(raw_query_batch),1), corrected_query_batch.reshape(len(corrected_query_batch),1)))                    
                else:
                    trainingSet = np.hstack((raw_query_batch.reshape(len(raw_query_batch),1), corrected_query_batch.reshape(len(corrected_query_batch),1)))
                    
                trainingSetRDD = sc.parallelize(trainingSet)
                #trainingSetRDD = trainingSetRDD.filter(isValidTrainingSample)
                encodedTrainingSet = trainingSetRDD.map(lambda x: encodeSampleFromList(x, encodeVoc, NGRAM, USE_GPS))
                encodedTrainingSet = encodedTrainingSet.collect()
                encodedTrainingSet = np.array(encodedTrainingSet, dtype=object)

                X = encodedTrainingSet[:,0]
                Y = encodedTrainingSet[:,1]
                # Pad the n-gram vector to expected length
                # X shape:  #sample * QUERY_ENCODED_LEN
                X = keras.preprocessing.sequence.pad_sequences(X, padding='post', maxlen=QUERY_ENCODED_LEN)
                Y = keras.preprocessing.sequence.pad_sequences(Y, padding='post', maxlen=LABEL_ENCODED_LEN-2)

                # Shift Y to the right by one position (for the starting token) and add stop token in the end
                # Y shape:  #sample * LABEL_ENCODED_LEN
                Y = np.insert(Y, 0, 0, axis=1)
                Y = np.hstack((Y, np.zeros((Y.shape[0], 1), dtype=np.uint16)))

                # One-hot encode the ngram vectors
                # encoder_input_data shape: #sample * QUERY_ENCODED_LEN * vocabulary_size
                # decoder_input_data shape: #sample * LABEL_ENCODED_LEN * vocabulary_size
                # decoder_target_data shape: #sample * LABEL_ENCODED_LEN * vocabulary_size
                encoder_input_data = to_categorical(X, num_classes=vocabulary_size, dtype='uint16')
                decoder_input_data = to_categorical(Y, num_classes=vocabulary_size, dtype='uint16')
                decoder_target_data = np.hstack((decoder_input_data[:,1:,:], np.zeros((decoder_input_data.shape[0],1,decoder_input_data.shape[2]), dtype=np.uint16)))

                yield [encoder_input_data, decoder_input_data], decoder_target_data
    except:
        pass


def train_seq2seq(batchSize):
    import tensorflow as tf
    from tensorflow import keras
    import horovod
    import horovod.tensorflow.keras as hvd    
    import pydoop.hdfs as hdfs
    from pyarrow.parquet import ParquetDataset    
    import numpy as np
    import os
    import json
    import pickle
    from petastorm import make_batch_reader
    from petastorm.tf_utils import make_petastorm_dataset
    from petastorm.predicates import in_lambda
        
    graph = tf.get_default_graph()

    sc = SparkContext()
            
    hvd.init()
    
    os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
              
    trainingShuffleSize = int(trainingSampleCnt / hvd.size())
    
    # Horovod: pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    keras.backend.set_session(tf.Session(config=config))
    
    model = createModel(vocabulary_size, hvd)

    sess = tf.Session()
            
    callbacks_list = [
        hvd.callbacks.BroadcastGlobalVariablesCallback(0),
        hvd.callbacks.MetricAverageCallback(),
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=LEARNING_RATE * hvd.size()),
        keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=10),
        keras.callbacks.EarlyStopping(monitor='val_loss', mode='auto', patience=20),
        keras.callbacks.TerminateOnNaN()
    ]

    if hvd.rank() == 0:
        model_checkpoint = ModelCheckpoint(model_checkpoint_file, monitor='val_acc', save_best_only=True, save_weights_only=True, mode='auto')
        callbacks_list.append(model_checkpoint)
    
    training_start_tm = datetime.datetime.now()

    predicate_filter = in_lambda(["raw_query"], lambda x: x.str.len() >= 3)
    # argument for make_batch_reader()
    # predicate = predicate_filter,
    
    with make_batch_reader(trainingDataSet, num_epochs=None,
                           cur_shard=hvd.rank(), shard_count=hvd.size(), 
                           schema_fields=DATA_COLS,
                           hdfs_driver=PETASTORM_HDFS_DRIVER) as train_reader:
        with make_batch_reader(valDataSet, num_epochs=None,
                               cur_shard=hvd.rank(), shard_count=hvd.size(),
                               schema_fields=DATA_COLS,
                               hdfs_driver=PETASTORM_HDFS_DRIVER) as val_reader:
            
            train_ds = make_petastorm_dataset(train_reader)
            train_ds_gen = make_generator(train_ds, sc, sess, batchSize, graph, trainingShuffleSize)

            val_ds = make_petastorm_dataset(val_reader)
            val_ds_gen = make_generator(val_ds, sc, sess, batchSize, graph)
            
            history = model.fit(train_ds_gen,
                          validation_data=val_ds_gen,
                          epochs = EPOCH_CNT, 
                          steps_per_epoch=int(trainingSampleCnt/batchSize/hvd.size()),
                          validation_steps=int(valSampleCnt/batchSize/hvd.size()),
                          callbacks = callbacks_list,
                          verbose=2 if hvd.rank() == 0 else 0)

            training_end_tm = datetime.datetime.now()

            print(f"\nTraining Finished!!!  Time: {training_end_tm - training_start_tm}")

            print("Saving Model file...")
            model.save(model_file)
            model.save_weights(model_weights_file)

            json_config = model.to_json()
            with open(model_config_json, 'w') as json_file:
                json_file.write(json_config)
                
            with open(training_history_pickle, 'wb') as history_file:
                pickle.dump(history.history, history_file)
            
            print("Model file saved")

In [None]:
mactx.run(train_seq2seq, args=(BATCH_SIZE,), env={'LIBHDFS_OPTS': '-Xms8192m -Xmx8192m', 'TF_CPP_MIN_LOG_LEVEL': '3'})

In [None]:
mactx.shutdown()