In [1]:
import os, sys
import warnings
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import warnings
import traceback
from path import Path
import random
import argparse
import json

import tensorflow as tf
from tensorflow import keras

import cv2
from PIL import Image
from imutils import contours
import editdistance

import pytesseract
from pytesseract import Output

from DataLoaderIAM import DataLoaderIAM, Batch
from Model import Model, DecoderType
from SamplePreprocessor import preprocess

In [2]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)

warnings.filterwarnings('ignore')

In [3]:
class FilePaths:
    "filenames and paths to data"
    fnCharList = '../model/charList.txt'
    fnSummary = '../model/summary.json'
    fnInfer = '../data/test.png'
    fnCorpus = '../data/corpus.txt'


def write_summary(charErrorRates, wordAccuracies):
    with open(FilePaths.fnSummary, 'w') as f:
        json.dump({'charErrorRates': charErrorRates, 'wordAccuracies': wordAccuracies}, f)


def train(model, loader):
    "train NN"
    epoch = 0  # number of training epochs since start
    summaryCharErrorRates = []
    summaryWordAccuracies = []
    bestCharErrorRate = float('inf')  # best valdiation character error rate
    noImprovementSince = 0  # number of epochs no improvement of character error rate occured
    earlyStopping = 25  # stop training after this number of epochs without improvement
    while True:
        epoch += 1
        print('Epoch:', epoch)

        # train
        print('Train NN')
        loader.trainSet()
        while loader.hasNext():
            iterInfo = loader.getIteratorInfo()
            batch = loader.getNext()
            loss = model.trainBatch(batch)
            print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}')

        # validate
        charErrorRate, wordAccuracy = validate(model, loader)

        # write summary
        summaryCharErrorRates.append(charErrorRate)
        summaryWordAccuracies.append(wordAccuracy)
        write_summary(summaryCharErrorRates, summaryWordAccuracies)

        # if best validation accuracy so far, save model parameters
        if charErrorRate < bestCharErrorRate:
            print('Character error rate improved, save model')
            bestCharErrorRate = charErrorRate
            noImprovementSince = 0
            model.save()
        else:
            print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%')
            noImprovementSince += 1

        # stop training if no more improvement in the last x epochs
        if noImprovementSince >= earlyStopping:
            print(f'No more improvement since {earlyStopping} epochs. Training stopped.')
            break


def validate(model, loader):
    "validate NN"
    print('Validate NN')
    loader.validationSet()
    numCharErr = 0
    numCharTotal = 0
    numWordOK = 0
    numWordTotal = 0
    while loader.hasNext():
        iterInfo = loader.getIteratorInfo()
        print(f'Batch: {iterInfo[0]} / {iterInfo[1]}')
        batch = loader.getNext()
        (recognized, _) = model.inferBatch(batch)

        print('Ground truth -> Recognized')
        for i in range(len(recognized)):
            numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0
            numWordTotal += 1
            dist = editdistance.eval(recognized[i], batch.gtTexts[i])
            numCharErr += dist
            numCharTotal += len(batch.gtTexts[i])
            print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gtTexts[i] + '"', '->',
                  '"' + recognized[i] + '"')

    # print validation result
    charErrorRate = numCharErr / numCharTotal
    wordAccuracy = numWordOK / numWordTotal
    print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.')
    return charErrorRate, wordAccuracy

# Create ground truth reference
gt = pd.read_csv('../../collectedData/labels.csv')
gt.text = gt.text.apply(lambda x: x.split())

In [4]:
def split(file):
    # configure tesseract
    ## oem 3 == Engine Mode: Default, based on what is available (Legacy, LSTM, or both)
    ## psm 3 == Page Segmentation Mode: Fully automatic page segmentation, but no OSD
    ## psm 11 == Page Segmentation Mode: Sparse text - find as much text as possible in no particular order, no OSD
    custom_config = r'--oem 3 --psm 11'

#     inPath = '../../collectedData/raw/'
    outParent = '../../collectedData/sliced/'

