In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm_notebook as tqdm

In [None]:
from collections import OrderedDict

from lxml import etree

from ochre.matchlines import get_ns

def get_textblocks_lines(fname):
    alto_ns = get_ns(fname)
    blocks = {}
    
    num_blocks = 0
    
    context = etree.iterparse(fname, events=('end', ), tag=(alto_ns+'TextBlock'))
    for event, elem in context:
        blocks[elem.attrib['ID']] = OrderedDict()
        num_blocks += 1
        for a in elem.getchildren():
            if a.tag == alto_ns+'TextLine':
                blocks[elem.attrib['ID']][a.attrib['ID']] = []
                for b in a.getchildren():
                    if b.tag == alto_ns+'String':
                        if b.attrib.get('SUBS_TYPE') == 'HypPart1':
                            blocks[elem.attrib['ID']][a.attrib['ID']].append(b.attrib['SUBS_CONTENT'])
                        elif b.attrib.get('SUBS_TYPE') != 'HypPart2':
                            blocks[elem.attrib['ID']][a.attrib['ID']].append(b.attrib['CONTENT'])
                    
        #for a in elem.getchildren():
        #    if a.tag == alto_ns+'String':
        #        lines[elem.attrib['ID']].append(a.attrib['CONTENT'])
        
         # make iteration over context fast and consume less memory
        #https://www.ibm.com/developerworks/xml/library/x-hiperfparse
        elem.clear()
        while elem.getprevious() is not None:
            del elem.getparent()[0]
            
    return blocks

# blocks equal length
in_file_ocr = '/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/DDD_010017911_002_alto.xml'
in_file_gt = '/home/jvdzwaan/ownCloud/Shared/OCR/Ground-truth/DDD_010017911_002_GT.xml'

# blocks not equal length
#in_file_ocr = '/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/DDD_010007697_002_alto.xml'
#in_file_gt = '/home/jvdzwaan/ownCloud/Shared/OCR/Ground-truth/DDD_010007697_002_GT.xml'


blocks_gs = get_textblocks_lines(in_file_gt)
blocks_ocr = get_textblocks_lines(in_file_ocr)

print(len(blocks_gs), len(blocks_ocr))

In [None]:
blocks_gs

In [None]:
blocks_ocr

In [None]:
import edlib

from collections import OrderedDict

from ochre.matchlines import Match, count_unknown

def initialize_matches(gs_lines, ocr_lines):
    #print(gs_lines.keys())

    matches = [Match(label, i) for i, label in enumerate(gs_lines.keys())]
    
    #print('End phase 1')
    
    for i, m in enumerate(matches):
        #print('set edit distances', i)
        # set edit distances
        gs = ' '.join(gs_lines[m.gs_label])
        #print(gs, m.gs_label)
    
        to_check = list(ocr_lines.keys())[max(0, i-50):i+50]
        #print('to_check', to_check)
    
        eds = OrderedDict()
        for ocr_id in to_check:
            ocr = ' '.join(ocr_lines[ocr_id])
            
            #print(len(gs), len(ocr))
            #print(len(gs.strip()), len(ocr.strip()))
            
            if len(gs) != 0 and len(ocr) != 0:
                #print(repr(gs))
                #print(repr(ocr))
                #print('calculating edit distance')
            
                r = edlib.align(gs, ocr)
                eds[ocr_id] = r['editDistance']
            else: 
                #print('FOuns zroe')
                if len(gs) == 0:
                    eds[ocr_id] = len(ocr)
                elif len(ocr) == 0:
                    eds[ocr_id] = len(gs)
                #print(repr(gs))
                #print(repr(ocr))
                
        #print(eds)
        
        m.eds = eds
        
        #print('set preb and next')
        
        # Set previous and next
        if i > 0:
            m.previous = matches[i-1]
            #print('previous', m.previous)
        if i < len(matches)-1:
            m.next = matches[i+1]
            
    #print('End initialize')
    
    return matches


def set_zero(matches, used):
    #print('Set zero', len(matches))
    for i, m in enumerate(matches):
        #print(m)
        if len(m.eds) > 0:
            if min(m.eds.values()) == 0:
                ocr_label = list(m.eds.keys())[list(m.eds.values()).index(0)]
                used.append(ocr_label)
                #print(ocr_label)
                        
                m.ocr_label = ocr_label
        else:
            print('zero length eds', m)
        #print(m)

