# Evaluating OCR Models

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
import time
import math
from collections import Counter
import unidecode

# Import Widgets
from ipywidgets import Button, Text, HBox, VBox
from IPython.display import display, clear_output

# Import costume functions, corresponding to notebooks
#from ocr import page, words, charSeg
from ocr.normalization import letterNorm
from ocr import charSeg
# Helpers
from ocr.helpers import implt, resize
from ocr.datahelpers import loadWordsData, char2idx, idx2char
from ocr.tfhelpers import Graph

Loading Segmantation model:
INFO:tensorflow:Restoring parameters from models/gap-clas/CNN-CG
INFO:tensorflow:Restoring parameters from models/gap-clas/large/CNN-CG
INFO:tensorflow:Restoring parameters from models/gap-clas/RNN/Bi-RNN
INFO:tensorflow:Restoring parameters from models/gap-clas/RNN/Bi-RNN-dense


### Global Variables

In [2]:
# Settings
LANG = 'en'

## Load Trained Model

In [3]:
charClass_1 = Graph('models/char-clas/' + LANG + '/CharClassifier')
charClass_2 = Graph('models/char-clas/' + LANG + '/Bi-RNN/model_2', 'prediction')
charClass_3 = Graph('models/char-clas/' + LANG + '/Bi-RNN/model_1', 'prediction')

wordClass = Graph('models/word-clas/' + LANG + '/WordClassifier', 'prediction_infer')

INFO:tensorflow:Restoring parameters from models/char-clas/en/CharClassifier
INFO:tensorflow:Restoring parameters from models/char-clas/en/Bi-RNN/model_2
INFO:tensorflow:Restoring parameters from models/char-clas/en/Bi-RNN/model_1
INFO:tensorflow:Restoring parameters from models/word-clas/en/WordClassifier


## Load image

In [4]:
images, labels = loadWordsData('data/test_words/' + LANG, loadGaplines=False)

if LANG == 'en':
    for i in range(len(labels)):
        labels[i] = unidecode.unidecode(labels[i])
        
print('Number of chars:', sum(len(l) for l in labels))

Loading words...
-> Number of words: 250
Number of chars: 1228


# Testing

In [5]:
# Load Words
WORDS = {}
with open('data/' + LANG + '_50k.txt') as f:
    for line in f:
        if LANG == 'en':
            WORDS[unidecode.unidecode(line.split(" ")[0])] = int(line.split(" ")[1])
        else:
            WORDS[line.split(" ")[0]] = int(line.split(" ")[1])
WORDS = Counter(WORDS)

def P(word, N=sum(WORDS.values())): 
    "Probability of `word`."
    return WORDS[word] / N

def correction(word): 
    "Most probable spelling correction for word."
    if word in WORDS:
        return word
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of `words` that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    
    if LANG == 'cz':
        letters = 'aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž'
    else:
        letters = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

In [6]:
class WordCycler:
    """ Cycle through the words and recognise them """ 
    def __init__(self, images, labels, charClass):
        self.images = images
        self.labels = labels
        self.charClass = charClass
        
        self.totalChars = sum([len(l) for l in labels])
        
        self.evaluateAll()
        
    def recogniseWord(self, img):
        slider = (60, 15)
        length = img.shape[1]//slider[1]
        
        input_seq = np.zeros((1, length, slider[0] * slider[1]), dtype=np.float32)
        input_seq[0][:] = [img[:, loc * slider[1]: (loc+1) * slider[1]].flatten()
                           for loc in range(length)]                                
        input_seq = input_seq.swapaxes(0, 1)
        
        targets = np.zeros((1, 1), dtype=np.int32)                        
        pred = self.charClass.eval_feed({'encoder_inputs:0': input_seq,
                                         'encoder_inputs_length:0': [length],
                                         'decoder_targets:0': targets,
                                         'keep_prob:0': 1})[0]
                  
        word = ''
        for i in pred:
            if word == 1:
                break
            else:
                word += idx2char(i)

        return word.lower()
    
    
    def countCorrect(self, pred, label):
        correct = 0
        for i in range(min(len(pred), len(label))):
            if pred[i] == label.lower()[i]:
                correct += 1
                
        return correct            
        

    def evaluateAll(self):
        self.evaluate()
        

    def evaluate(self):
        """ Evaluate accuracy of the word classification """
        print()
        print("STATS: Seq2Seq")
        print(self.labels[0], ':', self.recogniseWord(self.images[0]))
        start_time = time.time()
        correct = 0
        correctWithCorrection = 0
        for i in range(len(self.images)):
            word = self.recogniseWord(self.images[i])
            correct += self.countCorrect(word,
                                         self.labels[i])
            correctWithCorrection += self.countCorrect(correction(word),
                                                       self.labels[i])
        print("Correct/Total: %s / %s" % (correct, self.totalChars))
        print("Accuracy: %s %%" % round(correct/self.totalChars * 100, 4))
        print("Accuracy with correction: %s %%" % round(correctWithCorrection/self.totalChars * 100, 4))
        print("--- %s seconds ---" % round(time.time() - start_time, 2))

