In [None]:
import math
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv1D, Conv2D, MaxPooling2D, Activation, Reshape, Bidirectional, LSTM, Dense, Lambda, Layer, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from spellchecker import SpellChecker
%run data_loader.ipynb
%matplotlib inline
np.random.seed(1)

In [None]:
X_train, X_test, Y_train, Y_test, Y_train_encoded, Y_test_encoded = data_loader('words', 'words.txt') 
X_train = X_train/255
X_test = X_test/255
X_train = np.reshape(X_train, (-1, 32, 128, 1))
X_test = np.reshape(X_test, (-1, 32, 128, 1))

In [None]:
class CTCLayer(Layer):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.loss_fn = tf.keras.backend.ctc_batch_cost

    def call(self, y_true, y_pred):
        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

        input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

        loss = self.loss_fn(y_true, y_pred, input_length, label_length)
        self.add_loss(loss)

        return y_pred

In [None]:
def build_model():

    input_img = Input(shape=(32, 128, 1), name="image", dtype="float32")
    labels =  Input(name="label", shape=(None,), dtype="float32")

    conv1 = Conv2D(16, (5, 5),activation="relu",kernel_initializer="he_normal", padding="same", name="Conv1")(input_img)
    pool1 = MaxPooling2D((2, 2), name="pool1")(conv1)


    conv2 = Conv2D(32, (5, 5), activation="relu", kernel_initializer="he_normal", padding="same", name="Conv2")(pool1)
    pool2 = MaxPooling2D((2, 2), name="pool2")(conv2)
    
    conv3 = Conv2D(64, (3, 3), activation="relu", kernel_initializer="he_normal", padding="same", name="Conv3")(pool2)
    pool3 = MaxPooling2D((2, 1), name="pool3")(conv3)
    
    conv4 = Conv2D(128, (3, 3), activation="relu", kernel_initializer="he_normal", padding="same", name="Conv4")(pool3)
    pool4 = MaxPooling2D((2, 1), name="pool4")(conv4)
    
    conv5 = Conv2D(256, (3, 3), activation="relu", kernel_initializer="he_normal", padding="same", name="Conv5")(pool4)
    pool5 = MaxPooling2D((2, 1), name="pool5")(conv5)
    
    reshape = Reshape(target_shape = (32, 256), name="reshape")(pool5)
    dropout_layer = Dropout(0.2)(reshape)

    blstm1 = Bidirectional( tf.keras.layers.LSTM(256, return_sequences=True, dropout=0.25))(reshape)
    blstm2 = Bidirectional( tf.keras.layers.LSTM(256, return_sequences=True, dropout=0.25))(blstm1)
    
    dense = Dense(80, activation="softmax", name="dense1")(blstm2)

    output = CTCLayer(name="ctc_loss")(labels, dense)

    model =  Model(inputs=[input_img, labels], outputs=output, name="htr_model_1")
    
    opt = Adam()

    model.compile(optimizer=opt, metrics=[tf.keras.metrics.Accuracy()])
    return model


In [None]:
model = build_model()
model.summary()

In [None]:
history = model.fit(x = [X_train, Y_train_encoded], epochs = 2, verbose = 1)

In [None]:
prediction_model = tf.keras.models.Model(model.get_layer(name="image").input, model.get_layer(name="dense1").output)
prediction_model.summary()
prediction = prediction_model.predict([X_test, Y_test_encoded])
out = tf.keras.backend.get_value(tf.keras.backend.ctc_decode(prediction, input_length=np.ones(prediction.shape[0])*prediction.shape[1], greedy=True)[0][0])[:, : 32]
print(np.shape(out))

spell = SpellChecker()
i, correct = 0, 0
for x in out:
    pred_word, equality = '', False
    act_word = Y_test[i]
    for p in x:  
        if int(p) != -1:
            pred_word = pred_word + char_list[int(p)]
    pred_word = spell.correction(pred_word)
    if pred_word == act_word:
        equality = True
        correct += 1
    print('||' + act_word + '||' + pred_word + '||' + str(equality))       
    print('\n')
    i+=1

print('Total correct matches : {}'.format(correct))
print('Accuracy : {} %'.format(correct/len(X_test)*100))