# Introduction

We tackle the problem of OCR post processing. In OCR, we map the image form of the document into the text domain. This is done first using an CNN+LSTM+CTC model, in our case based on tesseract. Since this output maps only image to text, we need something on top to validate and correct language semantics.

The idea is to build a language model, that takes the OCRed text and corrects it based on language knowledge. The langauge model could be:
- Char level: the aim is to capture the word morphology. In which case it's like a spelling correction system.
- Word level: the aim is to capture the sentence semnatics. But such systems suffer from the OOV problem.
- Fusion: to capture semantics and morphology language rules. The output has to be at char level, to avoid the OOV. However, the input can be char, word or both.

The fusion model target is to learn:

    p(char | char_context, word_context)

In this workbook we use seq2seq vanilla Keras implementation, adapted from the lstm_seq2seq example on Eng-Fra translation task. The adaptation involves:

- Adapt to spelling correction, on char level
- Pre-train on a noisy, medical sentences
- Fine tune a residual, to correct the mistakes of tesseract 
- Limit the input and output sequence lengths
- Enusre teacher forcing auto regressive model in the decoder
- Limit the padding per batch
- Learning rate schedule
- Bi-directional LSTM Encoder
- Bi-directional GRU Encoder


# Imports

In [1]:
from __future__ import print_function
import tensorflow as tf
import keras.backend as K
from keras.backend.tensorflow_backend import set_session
from keras.models import Model
from keras.layers import Input, LSTM, Dense, Bidirectional, Concatenate, GRU
from keras import optimizers
from keras.callbacks import ModelCheckpoint, TensorBoard, LearningRateScheduler
from keras.models import load_model
import numpy as np
import os
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from autocorrect import spell
import re
%matplotlib inline

Using TensorFlow backend.


# Utility functions

In [2]:
# Limit gpu allocation. allow_growth, or gpu_fraction
def gpu_alloc():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    set_session(tf.Session(config=config))

In [3]:
gpu_alloc()

In [4]:
def calculate_WER_sent(gt, pred):
    '''
    calculate_WER('calculating wer between two sentences', 'calculate wer between two sentences')
    '''
    gt_words = gt.lower().split(' ')
    pred_words = pred.lower().split(' ')
    d = np.zeros(((len(gt_words) + 1), (len(pred_words) + 1)), dtype=np.uint8)
    # d = d.reshape((len(gt_words)+1, len(pred_words)+1))

    # Initializing error matrix
    for i in range(len(gt_words) + 1):
        for j in range(len(pred_words) + 1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(gt_words) + 1):
        for j in range(1, len(pred_words) + 1):
            if gt_words[i - 1] == pred_words[j - 1]:
                d[i][j] = d[i - 1][j - 1]
            else:
                substitution = d[i - 1][j - 1] + 1
                insertion = d[i][j - 1] + 1
                deletion = d[i - 1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)
    return d[len(gt_words)][len(pred_words)]

In [5]:
def calculate_WER(gt, pred):
    '''

    :param gt: list of sentences of the ground truth
    :param pred: list of sentences of the predictions
    both lists must have the same length
    :return: accumulated WER
    '''
#    assert len(gt) == len(pred)
    WER = 0
    nb_w = 0
    for i in range(len(gt)):
        #print(gt[i])
        #print(pred[i])
        WER += calculate_WER_sent(gt[i], pred[i])
        nb_w += len(gt[i])

    return WER / nb_w

In [6]:
def load_data_with_gt(file_name, num_samples, max_sent_len, min_sent_len, delimiter='\t', gt_index=1, prediction_index=0):
    '''Load data from txt file, with each line has: <TXT><TAB><GT>. The  target to the decoder muxt have \t as the start trigger and \n as the stop trigger.'''
    cnt = 0  
    input_texts = []
    gt_texts = []
    target_texts = []
    for row in open(file_name, encoding='utf8'):
        if cnt < num_samples :
            #print(row)
            sents = row.split(delimiter)
            input_text = sents[prediction_index]
            
            target_text = '\t' + sents[gt_index] + '\n'
            if len(input_text) > min_sent_len and len(input_text) < max_sent_len and len(target_text) > min_sent_len and len(target_text) < max_sent_len:
                cnt += 1
                
                input_texts.append(input_text)
                target_texts.append(target_text)
                gt_texts.append(sents[gt_index])
    return input_texts, target_texts, gt_texts

