In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import codecs
import glob
import numpy as np
import os

# get test set
with codecs.open('/home/jvdzwaan/data/ocr/datadivision.json', encoding='utf-8') as f:
    division = json.load(f)
print division.get('test')
print len(division.get('test'))

In [None]:
seq_length = 25
num_nodes = 256
layers = 1
batch_size = 100
lowercase = True
bidirectional = True
data_dir = '/home/jvdzwaan/data/ocr'
weights_dir = '/home/jvdzwaan/data/ocr-all-bidirect/'

In [None]:
def read_texts(data_files, data_dir):
    raw_text = []
    gs = []
    ocr = []

    for df in data_files:
        with codecs.open(os.path.join(data_dir, df), encoding='utf-8') as f:
            aligned = json.load(f)

        ocr.append(aligned['ocr'])
        ocr.append([' ']) # add space between two texts
        gs.append(aligned['gs'])
        gs.append([' ']) # add space between two texts

        raw_text.append(''.join(aligned['ocr']))
        raw_text.append(''.join(aligned['gs']))

    # Make a single array, containing the character-aligned text of all data
    # files
    gs_text = [y for x in gs for y in x]
    ocr_text = [y for x in ocr for y in x]

    return ' '.join(raw_text), gs_text, ocr_text

In [None]:
def get_char_to_int(chars):
    return dict((c, i) for i, c in enumerate(chars))


def get_int_to_char(chars):
    return dict((i, c) for i, c in enumerate(chars))

In [None]:
raw_val, gs_val, ocr_val = read_texts(division.get('val'), data_dir)
raw_test, gs_test, ocr_test = read_texts(division.get('test'), data_dir)
raw_train, gs_train, ocr_train = read_texts(division.get('train'), data_dir)

raw_text = ''.join([raw_val, raw_test, raw_train])
if lowercase:
    raw_text = raw_text.lower()

chars = sorted(list(set(raw_text)))
chars.append(u'\n')                      # padding character
char_to_int = get_char_to_int(chars)

n_chars = len(raw_text)
n_vocab = len(chars)

print('Total Characters: {}'.format(n_chars))
print('Total Vocab: {}'.format(n_vocab))

In [None]:
def to_string(char_list, lowercase):
    if lowercase:
        return u''.join(char_list).lower()
    return u''.join(char_list)


def create_synced_data(ocr_text, gs_text, char_to_int, n_vocab, seq_length=25,
                       batch_size=100, padding_char=u'\n', lowercase=False):
    """Create padded one-hot encoded data sets from text.

    A sample consists of seq_length characters from ocr_text
    (includes empty characters) (input), and seq_length characters from
    gs_text (includes empty characters) (output).
    ocr_text and gs_tetxt contain aligned arrays of characters.
    Because of the empty characters ('' in the character arrays), the input
    and output sequences may not have equal length. Therefore input and
    output are padded with a padding character (newline).

    Returns:
      int: the number of samples in the dataset
      generator: generator for one-hot encoded data (so the data doesn't have
        to fit in memory)
    """
    dataX = []
    dataY = []
    text_length = len(ocr_text)
    for i in range(0, text_length-seq_length + 1, 1):
        seq_in = ocr_text[i:i+seq_length]
        seq_out = gs_text[i:i+seq_length]
        dataX.append(to_string(seq_in, lowercase))
        dataY.append(to_string(seq_out, lowercase))
    return len(dataX), synced_data_gen(dataX, dataY, seq_length, n_vocab,
                                       char_to_int, batch_size, padding_char)


