In [None]:
import xml.etree.ElementTree as ET
from PIL import Image, ImageDraw, ImageFilter
import os
import numpy as np
import subprocess
import torch
from transformers import VisionEncoderDecoderModel, AutoProcessor, TrOCRProcessor, AutoTokenizer, AutoModel
from torchvision import transforms
import sys

In [None]:
processor = AutoProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('dh-unibe/trocr-kurrent-XVI-XVII')

In [None]:
'''
Paste directory where XML and JPG (PNG) files are located. 
XML file should contain coordinates of lines masks.
It must have the same name as the corresponding JPG (PNG) file.
'''
xml_directory = 'path'

'''
Paste the path to the TXT file which will contain all transcribed lines.
'''
output_txt_file = 'path'

'''
Path to the cache file which contains the number of the last transcribed and written line.
So in case of program interruption the script will continue transcribing lines
from where it was interrupted.
'''
cache_path = 'path'

In [None]:
def transcribe_image(image):
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    with torch.no_grad():
        generated_ids = model.generate(pixel_values)
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return transcription

# Kraken xml
def get_coords_from_xml(xml_path, image_size):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    namespace = {'alto': 'http://www.loc.gov/standards/alto/ns-v4#'}
    coords_list = []
    lines = root.findall('.//alto:TextLine[@TAGREFS="LINE_TYPE_1"]', namespace)
    for line in lines:
        coords_str = line.find('alto:Shape/alto:Polygon', namespace).get('POINTS')
        coords = [(int(y), int(x)) for y, x in zip(coords_str.split()[::2], coords_str.split()[1::2])]
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon(coords, outline=1, fill=255)
        bbox = mask.getbbox()
        if bbox:
            coords_list.append((bbox, coords))
    coords_list.sort(key=lambda x: x[0][1])
    return [coords for _, coords in coords_list]

# Riksarkivet xml
# def get_coords_from_xml(xml_path, image_size):
#     tree = ET.parse(xml_path)
#     root = tree.getroot()
#     namespace = {'ns': 'http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15'}
#     coords_list = []
#     regions = root.findall('.//ns:TextRegion', namespace)
#     for region in regions:
#         coords_str = region.find('ns:Coords', namespace).get('points')
#         coords = [tuple(map(int, point.split(','))) for point in coords_str.split()]
#         mask = Image.new('L', image_size, 0)
#         draw = ImageDraw.Draw(mask)
#         draw.polygon(coords, outline=1, fill=255)
#         bbox = mask.getbbox()
#         if bbox:
#             coords_list.append((bbox, coords))
#     coords_list.sort(key=lambda x: x[0][1])
#     return [coords for _, coords in coords_list]

def extract_number_from_filename_xml(filename):
    return int(filename.split('.')[0])

def get_average_background_color(image):
    image = image.convert('RGB')
    np_image = np.array(image)
    average_color = np_image.mean(axis=(0, 1)) 
    return tuple(map(int, average_color))

def create_checkerboard_pattern(image_size, avg_color, square_size=8, blur_radius=5):
    base_color = np.array(avg_color)
    light_color = tuple(np.clip(base_color + 70, 0, 255).astype(int))
    dark_color = tuple(np.clip(base_color - 30, 0, 255).astype(int))
    pattern = Image.new('RGB', image_size, color=avg_color)
    draw = ImageDraw.Draw(pattern)
    for y in range(0, image_size[1], square_size):
        for x in range(0, image_size[0], square_size):
            color = light_color if (x // square_size + y // square_size) % 2 == 0 else dark_color
            draw.rectangle([x, y, x + square_size, y + square_size], fill=color)
    pattern = pattern.filter(ImageFilter.GaussianBlur(blur_radius))
    return pattern

def process_image_line(image, coords):
    avg_color = get_average_background_color(image)
    mask = Image.new('L', image.size, 0)
    draw = ImageDraw.Draw(mask)
    draw.polygon(coords, outline=1, fill=255)
    bbox = mask.getbbox()
    if bbox:
        checkerboard_background = create_checkerboard_pattern(image.size, avg_color)
        masked_image = Image.composite(image, checkerboard_background.convert('RGBA'), mask)
        cropped_image = masked_image.crop(bbox)
        return cropped_image.convert('RGB')
    return None

def how_much(xml_dir):
    total_lines = 0
    xml_files_for_counting = [f for f in os.listdir(xml_dir) if f.endswith('.xml')]
    for xml_file in xml_files_for_counting:
        xml_path = os.path.join(xml_dir, xml_file)
        tree = ET.parse(xml_path)
        root = tree.getroot()
        namespace = {'alto': 'http://www.loc.gov/standards/alto/ns-v4#'}
        lines = root.findall('.//alto:TextLine[@TAGREFS="LINE_TYPE_1"]', namespace)
        total_lines += len(lines)
    return total_lines   

def final_script(xml_dir, output_txt_path, cache_path):
    if os.path.exists(cache_path):
        with open(cache_path, 'r') as cache_file:
            try:
                cached_num = int(cache_file.read().strip())
            except ValueError:
                cached_num = 0
    else:
        cached_num = 0
    processed_lines = 0
    with open(output_txt_path, 'a', encoding='utf-8') as output_file:
        files_xmls = os.listdir(xml_dir)
        sorted_xml_files = sorted(files_xmls, key=lambda x: extract_number_from_filename_xml(x) if x.endswith('.xml') else float('inf'))
        total_lines = how_much(xml_dir)
        for xml_file in sorted_xml_files:
            image_path = os.path.join(xml_dir, os.path.splitext(xml_file)[0] + '.jpg')
            image = Image.open(image_path).convert('RGBA')
            image_size = image.size
            xml_path = os.path.join(xml_dir, xml_file)
            coords_list = get_coords_from_xml(xml_path, image_size)
            for coords in coords_list:
                if processed_lines < cached_num:
                    processed_lines += 1
                    continue
                line_image = process_image_line(image, coords)
                transcription = transcribe_image(line_image)
                output_file.write(transcription + '\n')
                processed_lines += 1
                with open(cache_path, 'w') as cache_file:
                    cache_file.write(str(processed_lines))
                max_length = 80
                current_line = f"{processed_lines}/{total_lines}: {transcription}"
                sys.stdout.write(f"\r{current_line.ljust(max_length)}")
                sys.stdout.flush()
            if processed_lines < cached_num:
                continue
            output_file.write('\n\n')

In [None]:
final_script(xml_directory, output_txt_file, cache_path)