# Introduction

We tackle the problem of OCR post processing. In OCR, we map the image form of the document into the text domain. This is done first using an CNN+LSTM+CTC model, in our case based on tesseract. Since this output maps only image to text, we need something on top to validate and correct language semantics.

The idea is to build a language model, that takes the OCRed text and corrects it based on language knowledge. The langauge model could be:
- Char level: the aim is to capture the word morphology. In which case it's like a spelling correction system.
- Word level: the aim is to capture the sentence semnatics. But such systems suffer from the OOV problem.
- Fusion: to capture semantics and morphology language rules. The output has to be at char level, to avoid the OOV. However, the input can be char, word or both.

The fusion model target is to learn:

    p(char | char_context, word_context)

In this workbook we use seq2seq vanilla Keras implementation, adapted from the lstm_seq2seq example on Eng-Fra translation task. The adaptation involves:

- Adapt to spelling correction, on char level
- Pre-train on a noisy, medical sentences
- Fine tune a residual, to correct the mistakes of tesseract 
- Limit the input and output sequence lengths
- Enusre teacher forcing auto regressive model in the decoder
- Limit the padding per batch
- Learning rate schedule
- Bi-directional LSTM Encoder
- Bi-directional GRU Encoder


# Imports

In [1]:
from __future__ import print_function
import tensorflow as tf
import keras.backend as K
from keras.backend.tensorflow_backend import set_session
from keras.models import Model
from keras.layers import Input, LSTM, Dense, Bidirectional, Concatenate, GRU
from keras import optimizers
from keras.callbacks import ModelCheckpoint, TensorBoard, LearningRateScheduler
from keras.models import load_model
import numpy as np
import os
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from autocorrect import spell
import re
%matplotlib inline

Using TensorFlow backend.


# Utility functions

In [2]:
# Limit gpu allocation. allow_growth, or gpu_fraction
def gpu_alloc():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config))

In [3]:
gpu_alloc()

In [4]:
def calculate_WER_sent(gt, pred):
    '''
    calculate_WER('calculating wer between two sentences', 'calculate wer between two sentences')
    '''
    gt_words = gt.lower().split(' ')
    pred_words = pred.lower().split(' ')
    d = np.zeros(((len(gt_words) + 1), (len(pred_words) + 1)), dtype=np.uint8)
    # d = d.reshape((len(gt_words)+1, len(pred_words)+1))

    # Initializing error matrix
    for i in range(len(gt_words) + 1):
        for j in range(len(pred_words) + 1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(gt_words) + 1):
        for j in range(1, len(pred_words) + 1):
            if gt_words[i - 1] == pred_words[j - 1]:
                d[i][j] = d[i - 1][j - 1]
            else:
                substitution = d[i - 1][j - 1] + 1
                insertion = d[i][j - 1] + 1
                deletion = d[i - 1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)
    return d[len(gt_words)][len(pred_words)]

In [5]:
def calculate_WER(gt, pred):
    '''

    :param gt: list of sentences of the ground truth
    :param pred: list of sentences of the predictions
    both lists must have the same length
    :return: accumulated WER
    '''
#    assert len(gt) == len(pred)
    WER = 0
    nb_w = 0
    for i in range(len(gt)):
        #print(gt[i])
        #print(pred[i])
        WER += calculate_WER_sent(gt[i], pred[i])
        nb_w += len(gt[i])

    return WER / nb_w

In [6]:
def load_data_with_gt(file_name, num_samples, max_sent_len, min_sent_len, delimiter='\t', gt_index=1, prediction_index=0):
    '''Load data from txt file, with each line has: <TXT><TAB><GT>. The  target to the decoder muxt have \t as the start trigger and \n as the stop trigger.'''
    cnt = 0  
    input_texts = []
    gt_texts = []
    target_texts = []
    for row in open(file_name, encoding='utf8'):
        if cnt < num_samples :
            #print(row)
            sents = row.split(delimiter)
            input_text = sents[prediction_index]
            
            target_text = '\t' + sents[gt_index] + '\n'
            if len(input_text) > min_sent_len and len(input_text) < max_sent_len and len(target_text) > min_sent_len and len(target_text) < max_sent_len:
                cnt += 1
                
                input_texts.append(input_text)
                target_texts.append(target_text)
                gt_texts.append(sents[gt_index])
    return input_texts, target_texts, gt_texts