def synced_data_gen(dataX, dataY, seq_length, n_vocab, char_to_int, batch_size,
                    padding_char):
    while 1:
        for batch_idx in range(0, len(dataX), batch_size):
            X = np.zeros((batch_size, seq_length, n_vocab), dtype=np.bool)
            Y = np.zeros((batch_size, seq_length, n_vocab), dtype=np.bool)
            sliceX = dataX[batch_idx:batch_idx+batch_size]
            sliceY = dataY[batch_idx:batch_idx+batch_size]
            for i, (sentenceX, sentenceY) in enumerate(zip(sliceX, sliceY)):
                for j, c in enumerate(sentenceX):
                    X[i, j, char_to_int[c]] = 1
                for j in range(seq_length-len(sentenceX)):
                    X[i, len(sentenceX) + j, char_to_int[padding_char]] = 1
                for j, c in enumerate(sentenceY):
                    Y[i, j, char_to_int[c]] = 1
                for j in range(seq_length-len(sentenceY)):
                    Y[i, len(sentenceY) + j, char_to_int[padding_char]] = 1
            yield X, Y

In [None]:
numTestSamples, testDataGen = create_synced_data(ocr_test, gs_test, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase)

In [None]:
def initialize_model_bidirectional(n, dropout, seq_length, chars, output_size,
                                   layers, loss='categorical_crossentropy',
                                   optimizer='adam'):
    model = Sequential()
    model.add(Bidirectional(LSTM(n, return_sequences=True),
                            input_shape=(seq_length, len(chars))))
    model.add(Dropout(dropout))

    for _ in range(layers-1):
        model.add(Bidirectional(LSTM(n, return_sequences=True)))
        model.add(Dropout(dropout))

    model.add(TimeDistributed(Dense(len(chars), activation='softmax')))

    model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])

    return model

def load_weights(model, weights_dir, loss='categorical_crossentropy',
                 optimizer='adam'):
    epoch = 0
    weight_files = glob2.glob('{}{}*.hdf5'.format(weights_dir, os.sep))
    if weight_files != []:
        fname = sorted(weight_files)[0]
        print('Loading weights from {}'.format(fname))

        model.load_weights(fname)
        model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])

        m = re.match(r'.+-(\d\d).hdf5', fname)
        if m:
            epoch = int(m.group(1))

    return epoch, model


In [None]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.layers import TimeDistributed
from keras.layers import Bidirectional

import glob2
import re

model = initialize_model_bidirectional(num_nodes, 0.5, seq_length, chars, n_vocab, layers)
epoch, model = load_weights(model, weights_dir)

In [None]:
int_to_char = get_int_to_char(chars)

In [None]:
numTestSamples, testDataGen = create_synced_data(ocr_test, gs_test, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase)

xTest = np.zeros((numTestSamples, seq_length, n_vocab))
yTest = np.zeros((numTestSamples, seq_length, n_vocab))

steps = 0
idx = 0
for xBatch, yBatch in testDataGen:
    for x, y in zip(xBatch, yBatch):
        if idx < numTestSamples:
            xTest[idx, :, :] = x
            yTest[idx, :, :] = y
        idx+=1
        
    if steps == int(numTestSamples/batch_size):
        break
    steps += 1

In [None]:
def check_data(data):
    res = data.sum(axis=2)
    b = np.ones(res.shape, dtype=np.int)
    
    return (res==b).all()

In [None]:
print check_data(xTest), check_data(yTest)

In [None]:
predicted = model.predict(xTest, verbose=1)
print predicted.shape

In [None]:
match = 0
no_match = 0
in_is_out = 0
change_correctly_predicted = 0
change = False
inputs = []
gs_strs = []
outputs = []


for i, sequence in enumerate(predicted):
    predicted_indices = [np.random.choice(n_vocab, p=p) for p in sequence]
    indices = np.where(yTest[i:i+1,:,:]==True)[2]
    indices2 = np.where(xTest[i:i+1,:,:]==True)[2]
    pred_str = u''.join([int_to_char[j] for j in predicted_indices])
    pred_str = pred_str.replace(u'\n', u'@')
    outputs.append(pred_str)
        
    gs = u''.join([int_to_char[j] for j in indices])
    gs = gs.replace(u'\n', u'@')
    gs_strs.append(gs)
        
    inp = u''.join([int_to_char[j] for j in indices2])
    inp = inp.replace(u'\n', u'@')
    inputs.append(inp)
    #print pred_str
    #print gs
    if list(indices) == list(indices2):
        in_is_out += 1
        change = False
    else:
        change = True
    
    if predicted_indices != list(indices):
        no_match += 1
        #print u'"{}"\t"{}"\t"{}"'.format(inp, gs, pred_str)
    else:
        match += 1
        if change:
            change_correctly_predicted += 1
        
