In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm_notebook as tqdm

**Let op**: De orginele alto's en nieuwe ground truth gebruiken andere namepaces voor de alto tags:

* OCR: http://schema.ccs-gmbh.com/ALTO, heeft meerdere alternatieven
* GT: http://www.loc.gov/standards/alto/ns-v2#

In [None]:
# lees bestanden gt
# connvert file names naar OCR
# lees ocr bestanden
import os

from nlppln.utils import get_files

gs_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Ground-truth/'
ocr_dir = '/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/'


def gt_fname2ocr_fname(fname):
    bn = os.path.basename(fname)
    return bn.replace('GT', 'alto')

gs_files = get_files(gs_dir)
print(len(gs_files))
print([f for f in gs_files if 'extra' in f])
# remove file with "extra" in the name, this one is the same as the file without "extra" in the name
gs_files = [f for f in gs_files if not 'extra' in f]
print(len(gs_files))
gt_fname2ocr_fname(gs_files[0])

In [None]:
ocr_files = []
for gs_file in gs_files:
    ocr_bn = gt_fname2ocr_fname(gs_file)
    ocr_file = os.path.join(ocr_dir, ocr_bn)
    if os.path.isfile(ocr_file):
        ocr_files.append(ocr_file)
    else:
        print('File not found:', ocr_file)
        print('GS file:', gs_file)
print(len(ocr_files))

In [None]:
from collections import OrderedDict

from lxml import etree

def get_words_in_textlines(fname, alto_ns):
    lines = OrderedDict()
    context = etree.iterparse(fname, events=('end', ), tag=(alto_ns+'TextLine'))
    for event, elem in context:
        lines[elem.attrib['ID']] = []
        for a in elem.getchildren():
            if a.tag == alto_ns+'String':
                if a.attrib.get('SUBS_TYPE') == 'HypPart1':
                    lines[elem.attrib['ID']].append(a.attrib['SUBS_CONTENT'])
                elif a.attrib.get('SUBS_TYPE') != 'HypPart2':
                    lines[elem.attrib['ID']].append(a.attrib['CONTENT'])
                    
        #for a in elem.getchildren():
        #    if a.tag == alto_ns+'String':
        #        lines[elem.attrib['ID']].append(a.attrib['CONTENT'])
        
         # make iteration over context fast and consume less memory
        #https://www.ibm.com/developerworks/xml/library/x-hiperfparse
        elem.clear()
        while elem.getprevious() is not None:
            del elem.getparent()[0]
            
    return lines

#get_words_in_textlines('/home/jvdzwaan/ownCloud/Shared/OCR/Originele ALTOs/DDD_000010534_002_alto.xml', 
#             alto_ns='{http://schema.ccs-gmbh.com/ALTO}')

In [None]:
from ochre.matchlines import get_ns

ns = get_ns(gs_files[0])
gs_lines = get_words_in_textlines(gs_files[1], ns)
ns = get_ns(ocr_files[0])
ocr_lines = get_words_in_textlines(ocr_files[1], ns)

In [None]:
import edlib

from collections import OrderedDict

from ochre.matchlines import Match, count_unknown

def initialize_matches(gs_lines, ocr_lines):
    #print(gs_lines.keys())

    matches = [Match(label, i) for i, label in enumerate(gs_lines.keys())]
    
    #print('End phase 1')
    
    for i, m in enumerate(matches):
        #print('set edit distances')
        # set edit distances
        gs = ' '.join(gs_lines[m.gs_label])
        #print(gs, m.gs_label)
    
        to_check = list(ocr_lines.keys())[max(0, i-50):i+50]
    
        eds = OrderedDict()
        for ocr_id in to_check:
            ocr = ' '.join(ocr_lines[ocr_id])
            
            #print(len(gs), len(ocr))
            #print(len(gs.strip()), len(ocr.strip()))
            
            if len(gs) != 0 and len(ocr) != 0:
                #print(repr(gs))
                #print(repr(ocr))
                #print('calculating edit distance')
            
                r = edlib.align(gs, ocr)
                eds[ocr_id] = r['editDistance']
            else: 
                #print('FOuns zroe')
                if len(gs) == 0:
                    eds[ocr_id] = len(ocr)
                elif len(ocr) == 0:
                    eds[ocr_id] = len(gs)
                #print(repr(gs))
                #print(repr(ocr))
        
        m.eds = eds
        
        #print('set preb and next')
        
        # Set previous and next
        if i > 0:
            m.previous = matches[i-1]
            #print('previous', m.previous)
        if i < len(matches)-1:
            m.next = matches[i+1]
            
    #print('End initialize')
    
    return matches