#     for file in os.listdir(inPath):    
    outPath = os.path.join(outParent,'.'.join(file.split('.')[:-1]))
    print(f'\nSplitting < {file} > -> < {outPath} >')
    if not os.path.exists(outPath):
        os.mkdir(outPath)

    img = cv2.imread(os.path.join(inPath,file))
    imgMarked = cv2.imread(os.path.join(inPath,file))
    imgGray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    imgBinary = cv2.threshold(imgGray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]

    details = pytesseract.image_to_data(imgBinary, output_type=Output.DICT, config=custom_config, lang='eng')
    total_boxes = len(details['text'])
    for sequence_number in range(total_boxes):
        if int(details['conf'][sequence_number]) > -1:
            (x, y, w, h) = (details['left'][sequence_number], details['top'][sequence_number], details['width'][sequence_number],  details['height'][sequence_number])
            imgBinary = cv2.rectangle(imgBinary, (x, y), (x + w, y + h), (0, 0, 0), 2)
            imgMarked = cv2.rectangle(imgMarked, (x, y), (x + w, y + h), (random.randint(0,255), random.randint(0,255), random.randint(0,255)), 2)
            crop_img = img[y:y+h, x:x+w]
            cv2.imwrite(os.path.join(outPath,str(sequence_number).zfill(len(str(total_boxes)))+'.png'), crop_img)

    cv2.imwrite(os.path.join(outParent,'_markings',str(file)), imgMarked)

In [5]:
def show(img, text='img'):
    cv2.imshow(str(text), img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [6]:
def mergeCnts(cnts):
    corners = []
    for c in cnts:
        pts = [list(x[0]) for x in c]
        xVals = [x[0] for x in pts]
        yVals = [x[1] for x in pts]
        xMin = min(xVals)
        xMax = max(xVals)
        yMin = min(yVals)
        yMax = max(yVals)

        for i, shape in enumerate(corners):
            left = shape[0][0,0]
            right = shape[1][0,0]
            top = shape[0][0,1]
            bottom = shape[-1][0,1]
            if (xMin >= left and xMin <= right) or (xMax >= left and xMax <= right):
                xMin = min(xMin,left)
                xMax = max(xMax,right)
                yMin = min(yMin,top)
                yMax = max(yMax,bottom)
                del corners[i]

        corners.append(np.array([[[xMin,yMin]],[[xMax,yMin]],[[xMax,yMax]],[[xMin,yMax]]]))
    
    return corners

In [7]:
def main(path):
    tf.compat.v1.reset_default_graph()

    sys.argv = []
    parser = argparse.ArgumentParser(sys.argv)
    parser.add_argument('--train', help='train the NN', action='store_true')
    parser.add_argument('--validate', help='validate the NN', action='store_true')
    parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath',
                        help='CTC decoder')
    parser.add_argument('--batch_size', help='batch size', type=int, default=100)
    parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
    parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
    parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
    args = parser.parse_args()
    
#     args.validate = True
    args.data_dir = Path(path)
    print(f'Predicting on slices in < {path} >')
    debug = False
    if debug:
        print("\n********")
        print(args)
        print(vars(args))
        print(args.data_dir)
        print("********\n")

    
    # set chosen CTC decoder
    if args.decoder == 'bestpath':
        decoderType = DecoderType.BestPath
    elif args.decoder == 'beamsearch':
        decoderType = DecoderType.BeamSearch
    elif args.decoder == 'wordbeamsearch':
        decoderType = DecoderType.WordBeamSearch

    # train or validate on IAM dataset
    if args.train or args.validate:
        # load training data, create TF model
        loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.imgSize, Model.maxTextLen, args.fast)

        # save characters of model for inference mode
        open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))

        # save words contained in dataset into file
        open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords))

        # execute training or validation
        if args.train:
            model = Model(loader.charList, decoderType)
            train(model, loader)
        elif args.validate:
            model = Model(loader.charList, decoderType, mustRestore=True)
            validate(model, loader)

    # infer text on test images
    else:
        df = pd.DataFrame(columns=['Actual','Guess','Conf','Result','Dist','Full Text','ResMatch','Path'])
        model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump)
        gt = pd.read_csv('../../collectedData/labels.csv')
        gt.text = gt.text.apply(lambda x: x.split())
        for sample in os.listdir(args.data_dir):
            if 'breakdowns' in sample:
                continue
            
            guess, conf = infer(model, os.path.join(args.data_dir,sample))
            
            actual = re.sub('((\W+\(\d*\))?\.((png)|(jpg)))','',sample)
            actual = actual.split(' ',1)[-1]
            result = actual==guess
            fullText = gt[gt.filename.str.contains(path.split('/')[-1])].iloc[0,1]
            resMatch = guess in fullText
            if len(actual)>0:
                dist = editdistance.eval(guess, actual)/len(actual)
            else:
                dist=len(guess)
#                 print('Actual:\t', actual)
#                 print('Guess:\t', guess)
#                 print('Conf:\t ', conf)
#                 print('\t', result)
#                 print()
            
            df = df.append(pd.Series([actual, guess, conf, result, dist, fullText, resMatch, os.path.join(args.data_dir,sample)], index = df.columns), ignore_index=True)
            
        return df