print 'Match', match
print 'No match', no_match
print 'Input == output', in_is_out
print 'Correct when input != output', change_correctly_predicted

In [None]:
print numTestSamples
print match + no_match
print numTestSamples-in_is_out

In [None]:
for gs, inp, outp in zip(gs_strs, inputs, outputs):
    print u'"{}"\t"{}"\t"{}"'.format(inp, gs, outp)

In [None]:
from collections import Counter

num = 550

idx = 0
counters = {}
counters_gs = {}

for input_str, output_str, gs_str in zip(inputs[:num], outputs[:num], gs_strs[:num]):
    print len(input_str), len(output_str), len(gs_str)
    for i, (inp, outp, gs) in enumerate(zip(input_str, output_str, gs_str)):
        #print i, inp, outp
        if outp != '@':
            if not idx + i in counters.keys():
                counters[idx+i] = Counter()
            counters[idx+i][outp] += 1
        
        if gs != '@':
            if not idx + i in counters_gs.keys():
                counters_gs[idx+i] = Counter()
            counters_gs[idx+i][gs] += 1
    idx += 1

In [None]:
print counters

In [None]:
for idx, c in counters.items():
    print sum(c.values())

In [None]:
print counters_gs

In [None]:
for idx, c in counters_gs.items():
    if len(c) > 2:
        print c

In [None]:
agg_out = []

for idx, c in counters.items():
    agg_out.append(c.most_common(1)[0][0])

In [None]:
print ''.join(gs_test[0:25])

In [None]:
print ''.join(agg_out)
print
print ''.join(raw_test[:num*2])
print
print ''.join(gs_test[:num*2])

In [None]:
agg_out = []

for idx, c in counters_gs.items():
    agg_out.append(c.most_common(1)[0][0])
print ''.join(agg_out[:500])

In [None]:
num_chars_per_text = 5000
i = num_chars_per_text
prev_i = -1
indices = []
while i < len(gs_test):
    if gs_test[i] == '.' and gs_test[i+1] == ' ':
        print i
        print ''.join(gs_test[prev_i+1:i+1])
        indices.append(i+1)
        prev_i = i
        i += num_chars_per_text
    i = i+1
print ''.join(gs_test[prev_i+1:])
print indices

In [None]:
import edlib
import unicodedata
from ochre.char_align import align_characters

def align_output_to_input(input_str, output_str, empty_char=u'@'):
    #print type(input_str)
    #print type(output_str)
    #print
    # remove accented and other special characters, because edlib doesn't like them (this might be a python 2 problem)
    #print output_str.encode('ASCII', 'replace')
    #print input_str.encode('ASCII', 'replace')
    t_output_str = output_str.encode('ASCII', 'replace')
    t_input_str = input_str.encode('ASCII', 'replace')
    #print t_output_str, len(t_output_str)
    #print t_input_str, len(t_input_str)
    #print
    try:
        r = edlib.align(t_input_str, t_output_str, task='path')
    except:
        print input_str
        print output_str
    r1, r2 = align_characters(input_str, output_str, r.get('cigar'), empty_char=empty_char, sanity_check=False)
    #print r1
    #print r2
    #print 
    #print u''.join(r1)
    #print u''.join(r2)
    #print 
    #print len(r2)
    #print len(r1)
    while len(r2) < len(input_str):
        r2.append(u'@')
    return u''.join(r2)
    

In [None]:
print align_output_to_input(u'«b««««àà', u'««««a')

In [None]:
from collections import Counter
import edlib
from ochre.char_align import align_characters
from nlppln.utils import remove_ext

inputs = []
gs_strs = []
outputs = []
outputs2 = []