def set_zero(matches, used):
    for i, m in enumerate(matches):
        if min(m.eds.values()) == 0:
            ocr_label = list(m.eds.keys())[list(m.eds.values()).index(0)]
            used.append(ocr_label)
            #print(ocr_label)
                        
            m.ocr_label = ocr_label
        #print(m)

In [None]:
matches = initialize_matches(gs_lines, ocr_lines)

In [None]:
from ochre.matchlines import UNKNOWN, EMPTY

def match_close(matches, ocr_lines, used):

    for m in matches:
        if m.ocr_label == UNKNOWN:
            #print(m)
            ocr_label = m.get_match(ocr_lines, matches, used, method='close')
            m.ocr_label = ocr_label
            if ocr_label != UNKNOWN and ocr_label != EMPTY:
                used.append(ocr_label)
            #print(m)

In [None]:
from ochre.matchlines import print_match

for m in matches:
    print(print_match(m, gs_lines, ocr_lines))

In [None]:
m = matches[259]
#print(m)
print(print_match(m, gs_lines, ocr_lines))
#print(m.get_options(ocr_lines, used))
lbl = m.get_match(ocr_lines, matches, used, method='best')
print(' '.join(ocr_lines['P2_TL00260']))
lbl

In [None]:
def gaps(matches):
    unknowns = []
    gs_labels = []
    unk = False

    for m in matches:
        if m.ocr_label == UNKNOWN:
            unk = True
            unknowns.append(m)
            gs_labels.append(m.gs_label)
        elif unk:
            yield unknowns, gs_labels
            unk = False
            unknowns = []
            gs_labels = []

def match_gaps(matches, used, ocr_lines):
    for unknowns, gs_labels in gaps(matches):
        if len(unknowns) > 0:
            #print(len(gs_labels), len(unknowns))
    
            if len(unknowns) > 0 and len(unknowns) == len(gs_labels):
                if unknowns[0].previous is not None:
                    start_index = list(ocr_lines.keys()).index(unknowns[0].previous.ocr_label) + 1
                else: 
                    start_index = 0
                    
                if unknowns[-1].next is not None:
                    stop_index = list(ocr_lines.keys()).index(unknowns[-1].next.ocr_label)
                else:
                    stop_index = list(ocr_lines.keys())[-1]
                #print(start_index, stop_index)
                #print([i for i in range(start_index, stop_index)])
                options = [list(ocr_lines.keys())[i] for i in range(start_index, stop_index)]
                p_options = []
                for u in unknowns:
                    for o in u.get_options(ocr_lines, used).keys():
                        if o not in p_options:
                            p_options.append(o)
                #print('possible options', p_options)
                if len(p_options) == len(gs_labels):
                #    print(options)
                #    print(set(options).intersection(set(used)))
                #    if len(set(options).intersection(set(used))) == 0:
                     for m, ocr_label in zip(unknowns, p_options):
                            #print(m)
                            #print(list(ocr_lines.keys())[i])
                            m.ocr_label = ocr_label
                            used.append(ocr_label)

In [None]:
def repeat_match_best(matches, used, ocr_lines):
    prev_unknown = len(matches)
    num_unknown = count_unknown(matches)
    
    while num_unknown < prev_unknown:
        for m in matches:
            if m.ocr_label == UNKNOWN:
                #print(m)
                ocr_label = m.get_match(ocr_lines, matches, used, method='best')
                m.ocr_label = ocr_label
                if ocr_label != UNKNOWN and ocr_label != EMPTY:
                    used.append(ocr_label)
                #print(m)
        prev_unknown = num_unknown
        num_unknown = count_unknown(matches)

In [None]:
import re