In [7]:
def load_data(file_name, num_samples, max_sent_len, min_sent_len):
    '''Load data from txt file, with each line has: <TXT><TAB><GT>. The  target to the decoder muxt have \t as the start trigger and \n as the stop trigger.'''
    cnt = 0  
    input_texts = []   
    
    #for row in open(file_name, encoding='utf8'):
    for row in open(file_name):
        if cnt < num_samples :            
            input_text = row           
            if len(input_text) > min_sent_len and len(input_text) < max_sent_len:
                cnt += 1                
                input_texts.append(input_text)
    return input_texts

In [8]:
def vectorize_data(input_texts, max_encoder_seq_length, num_encoder_tokens, vocab_to_int):
    
    if(len(input_texts) > max_encoder_seq_length):
        input_texts = input_texts[:max_encoder_seq_length]
    
    '''Prepares the input text and targets into the proper seq2seq numpy arrays'''
    encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length),
    dtype='float32')
    
    for i, input_text in enumerate(input_texts):
        for t, char in enumerate(input_text[:max_encoder_seq_length]):
            # c0..cn
            encoder_input_data[i, t] = vocab_to_int[char]
                
    return encoder_input_data

In [9]:
def decode_sequence(input_seq, encoder_model, decoder_model, num_decoder_tokens, max_decoder_seq_length, vocab_to_int, int_to_vocab):
    
    #print(max_decoder_seq_length)
    # Encode the input as state vectors.
    encoder_outputs, h, c  = encoder_model.predict(input_seq)
    states_value = [h,c]
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0] = vocab_to_int['\t']

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    #print(input_seq)
    attention_density = []
    i = 0
    special_chars = ['\\', '/', '-', '—' , ':', '[', ']', ',', '.', '"', ';', '%', '~', '(', ')', '{', '}', '$', '#']
    #special_chars = []
    while not stop_condition:
        #print(target_seq)
        output_tokens, attention, h, c  = decoder_model.predict(
            [target_seq, encoder_outputs] + states_value)
        #print(attention.shape)
        attention_density.append(attention[0][0])# attention is max_sent_len x 1 since we have num_time_steps = 1 for the output
        # Sample a token
        #print(output_tokens.shape)
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        
        #print(sampled_token_index)
        sampled_char = int_to_vocab[sampled_token_index]
        
        orig_char = int_to_vocab[int(input_seq[:,i][0])]
        
        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True
            #print('End', sampled_char, 'Len ', len(decoded_sentence), 'Max len ', max_decoder_seq_length)
            sampled_char = ''
        
        # Copy digits as it, since the spelling corrector is not good at digit corrections
        
        if(orig_char.isdigit() or orig_char in special_chars):
            decoded_sentence += orig_char            
        else:
            if(sampled_char.isdigit() or sampled_char in special_chars):
                decoded_sentence += ''
            else:
                decoded_sentence += sampled_char
        
        #decoded_sentence += sampled_char


        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1))
        target_seq[0, 0] = sampled_token_index

        # Update states
        states_value = [h, c]
        
        i += 1
        if(i > 48):
            i = 0
    attention_density = np.array(attention_density)
    
    # Word level spell correct
    '''
    corrected_decoded_sentence = ''
    for w in decoded_sentence.split(' '):
        corrected_decoded_sentence += spell(w) + ' '
    decoded_sentence = corrected_decoded_sentence
    '''
    return decoded_sentence, attention_density


In [10]:
def word_spell_correct(decoded_sentence):
    corrected_decoded_sentence = ''
    special_chars = ['\\', '/', '-', '—' , ':', '[', ']', ',', '.', '"', ';', '%', '~', '(', ')', '{', '}', '$', '#']
    for w in decoded_sentence.split(' '):
        if((len(re.findall(r'\d+', w))==0) and not (w in special_chars)):
            corrected_decoded_sentence += spell(w) + ' '
        else:
            corrected_decoded_sentence += w + ' '
    return corrected_decoded_sentence

In [11]:
def clean_up_sentence(sentence, vocab):
    s = ''
    prev_char = ''
    for c in sentence.strip():
        if c not in vocab or (c == ' ' and prev_char == ' '):
            s += ''
        else:
            s += c
        prev_char = c
            
    return s

# Load data

# Load model params

In [12]:
data_path = '../../dat/'

In [13]:
max_sent_lengths = [50, 100]

In [14]:
vocab_file = {}
model_file = {}
encoder_model_file = {}
decoder_model_file = {}
model = {}
encoder_model = {}
decoder_model = {}
vocab = {}
vocab_to_int = {}
int_to_vocab = {}
max_sent_len = {}
min_sent_len = {}
num_decoder_tokens = {}
num_encoder_tokens = {}
max_encoder_seq_length = {}
max_decoder_seq_length = {}