for text in division.get('test'):
    raw_test, gs_test, ocr_test = read_texts([text], data_dir)
    
    numTestSamples, testDataGen = create_synced_data(ocr_test, gs_test, char_to_int, n_vocab, seq_length=seq_length, batch_size=batch_size, lowercase=lowercase)

    xTest = np.zeros((numTestSamples, seq_length, n_vocab))
    yTest = np.zeros((numTestSamples, seq_length, n_vocab))

    steps = 0
    idx = 0
    for xBatch, yBatch in testDataGen:
        for x, y in zip(xBatch, yBatch):
            if idx < numTestSamples:
                xTest[idx, :, :] = x
                yTest[idx, :, :] = y
            idx+=1
        
        if steps == int(numTestSamples/batch_size):
            break
        steps += 1
    
    predicted = model.predict(xTest, verbose=1)
    
    match = 0
    no_match = 0
    in_is_out = 0
    inputs = []
    gs_strs = []
    outputs = []

    for i, sequence in enumerate(predicted):
        predicted_indices = [np.random.choice(n_vocab, p=p) for p in sequence]
        indices = np.where(yTest[i:i+1,:,:]==True)[2]
        indices2 = np.where(xTest[i:i+1,:,:]==True)[2]
        pred_str = u''.join([int_to_char[j] for j in predicted_indices])
        pred_str = pred_str.replace(u'\n', u'@')
        outputs.append(pred_str)
        
        gs = u''.join([int_to_char[j] for j in indices])
        gs = gs.replace(u'\n', u'@')
        gs_strs.append(gs)
        
        inp = u''.join([int_to_char[j] for j in indices2])
        inp = inp.replace(u'\n', u'@')
        inputs.append(inp)
        #print pred_str
        #print gs
        if predicted_indices != list(indices):
            no_match += 1
            #print u'"{}"\t"{}"\t"{}"'.format(inp, gs, pred_str)
        else:
            match += 1

        if list(indices) == list(indices2):
            in_is_out += 1
     
    print
    print 'Match', match
    print 'No match', no_match
    print 'Input == output', in_is_out

    idx = 0
    counters = {}
    counters_gs = {}
    prev_output_str = ''

    for input_str, output_str, gs_str in zip(inputs, outputs, gs_strs):
        #print len(input_str), len(output_str), len(gs_str)
        #print input_str, '\t', gs_str, '\t', output_str 
        #print type(output_str.replace('@', ''))
        if '@' in output_str:
            output_str2 = align_output_to_input(input_str, output_str.replace('@', ''), empty_char=u'@')
        else:
            output_str2 = output_str
        outputs2.append(output_str2)
        for i, (inp, outp, gs) in enumerate(zip(input_str, output_str2, gs_str)):
            #print i, inp, outp
            #if outp != '@':
            if not idx + i in counters.keys():
                counters[idx+i] = Counter()
            counters[idx+i][outp] += 1
        
            #if gs != '@':
            if not idx + i in counters_gs.keys():
                counters_gs[idx+i] = Counter()
            counters_gs[idx+i][gs] += 1
        idx += 1
        
    agg_out = []
    for idx, c in counters.items():
        agg_out.append(c.most_common(1)[0][0])
     
    agg_out_gs = []
    for idx, c in counters_gs.items():
        agg_out_gs.append(c.most_common(1)[0][0])
        
    new_text = u''.join(agg_out)
    new_text = new_text.replace(u'@', u'')
    #print new_text
    
    fname = remove_ext(text)
    fname = '{}-1x256-bidirect-all.txt'.format(fname)
    #print fname
    with codecs.open(os.path.join('/home/jvdzwaan/data/results-1x256-bidirect-all', fname), 'wb', encoding='utf-8') as f:
        f.write(new_text)
    #print ''.join(agg_out)
    #print 
    #print ''.join(ocr_test)
    #print
    #print ''.join(gs_test)
    #print 
    #print ''.join(agg_out_gs)

In [None]:
for gs, inp, outp in zip(gs_strs, inputs, outputs2):
    print u'"{}"\t"{}"\t"{}"'.format(inp, gs, outp)