In [13]:
class Cycler:
    """ Cycle through the words and recognise them """ 
    def __init__(self, images, labels, charClass, charRNN=False):
        self.images = images
        self.labels = labels
        self.charClass = charClass
        self.charRNN = charRNN
        
        self.totalChars = sum([len(l) for l in labels])
        
        self.evaluateAll()
        
    def recogniseWord(self, img, gapRNN=False, norm1=False, norm2=False, large=False):
        if large:
            border = 60
        else:
            border = 15
        img = cv2.copyMakeBorder(img,
                                 0, 0, border, border,
                                 cv2.BORDER_CONSTANT,
                                 value=[0, 0, 0])
        if norm1:
            gapImg = (img - np.mean(img)) / max(np.std(img), 1.0 / math.sqrt(img.size))
        else:
            gapImg = img
        gaps = charSeg.segmentation(gapImg, RNN=gapRNN, large=large)
        
        chars = []
        for i in range(len(gaps)-1):
            char = img[:, gaps[i]:gaps[i+1]]
            # TODO None type error after treshold
            char, dim = letterNorm(char, dim=True)
            # TODO Test different values
            if dim[0] > 4 and dim[1] > 4:
                if norm2:
                    char = (char - np.mean(char)) / max(np.std(char), 1.0 / math.sqrt(char.size))
                else:
                    char = char
                chars.append(char.flatten())
                
        chars = np.array(chars)
        
        if self.charRNN:
            pred = self.charClass.eval_feed({'inputs:0': [chars],
                                             'length:0': [len(chars)],
                                             'keep_prob:0': 1})[0]
        else:
            pred = self.charClass.run(chars)
        
        word = ''
        for c in pred:
            # word += CHARS[charIdx]
            word += idx2char(c)
        
        return word.lower()
    
    
    def countCorrect(self, pred, label):
        correct = 0
        for i in range(min(len(pred), len(label))):
            if pred[i] == label.lower()[i]:
                correct += 1
                
        return correct            
        

    def evaluateAll(self):
        self.evaluate(True, True, False, False)
        self.evaluate(True, False, False, True)
#         self.evaluate(True, True, True)
        self.evaluate(False, False, False, False)
        self.evaluate(False, False, False, True)
        

    def evaluate(self, gapRNN, norm1, norm2, large=False):
        """ Evaluate accuracy of the word classification """
        print()
        print("STATS: gapRNN - %s, charRNN - %s, gapNorm - %s, charNorm - %s, large - %s" %
              (gapRNN, self.charRNN, norm1, norm2, large))
        print(self.labels[0], ':', self.recogniseWord(self.images[0], gapRNN, norm1, norm2, large))
        start_time = time.time()
        correct = 0
        correctWithCorrection = 0
        for i in range(len(self.images)):
            word = self.recogniseWord(self.images[i], gapRNN, norm1, norm2, large)
            correct += self.countCorrect(word,
                                         self.labels[i])
            correctWithCorrection += self.countCorrect(correction(word),
                                                       self.labels[i])
        print("Correct/Total: %s / %s" % (correct, self.totalChars))
        print("Accuracy: %s %%" % round(correct/self.totalChars * 100, 4))
        print("Accuracy with correction: %s %%" %
              round(correctWithCorrection/self.totalChars * 100, 4))
        print("--- %s seconds ---" % round(time.time() - start_time, 2))

In [14]:
# Class cycling through words

# WordCycler(images,
#            labels,
#            wordClass)

Cycler(images,
       labels,
       charClass_1,
       charRNN=False)

Cycler(images,
       labels,
       charClass_2,
       charRNN=True)

Cycler(images,
       labels,
       charClass_3,
       charRNN=True)


STATS: gapRNN - True, charRNN - False, gapNorm - True, charNorm - False, large - False
urges : srpes
Correct/Total: 736 / 1228
Accuracy: 59.9349 %
Accuracy with correction: 64.1694 %
--- 52.37 seconds ---

STATS: gapRNN - True, charRNN - False, gapNorm - False, charNorm - False, large - True
urges : sgee
Correct/Total: 156 / 1228
Accuracy: 12.7036 %
Accuracy with correction: 10.0977 %
--- 81.72 seconds ---

STATS: gapRNN - False, charRNN - False, gapNorm - False, charNorm - False, large - False
urges : trrps
Correct/Total: 673 / 1228
Accuracy: 54.8046 %
Accuracy with correction: 59.5277 %
--- 45.93 seconds ---

STATS: gapRNN - False, charRNN - False, gapNorm - False, charNorm - False, large - True
urges : zrqes
Correct/Total: 642 / 1228
Accuracy: 52.2801 %
Accuracy with correction: 53.5831 %
--- 83.13 seconds ---

STATS: gapRNN - True, charRNN - True, gapNorm - True, charNorm - False, large - False
urges : srqoz
Correct/Total: 571 / 1228
Accuracy: 46.4984 %
Accuracy with correction: 5

KeyboardInterrupt: 