In [15]:

for i in max_sent_lengths:
    vocab_file[i] = 'vocab-{}.npz'.format(i)
    model_file[i] = 'best_model-{}.hdf5'.format(i)
    encoder_model_file[i] = 'encoder_model-{}.hdf5'.format(i)
    decoder_model_file[i] = 'decoder_model-{}.hdf5'.format(i)
    
    vocab = np.load(file=vocab_file[i])
    vocab_to_int[i] = vocab['vocab_to_int'].item()
    int_to_vocab[i] = vocab['int_to_vocab'].item()
    max_sent_len[i] = vocab['max_sent_len']
    min_sent_len[i] = vocab['min_sent_len']
    input_characters = sorted(list(vocab_to_int))
    num_decoder_tokens[i] = num_encoder_tokens[i] = len(input_characters) #int(encoder_model.layers[0].input.shape[2])
    max_encoder_seq_length[i] = max_decoder_seq_length[i] = max_sent_len[i] - 1#max([len(txt) for txt in input_texts])
    
    model[i] = load_model(model_file[i])
    encoder_model[i] = load_model(encoder_model_file[i])
    decoder_model[i] = load_model(decoder_model_file[i])



In [16]:
num_samples = 1000000
#tess_correction_data = os.path.join(data_path, 'test_data.txt')
#input_texts = load_data(tess_correction_data, num_samples, max_sent_len, min_sent_len)

OCR_data = os.path.join(data_path, 'new_trained_data.txt')
#input_texts, target_texts, gt_texts = load_data_with_gt(OCR_data, num_samples, max_sent_len, min_sent_len, delimiter='|',gt_index=0, prediction_index=1)
input_texts, target_texts, gt_texts = load_data_with_gt(OCR_data, num_samples, max_sent_len=10000, min_sent_len=0)

In [17]:
# Sample data
print(len(input_texts))
for i in range(10):
    print(input_texts[i], '\n', target_texts[i])

1951
Me dieal Provider Roles: Treating  
 	Medical Provider Roles: Treating


Provider First Name: Christine  
 	Provider First Name: Christine


Provider Last Name: Nolen, MD  
 	Provider Last Name: Nolen, MD


Address Line 1 : 7 25 American Avenue  
 	Address Line 1 : 725 American Avenue


City. W’aukesha  
 	City: Waukesha


StatefProvinee: ‘WI  
 	State/Province: WI


Postal Code: 5 31 88  
 	Postal Code: 53188


Country". US  
 	Country:  US


Business Telephone: (2 62) 92 8- 1000  
 	Business Telephone: (262) 928- 1000


Date ot‘Pirst Visit: 1 2/01f20 17  
 	Date of First Visit: 12/01/2017




In [18]:
# Spell correct before inference
'''
input_texts_ = []
for sent in input_texts:
    sent_ = ''
    for word in sent.split(' '):
        sent_ += spell(word) + ' '
    input_texts_.append(sent_)
input_texts = input_texts_
input_texts_ = []
# Sample data
print(len(input_texts))
for i in range(10):
    print(input_texts[i], '\n', target_texts[i])
'''

"\ninput_texts_ = []\nfor sent in input_texts:\n    sent_ = ''\n    for word in sent.split(' '):\n        sent_ += spell(word) + ' '\n    input_texts_.append(sent_)\ninput_texts = input_texts_\ninput_texts_ = []\n# Sample data\nprint(len(input_texts))\nfor i in range(10):\n    print(input_texts[i], '\n', target_texts[i])\n"

In [19]:
decoded_sentences = []
corrected_sentences = []

#for seq_index in range(len(input_texts)):
results = open('RESULTS.md', 'w')
results.write('|OCR sentence|GT sentence|Char decoded sentence|Word decoded sentence|Sentence length (chars)|\n')
results.write('---------------|-----------|----------------|----------------|----------------|\n')
     