def edlib2pair(query: str, ref: str, mode: str = "NW") -> str:
    """
    input:
    query and ref sequence

    output:
    TAAGGATGGTCCCAT TC
     ||||  ||||.||| ||
     AAGG  GGTCTCATATC
    """

    a = edlib.align(query, ref, mode=mode, task="path")
    ref_pos = a["locations"][0][0]
    query_pos = 0
    ref_aln = []
    match_aln = ""
    query_aln = []

    for step, code in re.findall(r"(\d+)(\D)", a["cigar"]):
        step = int(step)
        if code == "=":
            for c in ref[ref_pos : ref_pos + step]:
                ref_aln.append(c)
            #ref_aln += ref[ref_pos : ref_pos + step]
            ref_pos += step
            for c in query[query_pos : query_pos + step]:
                query_aln.append(c)
            #query_aln += query[query_pos : query_pos + step]
            query_pos += step
            match_aln += "|" * step
        elif code == "X":
            for c in ref[ref_pos : ref_pos + step]:
                ref_aln.append(c)
            #ref_aln += ref[ref_pos : ref_pos + step]
            ref_pos += step
            for c in query[query_pos : query_pos + step]:
                query_aln.append(c)
            #query_aln += query[query_pos : query_pos + step]
            query_pos += step
            match_aln += "." * step
        elif code == "D":
            for c in ref[ref_pos : ref_pos + step]:
                ref_aln.append(c)
            #ref_aln += ref[ref_pos : ref_pos + step]
            ref_pos += step
            #query_aln += " " * step
            query_pos += 0
            for i in range(step):
                query_aln.append('')
            match_aln += " " * step
        elif code == "I":
            for i in range(step):
                ref_aln.append('')
            #ref_aln += " " * step
            ref_pos += 0
            for c in query[query_pos : query_pos + step]:
                query_aln.append(c)
            #query_aln += query[query_pos : query_pos + step]
            query_pos += step
            match_aln += " " * step
        else:
            pass

    return ref_aln, match_aln, query_aln

In [None]:
import edlib
import json

from nlppln.utils import create_dirs, out_file_name
from ochre.utils import align_characters

def align_page(matches):
    ocr_result = []
    gs_result = []

    for m in matches:
        gs = ' '.join(gs_lines[m.gs_label])
        ocr = ' '.join(ocr_lines.get(m.ocr_label, ''))
    
        #print(' GS:', repr(gs))
        #print('OCR:', repr(ocr))
        #print(len(gs), len(ocr))
    
        gs_a, match_a, ocr_a = edlib2pair(gs, ocr, mode='NW')
        if len(ocr_a) != len(gs_a):
            print('UNEQUAL!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        
        for o, g in zip(ocr_a, gs_a):
            ocr_result.append(o)
            gs_result.append(g)
        #    print(o,g)
        #print('---')
    return gs_result, ocr_result

def doc_id(fname):
    bn = os.path.basename(fname)
    n = bn.rsplit('_', 1)[0]
    return n

In [None]:
from ochre.matchlines import get_ns, replace_entities
from ochre.utils import get_temp_file

from nlppln.utils import create_dirs, out_file_name

def do_matching(gs_lines, ocr_lines):
    print('Number of lines in gs', len(gs_lines))
    print('Number of lines in ocr', len(ocr_lines))

    matches = initialize_matches(gs_lines, ocr_lines)
    
    #matches = []
    
    used = []
    set_zero(matches, used)
    
    num_unknown = count_unknown(matches)
    print('Unknown after set zero:', num_unknown)
    
    match_close(matches, ocr_lines, used)
    
    num_unknown = count_unknown(matches)
    print('Unknown after match close:', num_unknown)
    
    match_gaps(matches, used, ocr_lines)
    
    num_unknown = count_unknown(matches)
    print('Unknown after match gaps:', num_unknown)
    
    repeat_match_best(matches, used, ocr_lines)
    
    num_unknown = count_unknown(matches)
    print('Unknown after match gaps:', num_unknown)
    
    return matches


out_dir = '/home/jvdzwaan/data/kb-ocr/text-not-aligned/aligned'
create_dirs(out_dir)

for gs_file, ocr_file in tqdm(zip(gs_files, ocr_files), total=len(gs_files)):
    print(gs_file)
    print(ocr_file)
    
    gs_tmp = get_temp_file()
    #print(gs_tmp)
    with open(gs_tmp, 'w') as f:
        f.write(replace_entities(gs_file))
            
    #ocr_tmp = get_temp_file()
    #print(gs_tmp)
    #with open(ocr_tmp, 'w') as f:
    #    f.write(replace_entities(ocr_file))
        
    gs_lines = get_words_in_textlines(gs_tmp, get_ns(gs_file))
    ocr_lines = get_words_in_textlines(ocr_file, get_ns(ocr_file))
    print(len(gs_lines), len(ocr_lines))
    
    os.remove(gs_tmp)
    
    matches = do_matching(gs_lines, ocr_lines)
    
    #gs_a, ocr_a = align_page(matches)
    #out_file = out_file_name(out_dir, doc_id(gs_file), 'json')
    #print('Writing', out_file)
    #with open(out_file, 'w') as f:
    #    json.dump({'ocr': ocr_a, 'gs': gs_a}, f)
    print()