In [None]:
from ochre.matchlines import UNKNOWN, EMPTY

def match_close(matches, ocr_lines, used):

    for m in matches:
        if m.ocr_label == UNKNOWN:
            #print(m)
            ocr_label = m.get_match(ocr_lines, matches, used, method='close')
            m.ocr_label = ocr_label
            if ocr_label != UNKNOWN and ocr_label != EMPTY:
                used.append(ocr_label)
            #print(m)
    
    if count_unknown(matches) == len(matches):
        #print('Everything still unknown, so matching best')
        for m in matches:
            #print(m)
            #print('next:', m.next)
            ocr_label = m.get_match(ocr_lines, matches, used, method='best')
            m.ocr_label = ocr_label
            #print('Using', ocr_label)
            if ocr_label != UNKNOWN and ocr_label != EMPTY:
                used.append(ocr_label)
        

In [None]:
def gaps(matches):
    unknowns = []
    gs_labels = []
    unk = False

    for m in matches:
        if m.ocr_label == UNKNOWN:
            unk = True
            unknowns.append(m)
            gs_labels.append(m.gs_label)
        elif unk:
            yield unknowns, gs_labels
            unk = False
            unknowns = []
            gs_labels = []
    if unk:
        yield unknowns, gs_labels

def match_gaps(matches, used, ocr_lines):
    #print('match gaps')
    #print([g for g in gaps(matches)])
    for unknowns, gs_labels in gaps(matches):
        if len(unknowns) > 0:
            #print(len(gs_labels), len(unknowns))
    
            if len(unknowns) > 0 and len(unknowns) == len(gs_labels):
                if unknowns[0].previous is not None:
                    start_index = list(ocr_lines.keys()).index(unknowns[0].previous.ocr_label) + 1
                else: 
                    start_index = 0
                    
                if unknowns[-1].next is not None:
                    stop_index = list(ocr_lines.keys()).index(unknowns[-1].next.ocr_label)
                else:
                    stop_index = unknowns[-1].index + 1
                #print('start - stop', start_index, stop_index)
                #print([i for i in range(start_index, stop_index)])
                options = []
                for i in range(start_index, stop_index):
                    try:
                        options.append(list(ocr_lines.keys())[i])
                    except IndexError:
                        pass
                p_options = []
                #print('unknowns', unknowns)
                for u in unknowns:
                    #print('u', u)
                    #print('get_options', u.get_options(ocr_lines, used).keys())
                    for o in u.get_options(ocr_lines, used).keys():
                        #print(o)
                        if o not in p_options:
                            p_options.append(o)
                #print('possible options', p_options)
                if len(p_options) == len(gs_labels):
                #    print(options)
                #    print(set(options).intersection(set(used)))
                    if len(set(options).intersection(set(used))) == 0:
                        for m, ocr_label in zip(unknowns, p_options):
                            #print(m)
                            #print(list(ocr_lines.keys())[i])
                            m.ocr_label = ocr_label
                            used.append(ocr_label)

In [None]:
def repeat_match_best(matches, used, ocr_lines):
    prev_unknown = len(matches)
    num_unknown = count_unknown(matches)
    
    while num_unknown < prev_unknown:
        for m in matches:
            if m.ocr_label == UNKNOWN:
                #print(m)
                ocr_label = m.get_match(ocr_lines, matches, used, method='best')
                m.ocr_label = ocr_label
                if ocr_label != UNKNOWN and ocr_label != EMPTY:
                    used.append(ocr_label)
                #print(m)
        prev_unknown = num_unknown
        num_unknown = count_unknown(matches)

In [None]:
import re