for i, input_text in enumerate(input_texts):
    #print(input_text)
    # Find the input length range to choose the proper model to use
    len_range = max_sent_lengths[-1] # Take the longest range
    for length in max_sent_lengths:
        if(len(input_text) < length):
            len_range = length
            break
    #print(len_range)
    
    input_text = clean_up_sentence(input_text, vocab_to_int[len_range])
    encoder_input_data = vectorize_data(input_texts=[input_text], max_encoder_seq_length=max_encoder_seq_length[len_range], num_encoder_tokens=num_encoder_tokens[len_range], vocab_to_int=vocab_to_int[len_range])
    
    

    target_text = gt_texts[i]
    
    input_seq = encoder_input_data
    #print(input_seq.shape)
    #print(max_decoder_seq_length[len_range])
    #print(max_decoder_seq_length)
    decoded_sentence,_  = decode_sequence(input_seq, encoder_model[len_range], decoder_model[len_range], num_decoder_tokens[len_range],  max_decoder_seq_length[len_range], vocab_to_int[len_range], int_to_vocab[len_range])
    corrected_sentence = word_spell_correct(input_text)
    print('-Lenght = ', len_range)
    print('Input sentence:', input_text)
    print('GT sentence:', target_text.strip())
    print('Char Decoded sentence:', decoded_sentence)   
    print('Word Decoded sentence:', corrected_sentence) 
    results.write(' | ' + input_text + ' | ' + target_text.strip() + ' | ' + decoded_sentence + ' | ' + corrected_sentence + ' | ' + str(len_range) + ' | \n')
    decoded_sentences.append(decoded_sentence)
    corrected_sentences.append(corrected_sentence)
results.close()    

    

-Lenght =  50
Input sentence: Me dieal Provider Roles: Treating
GT sentence: Medical Provider Roles: Treating
Char Decoded sentence: Medical Provider Roles:Treating
Word Decoded sentence: Me deal Provider Roles Treating 
-Lenght =  50
Input sentence: Provider First Name: Christine
GT sentence: Provider First Name: Christine
Char Decoded sentence: Provider First Name: Christine
Word Decoded sentence: Provider First Name Christine 
-Lenght =  50
Input sentence: Provider Last Name: Nolen, MD
GT sentence: Provider Last Name: Nolen, MD
Char Decoded sentence: Provider Last Name: Norle, MD
Word Decoded sentence: Provider Last Name Dolens MD 
-Lenght =  50
Input sentence: Address Line 1 : 7 25 American Avenue
GT sentence: Address Line 1 : 725 American Avenue
Char Decoded sentence: Address Line 1:7A25ent Admentine Ave
Word Decoded sentence: Address Line 1 : 7 25 American Avenue 
-Lenght =  50
Input sentence: City. W’aukesha
GT sentence: City: Waukesha
Char Decoded sentence: City.States
Word Dec

KeyboardInterrupt: 

In [25]:
input_texts = ['SUBJECTIVE: This is a S-year-old +@W his left great toe with the handleh lacration.',
               'Thera was no handlebarthe lacration.',
               'Patiet last tet is needing this for school at this',
               'OBJECTIVE : The temp is 99.8, the f tha blood pressure 99/64, O2 sat 94 8/10 at this time.',
               'Left great toe the dorsl surface, extending ta th active hemorrhage at this time.',
               'Th anaathetized with a cotton ball sat Left this in place for 20 minutes.',
               'with Betadine again and injected th he tolerated very well.',
               'The wound sutures. Antibiotic eintment and g',
               'Patient tolerated very well. Pat',
               'IMPRESSION: Lacration te left grs',
               'PLAN: Patent is to do dressing ch advised as far as checking tha waurn it with soap and water.',
               'Sutures oy have any problems prior te that tim ona teaspoon three times a day rer Ibuprofan far pain, discomfort. Cg',
               'hite male who accidently dropped a bike onto ar end hitting the left great toe,'
               'causing a guard to the end of the bike, which caused anus shot is more that three years ago and time.',
               'nlse ef 105 and regular, resprations 286,% on room air.',
               'Patient rates hia pain at — there i15 noted a 3-om laceration across a lateral aspect of tha toe.',
               'There is no e toa ir cleansed with Betadine.',
               'It is then urated with 5 cu of 2% Hylocaine plain.',
               'We then cleansed ae toa with 3 cc of 2% Xylacaina plain',
               'which was then clesed with five 5-0 Prolene ressure dressing was then applied to the tos',
               'paient is given DPT 0.5 ee intramucular (IM).at toe.',
               'Kefylex 250 mg per 5 ml, the next seven days.',
               'He may use Tylenol or 11 if any problems.',
               'Unum Life Insurance Company of America 2211',               
               'Congress Street Portland, Maine 04122',
               'APPLICATION FOR GROUP CRITICAL LLNESS INSURANCE',
               'I Evidence of Insurability',
               '',
               'Application Type: @ New Enrollee Change to',
               'Existing Coverage  Reinstatement  Internal',
               'Replacement  Late Applicant  Rehire SECTION 1:',
               'Employee(Applicant) Information  Always',
               'Complete Employee Name(First, Middle, Last)',
               'Social Security Number Nikolas J Jones',
               '123 - 456 - 7890 Home Address(Street/ PO Box)',
               'Gender 1634 Stewert St  F  M City Date of Birth',
               '(mm / dd / yyyy) Seattle 06 / 15 / 1991 State Zip',
               'Code Home Phone # Washington 98101 854-555-1212',
               'Are you Actively at Work? Employee ID / Payroll #',
               ' Yes  No55624 a.Are you a U.S.Citizen or',
               'Canadian Citizen working in the U.S.? b.Are you',
               'legally authorized to work in  Yes  No(If No',
               'reply to part b) the U.S.?  Yes  No Employer',
               'Name Group Number Date of Hire(mm/ dd / yyyy)',
               'Facebook 11 - 555566 11 / 30 / 2016 Occupation',
               'Eligibility Class Software Engineer 7 Scheduled',
               'Number of Work Hours per Week Work Phone # 35',
               '854-555-6622 SECTION 2: Spouse Information ',
               'Complete Only if applying for Spouse coverage Name',
               '(First, Middle, Last) Social Security Number',
               'Gender Date of Birth(mm / dd / yyyy) Does the',
               '1019 - 07 - AZ 1',
              'if claint is for a child, please state your relationship 10 the child',
              'date of accident 3d _ time of accident ram. 0 p.m.',
              'have you slopped working? (of yes [1 no if yes, what was the last day that you worked? (mm/ddryy)_| —3 | —{% cnslamegs bil =']
               