In [8]:
def infer1(model, img):
    scaleFactor = 20/max(img.shape[:2])
    width = int(img.shape[1] * scaleFactor)
    height = int(img.shape[0] * scaleFactor)
    dim = (width, height)

    scaled = cv2.resize(img,dim)

    canvas = np.zeros((28,28), dtype = "uint8")

    x = int((28 - scaled.shape[1])/2)
    y = int((28 - scaled.shape[0])/2)
    canvas[y:y+scaled.shape[0],x:x+scaled.shape[1]] = scaled

#     print(model.predict(canvas.reshape(1, 28, 28, 1)))
    return str(model.predict(canvas.reshape(1, 28, 28, 1)).argmax())

def breakdown(raw):
    model = keras.models.load_model('../model/charModel')
    
    blur = cv2.GaussianBlur(raw, (9, 9), 0)
    thresh = cv2.threshold(blur,0,255,cv2.THRESH_OTSU + cv2.THRESH_BINARY)[1]
    invert = cv2.bitwise_not(thresh)
    img = invert

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 1))
    dilate = cv2.dilate(img, kernel, iterations=5)

    contours, hierarchy = cv2.findContours(dilate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = sorted(contours, key = cv2.contourArea, reverse = True)
    cnts = mergeCnts(cnts)

    guess = ''
    for i, c in enumerate(cnts):
        area = cv2.contourArea(c)
        if area > 10:
            x,y,w,h = cv2.boundingRect(c)
            ROI = 255 - raw[y:y+h, x:x+w]
            cv2.rectangle(raw, (x, y), (x + w, y + h), (random.randint(0,255),random.randint(0,255),random.randint(0,255)), 1)
            crop_img = img[y:y+h, x:x+w]           
            guess += infer1(model, crop_img)
                        
    if len(cnts) > 1:
        return(guess)
    else:
        return('')

def infer(model, imgPath):    
    "recognize text in image provided by file path"
    raw = cv2.imread(imgPath, cv2.IMREAD_GRAYSCALE)
    img = preprocess(raw, Model.imgSize)
    batch = Batch(None, [img])
    
    (recognized, probability) = model.inferBatch(batch, True)
    
    if probability < .3:
        newGuess = breakdown(raw)
        
        if len(newGuess)>0:
            probability[0] = -1
            recognized[0] = newGuess
        
#         print('Done', imgPath)
#         print(f'\tRecognized: "{recognized[0]}"')
#         print(f'\tProbability: {probability[0]}')
#         print()
        
    return (recognized[0],probability[0])

if __name__ == '__main__':
    out = {}
    #for path in [os.path.join('../../collectedData/sliced/',x) for x in os.listdir('../../collectedData/sliced/')]:
    inPath = '../../collectedData/raw/'
    for file in os.listdir(inPath):
        name = file[:-4]
        if '_markings' in file:
            continue
        
        split(file)
        try:  
            out[name] = main(os.path.join('../../collectedData/sliced/',name))
        except:
            print(f'\tERROR on: < {file} >')
#             traceback.print_exc()

SyntaxError: invalid syntax (<ipython-input-8-cb3970c5b870>, line 40)

In [None]:
for key in out.keys():
    print(key)
    df = out[key]
    actual = gt[gt.filename.str.contains(key)].iloc[0,1]
    orig = actual
    correct = []
    for word in df[df.ResMatch].Guess:
        if word in actual:
            correct.append(word)
            actual = actual[actual.index(word)+1:]
    print(correct)
    print(f'{len(correct)}/{len(orig)}: {len(correct)/len(orig)}')
    print()

#  Stats based on individually labeled slice folders

In [None]:
out['math1'] = main('../../collectedData/sliced/math1')

In [None]:
df = out['math1']

print(df.Result.value_counts())
print()

print(f'Accuracy: {sum(df.Result == True)/len(df.Result)}')
print(f'Character Error Rate: {df.Dist.mean()}')

In [None]:
plt.hist(df[df.Result==False].Conf, color='red', bins=20, alpha=0.75)
plt.hist(df[df.Result==True].Conf, color='green', bins=20, alpha=0.5)
plt.ylabel('# of Samples')
plt.xlabel('Confidence')
plt.legend(['Incorrect','Correct'])
plt.show()

In [None]:
%matplotlib inline

for tup in df.itertuples():
    im = Image.open(tup.Path)
    imshow(np.asarray(im))
    plt.show()
    print('Actual:\t', tup.Actual)
    print('Guess:\t', tup.Guess)
    print('Conf:\t ', tup.Conf)
    print('Dist:\t ', tup.Dist)
    print('\t', tup.Result)
    print('\n***************************************************************************************************************')