def edlib2pair(query: str, ref: str, mode: str = "NW") -> str:
    """
    input:
    query and ref sequence

    output:
    TAAGGATGGTCCCAT TC
     ||||  ||||.||| ||
     AAGG  GGTCTCATATC
    """

    a = edlib.align(query, ref, mode=mode, task="path")
    ref_pos = a["locations"][0][0]
    query_pos = 0
    ref_aln = []
    match_aln = ""
    query_aln = []

    for step, code in re.findall(r"(\d+)(\D)", a["cigar"]):
        step = int(step)
        if code == "=":
            for c in ref[ref_pos : ref_pos + step]:
                ref_aln.append(c)
            #ref_aln += ref[ref_pos : ref_pos + step]
            ref_pos += step
            for c in query[query_pos : query_pos + step]:
                query_aln.append(c)
            #query_aln += query[query_pos : query_pos + step]
            query_pos += step
            match_aln += "|" * step
        elif code == "X":
            for c in ref[ref_pos : ref_pos + step]:
                ref_aln.append(c)
            #ref_aln += ref[ref_pos : ref_pos + step]
            ref_pos += step
            for c in query[query_pos : query_pos + step]:
                query_aln.append(c)
            #query_aln += query[query_pos : query_pos + step]
            query_pos += step
            match_aln += "." * step
        elif code == "D":
            for c in ref[ref_pos : ref_pos + step]:
                ref_aln.append(c)
            #ref_aln += ref[ref_pos : ref_pos + step]
            ref_pos += step
            #query_aln += " " * step
            query_pos += 0
            for i in range(step):
                query_aln.append('')
            match_aln += " " * step
        elif code == "I":
            for i in range(step):
                ref_aln.append('')
            #ref_aln += " " * step
            ref_pos += 0
            for c in query[query_pos : query_pos + step]:
                query_aln.append(c)
            #query_aln += query[query_pos : query_pos + step]
            query_pos += step
            match_aln += " " * step
        else:
            pass

    return ref_aln, match_aln, query_aln

In [None]:
import edlib
import json

from nlppln.utils import create_dirs, out_file_name
from ochre.utils import align_characters