In [7]:
def load_data(file_name, num_samples, max_sent_len, min_sent_len):
    '''Load data from txt file, with each line has: <TXT><TAB><GT>. The  target to the decoder muxt have \t as the start trigger and \n as the stop trigger.'''
    cnt = 0  
    input_texts = []   
    
    #for row in open(file_name, encoding='utf8'):
    for row in open(file_name):
        if cnt < num_samples :            
            input_text = row           
            if len(input_text) > min_sent_len and len(input_text) < max_sent_len:
                cnt += 1                
                input_texts.append(input_text)
    return input_texts

In [8]:
def vectorize_data(input_texts, max_encoder_seq_length, num_encoder_tokens, vocab_to_int):
    
    if(len(input_texts) > max_encoder_seq_length):
        input_texts = input_texts[:max_encoder_seq_length]
    
    '''Prepares the input text and targets into the proper seq2seq numpy arrays'''
    encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length),
    dtype='float32')
    
    for i, input_text in enumerate(input_texts):
        for t, char in enumerate(input_text[:max_encoder_seq_length]):
            # c0..cn
            encoder_input_data[i, t] = vocab_to_int[char]
                
    return encoder_input_data

In [9]:
def decode_sequence(input_seq, encoder_model, decoder_model, num_decoder_tokens, max_decoder_seq_length, vocab_to_int, int_to_vocab):
    
    #print(max_decoder_seq_length)
    # Encode the input as state vectors.
    encoder_outputs, h, c  = encoder_model.predict(input_seq)
    states_value = [h,c]
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0] = vocab_to_int['\t']

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    #print(input_seq)
    attention_density = []
    i = 0
    special_chars = ['\\', '/', '-', '—' , ':', '[', ']', ',', '.', '"', ';', '%', '~', '(', ')', '{', '}', '$', '#']
    #special_chars = []
    while not stop_condition:
        #print(target_seq)
        output_tokens, attention, h, c  = decoder_model.predict(
            [target_seq, encoder_outputs] + states_value)
        #print(attention.shape)
        attention_density.append(attention[0][0])# attention is max_sent_len x 1 since we have num_time_steps = 1 for the output
        # Sample a token
        #print(output_tokens.shape)
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        
        #print(sampled_token_index)
        sampled_char = int_to_vocab[sampled_token_index]
        
        orig_char = int_to_vocab[int(input_seq[:,i][0])]
        
        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True
            #print('End', sampled_char, 'Len ', len(decoded_sentence), 'Max len ', max_decoder_seq_length)
            sampled_char = ''
        
        # Copy digits as it, since the spelling corrector is not good at digit corrections
        if(orig_char.isdigit() or orig_char in special_chars):
            decoded_sentence += orig_char            
        else:
            if(sampled_char.isdigit() or sampled_char in special_chars):
                decoded_sentence += ''
            else:
                decoded_sentence += sampled_char
        
        #decoded_sentence += sampled_char


        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1))
        target_seq[0, 0] = sampled_token_index

        # Update states
        states_value = [h, c]
        
        i += 1
        if(i > 48):
            i = 0
    attention_density = np.array(attention_density)
    
    # Word level spell correct
    '''
    corrected_decoded_sentence = ''
    for w in decoded_sentence.split(' '):
        corrected_decoded_sentence += spell(w) + ' '
    decoded_sentence = corrected_decoded_sentence
    '''
    return decoded_sentence, attention_density


In [10]:
def word_spell_correct(decoded_sentence):
    corrected_decoded_sentence = ''
    special_chars = ['\\', '/', '-', '—' , ':', '[', ']', ',', '.', '"', ';', '%', '~', '(', ')', '{', '}', '$', '#']
    for w in decoded_sentence.split(' '):
        if((len(re.findall(r'\d+', w))==0) and not (w in special_chars)):
            corrected_decoded_sentence += spell(w) + ' '
        else:
            corrected_decoded_sentence += w + ' '
    return corrected_decoded_sentence

# Load data

# Load model params

In [11]:
data_path = '../../dat/'

In [None]:
max_sent_lengths = [50, 100]

In [None]:
vocab_file = {}
model_file = {}
encoder_model_file = {}
decoder_model_file = {}
model = {}
encoder_model = {}
decoder_model = {}
vocab = {}
vocab_to_int = {}
int_to_vocab = {}
max_sent_len = {}
min_sent_len = {}
num_decoder_tokens = {}
num_encoder_tokens = {}
max_encoder_seq_length = {}
max_decoder_seq_length = {}

In [None]:

for i in max_sent_lengths:
    vocab_file[i] = 'vocab-{}.npz'.format(i)
    model_file[i] = 'best_model-{}.hdf5'.format(i)
    encoder_model_file[i] = 'encoder_model-{}.hdf5'.format(i)
    decoder_model_file[i] = 'decoder_model-{}.hdf5'.format(i)
    
    vocab = np.load(file=vocab_file[i])
    vocab_to_int[i] = vocab['vocab_to_int'].item()
    int_to_vocab[i] = vocab['int_to_vocab'].item()
    max_sent_len[i] = vocab['max_sent_len']
    min_sent_len[i] = vocab['min_sent_len']
    input_characters = sorted(list(vocab_to_int))
    num_decoder_tokens[i] = num_encoder_tokens[i] = len(input_characters) #int(encoder_model.layers[0].input.shape[2])
    max_encoder_seq_length[i] = max_decoder_seq_length[i] = max_sent_len[i] - 1#max([len(txt) for txt in input_texts])
    
    model[i] = load_model(model_file[i])
    encoder_model[i] = load_model(encoder_model_file[i])
    decoder_model[i] = load_model(decoder_model_file[i])

In [None]:
num_samples = 1000000
#tess_correction_data = os.path.join(data_path, 'test_data.txt')
#input_texts = load_data(tess_correction_data, num_samples, max_sent_len, min_sent_len)

OCR_data = os.path.join(data_path, 'new_trained_data.txt')
#input_texts, target_texts, gt_texts = load_data_with_gt(OCR_data, num_samples, max_sent_len, min_sent_len, delimiter='|',gt_index=0, prediction_index=1)
input_texts, target_texts, gt_texts = load_data_with_gt(OCR_data, num_samples, max_sent_len=10000, min_sent_len=0)

In [None]:
# Sample data
print(len(input_texts))
for i in range(10):
    print(input_texts[i], '\n', target_texts[i])

In [None]:
# Spell correct before inference
'''
input_texts_ = []
for sent in input_texts:
    sent_ = ''
    for word in sent.split(' '):
        sent_ += spell(word) + ' '
    input_texts_.append(sent_)
input_texts = input_texts_
input_texts_ = []
# Sample data
print(len(input_texts))
for i in range(10):
    print(input_texts[i], '\n', target_texts[i])
'''

In [None]:
decoded_sentences = []

#for seq_index in range(len(input_texts)):
results = open('RESULTS.md', 'w')
results.write('|OCR sentence|GT sentence|Char decoded sentence|Word decoded sentence|Sentence length (chars)|\n')
results.write('---------------|-----------|----------------|----------------|----------------|\n')
     

for i, input_text in enumerate(input_texts):
    #print(input_text)
    # Find the input length range to choose the proper model to use
    len_range = max_sent_lengths[-1] # Take the longest range
    for length in max_sent_lengths:
        if(len(input_text) < length):
            len_range = length
            break
    #print(len_range)
    encoder_input_data = vectorize_data(input_texts=[input_text], max_encoder_seq_length=max_encoder_seq_length[len_range], num_encoder_tokens=num_encoder_tokens[len_range], vocab_to_int=vocab_to_int[len_range])
    
    

    target_text = gt_texts[i]
    
    input_seq = encoder_input_data
    #print(input_seq.shape)
    #print(max_decoder_seq_length[len_range])
    #print(max_decoder_seq_length)
    decoded_sentence,_  = decode_sequence(input_seq, encoder_model[len_range], decoder_model[len_range], num_decoder_tokens[len_range],  max_decoder_seq_length[len_range], vocab_to_int[len_range], int_to_vocab[len_range])
    corrected_sentence = word_spell_correct(decoded_sentence)
    print('-Lenght = ', len_range)
    print('Input sentence:', input_text)
    print('GT sentence:', target_text.strip())
    print('Char Decoded sentence:', decoded_sentence)   
    print('Word Decoded sentence:', corrected_sentence) 
    results.write(input_text + '|' + target_text.strip() + '|' + decoded_sentence + '|' + corrected_sentence + '|' + str(len_range) + '|\n')
    decoded_sentences.append(decoded_sentence)
results.close()    

    

In [None]:
WER_spell_correction = calculate_WER(gt_texts, decoded_sentences)
print('WER_spell_correction |TEST= ', WER_spell_correction)

In [None]:
WER_OCR = calculate_WER(gt_texts, input_texts)
print('WER_OCR |TEST= ', WER_OCR)