for input_text in input_texts:
    len_range = max_sent_lengths[-1] # Take the longest range
    for length in max_sent_lengths:
        if(len(input_text) < length):
            len_range = length
            break
    #print(len_range)
    pre_corrected_sentence = word_spell_correct(input_text)
    input_text = clean_up_sentence(input_text, vocab_to_int[len_range])
    encoder_input_data = vectorize_data(input_texts=[input_text], max_encoder_seq_length=max_encoder_seq_length[len_range], num_encoder_tokens=num_encoder_tokens[len_range], vocab_to_int=vocab_to_int[len_range])



    target_text = gt_texts[i]

    input_seq = encoder_input_data
    #print(input_seq.shape)
    #print(max_decoder_seq_length[len_range])
    #print(max_decoder_seq_length)

    decoded_sentence,_  = decode_sequence(input_seq, encoder_model[len_range], decoder_model[len_range], num_decoder_tokens[len_range],  max_decoder_seq_length[len_range], vocab_to_int[len_range], int_to_vocab[len_range])
    corrected_sentence = word_spell_correct(input_text)
    #print('-Lenght = ', len_range)
    print('Input sentence:', input_text)
    #print('Spell Decoded sentence:', pre_corrected_sentence) 
    #print('Char Decoded sentence:', decoded_sentence)   
    print('Word Decoded sentence:', corrected_sentence) 
    print('\n')



Input sentence: SUBJECTIVE: This is a S-year-old +@W his left great toe with the handleh lacration.
Word Decoded sentence: SUBJECTIVES This is a S-year-old New his left great toe with the handled laceration 


Input sentence: Thera was no handlebarthe lacration.
Word Decoded sentence: Thera was no handlebarthe laceration 


Input sentence: Patiet last tet is needing this for school at this
Word Decoded sentence: Patient last tet is needing this for school at this 


Input sentence: OBJECTIVE : The temp is 99.8, the f tha blood pressure 99/64, O2 sat 94 8/10 at this time.
Word Decoded sentence: OBJECTIVE : The temp is 99.8, the f tha blood pressure 99/64, O2 sat 94 8/10 at this time 


Input sentence: Left great toe the dorsl surface, extending ta th active hemorrhage at this time.
Word Decoded sentence: Left great toe the dorsal surface extending ta th active hemorrhage at this time 


Input sentence: Th anaathetized with a cotton ball sat Left this in place for 20 minutes.
Word Decode

In [None]:
WER_spell_correction = calculate_WER(gt_texts, decoded_sentences)
print('WER_spell_correction |TEST= ', WER_spell_correction)

In [None]:
WER_spell_word_correction = calculate_WER(gt_texts, corrected_sentences)
print('WER_spell_word_correction |TEST= ', WER_spell_word_correction)

In [None]:
WER_OCR = calculate_WER(gt_texts, input_texts)
print('WER_OCR |TEST= ', WER_OCR)