def align_page(matches):
    ocr_result = []
    gs_result = []
    
    for m in matches:
        print(m)

    for m in matches:
        gs = ' '.join(gs_lines[m.gs_label])
        ocr = ' '.join(ocr_lines.get(m.ocr_label, ''))
    
        print(' GS:', repr(gs))
        print('OCR:', repr(ocr))
        print(len(gs), len(ocr))
        
        if len(gs) == 0:
            if len(ocr) != 0:
                #print(m)
                #print(' GS:', repr(gs))
                #print('OCR:', repr(ocr))
                #print(len(gs), len(ocr))
                for c in ocr:
                    ocr_result.append(c)
                    gs_result.append('')
        else:
            gs_a, match_a, ocr_a = edlib2pair(ocr, gs, mode='NW')
            if len(ocr_a) != len(gs_a):
                print('UNEQUAL!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        
            for o, g in zip(ocr_a, gs_a):
                ocr_result.append(o)
                gs_result.append(g)
        #    print(o,g)
        #print('---')
    return gs_result, ocr_result

In [None]:
def align_block(gs_text, ocr_text):
    if gs_text != '' and ocr_text != '':
        gs_a, match_a, ocr_a = edlib2pair(ocr_text, gs_text, mode='NW')
    else:
        print('Empty stuff!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        if gs_text == '' and ocr_text == '':
            gs_a = []
            ocr_a = []
        elif ocr_text != '':
            gs_a = []
            ocr_a = []
            for c in ocr_text:
                gs_a.append('')
                ocr_a.append(c)
        else:
            gs_a = []
            ocr_a = []
            for c in gs_text:
                gs_a.append(c)
                ocr_a.append('')
        print('Returning')
        print(gs_a)
        print(ocr_a)

    return gs_a, ocr_a

In [None]:
from ochre.matchlines import get_ns, replace_entities
from ochre.utils import get_temp_file

from nlppln.utils import create_dirs, out_file_name

def do_matching(gs_lines, ocr_lines):
    #print('Number of lines in gs', len(gs_lines))
    #print('Number of lines in ocr', len(ocr_lines))

    matches = initialize_matches(gs_lines, ocr_lines)
    
    #print('Printing matches (init):')
    #for m in matches:
    #    print(m)
    #print('---')
    #print('len matches after init', len(matches))
        
    used = []
    set_zero(matches, used)
    
    #num_unknown = count_unknown(matches)
    #print('Unknown after set zero:', num_unknown)
    
    #print('Printing matches (set zero):')
    #for m in matches:
    #    print(m)
    #print('---')
    
    match_close(matches, ocr_lines, used)
    
    #num_unknown = count_unknown(matches)
    #print('Unknown after match close:', num_unknown)
    
    #match_gaps(matches, used, ocr_lines)
    
    #num_unknown = count_unknown(matches)
    #print('Unknown after match gaps:', num_unknown)
    
    #repeat_match_best(matches, used, ocr_lines)
    
    #num_unknown = count_unknown(matches)
    #print('Unknown after match gaps:', num_unknown)
    
    return matches

In [None]:
import os
import json
import glob

from nlppln.utils import get_files, create_dirs, out_file_name

from ochre.utils import get_temp_file
from ochre.matchlines import replace_entities

# inladen json
# vind gs_file en ocr_file
# Maak datastructuur met textblock_id -> ordered dict met textline ids -> word lists

json_dir = '/home/jvdzwaan/data/kb-ocr/textblock_matches-original-altos/'
gs_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Ground-truth/'
ocr_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/'

out_dir = '/home/jvdzwaan/data/kb-ocr/text_aligned_blocks-match_gs'

json_files = get_files(json_dir)
#json_files = ['/home/jvdzwaan/data/kb-ocr/textblock_matches-original-altos/DDD_010017911_002.json']
print(len(json_files))
for in_file in tqdm(json_files):
    print(in_file)
    bn = os.path.splitext(os.path.basename(in_file))[0]
    gs_file = glob.glob('{}*'.format(os.path.join(gs_dir, bn)))[0]
    gs_tmp = get_temp_file()
    #print(gs_tmp)
    with open(gs_tmp, 'w') as f:
        f.write(replace_entities(gs_file))
    
    ocr_file = glob.glob('{}*'.format(os.path.join(ocr_dir, bn)))[0]
    
    with open(in_file) as f:
        tb_matches = json.load(f)
    
    #print(tb_matches)
    
    gs_blocks = get_textblocks_lines(gs_tmp)
    ocr_blocks = get_textblocks_lines(ocr_file)
    
    print('# textblocks', len(gs_blocks), len(ocr_blocks))
    #print(gs_blocks)
    
    os.remove(gs_tmp)
    
    gs_text = []
    ocr_text = []
    gs_aligned = []
    ocr_aligned = []
    
    for gs_block_id, ocr_block_id in tb_matches.items():
        #print(gs_block_id, ocr_block_id)
        if ocr_block_id is not None:
            #print(gs_blocks[gs_block_id])
            gs_lines = gs_blocks[gs_block_id]
            ocr_lines = ocr_blocks[ocr_block_id]
            #print(len(gs_lines))
            #print(gs_lines)
            #print(len(ocr_lines))
            #print(ocr_lines)
            matches = do_matching(gs_lines, ocr_lines)
            
            gs_to_align = []
            ocr_to_align = []
            
            # save gs text, ocr text and aligned
            #print('Matches')
            for m in matches:
                #print(m)
                if m.ocr_label != UNKNOWN and m.ocr_label != EMPTY:
                    gs_text.append(' '.join(gs_lines[m.gs_label]))
                    ocr_text.append(' '.join(ocr_lines[m.ocr_label]))
                    
                    gs_to_align.append(' '.join(gs_lines[m.gs_label]))
                    ocr_to_align.append(' '.join(ocr_lines[m.ocr_label]))
                    
            #print(m.gs_label, m.ocr_label)
            gs_a, ocr_a = align_block(' '.join(gs_to_align), ' '.join(ocr_to_align))
            gs_aligned.append(gs_a)
            ocr_aligned.append(ocr_a)
            

    # write gs text
    out_file = out_file_name(os.path.join(out_dir, 'gs'), in_file, 'txt')
    print(out_file)
    create_dirs(out_file, is_file=True)
    with open(out_file, 'w') as f:
        f.write(' '.join(gs_text))
            
    
    # write ocr text
    out_file = out_file_name(os.path.join(out_dir, 'ocr'), in_file, 'txt')
    print(out_file)
    create_dirs(out_file, is_file=True)
    with open(out_file, 'w') as f:
        f.write(' '.join(ocr_text))
    
    # write aligned
    gs_a = []
    for tb in gs_aligned:
        for c in tb:
            gs_a.append(c)
    ocr_a = []
    for tb in ocr_aligned:
        for c in tb:
            ocr_a.append(c)
            
    if len(gs_a) != len(ocr_a):
        print('unequal lengths for aligned')
        break
    
    out_file = out_file_name(os.path.join(out_dir, 'aligned'), in_file, 'json')
    print(out_file)
    create_dirs(out_file, is_file=True)
    with open(out_file, 'w') as f:
        json.dump({'ocr': ocr_a, 'gs': gs_a}, f)
        
    print()
    #break