<a href="https://colab.research.google.com/github/MGJamJam/calamari_kurrent_model/blob/main/ColabNotebooks/LineSegmentationPageXML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import os
from PIL import Image, ImageDraw
import numpy as np
import xml.etree.ElementTree as ET
import copy

def parse_points(points_str):
    """Parse the 'points' attribute from PAGE XML to a list of (x, y) tuples."""
    points = []
    for point in points_str.split():
        x, y = map(int, point.split(','))
        points.append((x, y))
    return points

def draw_polygon(image, points, rectangle = False):
    """Draw a red rectangle around the given points."""
    draw = ImageDraw.Draw(image)
    # Draw a rectangle using the bounding box of the polygon
    x_coords = [x for x, y in points]
    y_coords = [y for x, y in points]
    if (rectangle):
      draw.rectangle([min(x_coords), min(y_coords), max(x_coords), max(y_coords)], outline="red", width=2)
    else:
      draw.polygon(points, outline='red', width=2)

def crop_polygon(image, polygon):
    """Crop the polygon area from the image."""
    # Create a mask image
    mask = Image.new('L', image.size, 0)
    ImageDraw.Draw(mask).polygon(polygon, outline=1, fill=1)
    mask = np.array(mask)

    # Convert the input image to a numpy array and mask it
    image_array = np.array(image)
    new_image_array = np.empty(image_array.shape, dtype='uint8')
    new_image_array[:, :, :3] = image_array[:, :, :3]
    new_image_array[:, :, 3] = mask * 255

    # Create the final image from the masked array
    result = Image.fromarray(new_image_array, 'RGBA')
    bbox = result.getbbox()
    if bbox:
        result = result.crop(bbox)
    return result, bbox


def process_pagexml_draw_polygons(xml_path, output_dir):
    # Parse the PAGE XML file
    tree = ET.parse(xml_path)
    root = tree.getroot()

    # # Get the namespace (in case it is used)
    namespace = {'ns': root.tag.split('}')[0].strip('{')}

    # Find the Page element and get the image filename
    page_elem = root.find('.//ns:Page', namespaces=namespace)
    image_filename = page_elem.attrib['imageFilename']

    # Load the corresponding image
    image_dir = os.path.dirname(xml_path)  # Assume image is in the same directory as the XML
    image_path = os.path.join(image_dir, image_filename)
    image = Image.open(image_path).convert("RGBA")

    # Iterate through all TextLine elements
    for textline in root.findall('.//ns:TextLine', namespaces=namespace):
        # Get the TextLine ID
        textline_id = textline.attrib.get('id', None)

        # Log the ID for debugging
        #print(f"Processing TextLine with ID: {textline_id}")

        # Get the coordinates for the TextLine
        coords = textline.find('.//ns:Coords', namespaces=namespace).attrib['points']
        points = parse_points(coords)

        # Draw a rectangle around the TextLine
        draw_polygon(image, points)

    # Save the image with the drawn rectangles
    output_image_path = os.path.join(output_dir, f"annotated_{os.path.basename(image_path)}")
    image.save(output_image_path)
    #print(f"Annotated image saved as: {output_image_path}")

def process_pagexml_crop_polygons(xml_path, output_dir):
    # Parse the PAGE XML file
    tree = ET.parse(xml_path)
    root = tree.getroot()

    # Get the namespace (in case it is used)
    namespace = {'ns': root.tag.split('}')[0].strip('{')}

    #TODO get rid of the namespace name in the new elements

    # Find the Page element and get the image filename
    page_elem = root.find('.//ns:Page', namespaces=namespace)
    image_filename = page_elem.attrib['imageFilename']

    # Load the corresponding image
    image_dir = os.path.dirname(xml_path)  # Assume image is in the same directory as the XML
    image_path = os.path.join(image_dir, image_filename)
    image = Image.open(image_path).convert("RGBA")

    # Iterate through all TextRegion elements
    for text_region in root.findall('.//ns:TextRegion', namespaces=namespace):
      textregion_id = text_region.attrib.get('id', None)

      # Generate the basic XML template without the TextRegion information
      root_template = copy.deepcopy(root)

      page = root_template.find('.//ns:Page', namespaces=namespace)

      # remove all textlines and textregions
      for region in root_template.findall('.//ns:TextRegion', namespaces=namespace):
        if region.attrib.get('id') == textregion_id:
          # TODO also remove the general unicode or set it to the textline translation later?
          for tl in region.findall('.//ns:TextLine', namespaces=namespace):
            region.remove(tl)
        else:
          page.remove(region)

      # Iterate through all TextLine elements
      for textline in text_region.findall('.//ns:TextLine', namespaces=namespace):
        # Get the TextLine ID
        textline_id = textline.attrib.get('id', None)

        # Log the ID for debugging
        print(f"Processing TextLine with ID: {textline_id}")

        # Get the coordinates for the TextLine
        coords = textline.find('.//ns:Coords', namespaces=namespace).attrib['points']
        coords_points = parse_points(coords)

        # Get the coordinates for the Baseline of the TextLine
        coords = textline.find('.//ns:Baseline', namespaces=namespace).attrib['points']
        baseline_points = parse_points(coords)

        # Crop the image to the polygon defined in the coords of the TextLine
        cropped_image, bbox = crop_polygon(image, coords_points)

        # Get the new TextRegion Coords
        new_text_region_coords = []
        new_text_region_coords.append(f"0,0")
        new_text_region_coords.append(f"0,{bbox[3] - bbox[1]}")
        new_text_region_coords.append(f"{bbox[2] - bbox[0]},{bbox[3] - bbox[1]}")
        new_text_region_coords.append(f"{bbox[2] - bbox[0]},0")
        new_text_region_coords_str = ' '.join(new_text_region_coords)

        # Get the new TextLine Coords
        new_text_line_coords = []
        for point in coords_points:
          updated_x = point[0] - bbox[0]
          updated_y = point[1] - bbox[1]
          new_text_line_coords.append(f"{updated_x},{updated_y}")
        new_text_line_coords_str = ' '.join(new_text_line_coords)

        # Get the new BaseLine Coords
        new_base_line_coords = []
        for point in baseline_points:
          updated_x = point[0] - bbox[0]
          updated_y = point[1] - bbox[1]
          new_base_line_coords.append(f"{updated_x},{updated_y}")
        new_base_line_coords_str = ' '.join(new_base_line_coords)

        # Generate new xml file and set new page width and height
        cropped_xml = copy.deepcopy(root_template)
        cropped_page = cropped_xml.find('.//ns:Page', namespaces=namespace)
        cropped_page.set('imageWidth', str(bbox[2] - bbox[0]))
        cropped_page.set('imageHeight', str(bbox[3] - bbox[1]))

        # Add new TextLine with updated Coords
        # Create the new <TextLine> element with an id attribute
        new_text_line = copy.deepcopy(textline)
        new_text_line.find('.//ns:Coords', namespaces=namespace).set('points', new_text_line_coords_str)
        new_text_line.find('.//ns:Baseline', namespaces=namespace).set('points', new_base_line_coords_str)

        for region in cropped_xml.findall('.//ns:TextRegion', namespaces=namespace):
          if region.attrib.get('id') == textregion_id:
            region.find('.//ns:Coords', namespaces=namespace).set('points', new_text_region_coords_str)
            region.append(new_text_line)

        # TODO clean up code and extract into functions

        new_tree = ET.ElementTree(cropped_xml)

        print('About to save the files')
        # Ensure the ID is present before saving
        if textline_id is not None:
            # New files are saved as <filename_textline_id>.png
            cropped_image_name = image_filename.strip('.png') + '_' + textline_id
            output_path = os.path.join(output_dir, f"{cropped_image_name}.png")
            cropped_image.save(output_path)
            cropped_page.set('imageFilename', f"{cropped_image_name}.png")
            print(f"Saved: {output_path}")
            output_path_xml = os.path.join(output_dir, f"{cropped_image_name}.xml")
            new_tree.write(output_path_xml, encoding='utf-8', xml_declaration=True)
            print(f"Saved: {output_path_xml}")
        else:
            print("Warning: No ID found for TextLine. Skipping this element.")

In [None]:
# Path to the PAGE XML file
xml_path = "bayerische-gesandtschaft-paepstlicher-stuhl-180-1824.xml"
output_dir = "."

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Process the PAGE XML
# draw rectangles
process_pagexml_draw_polygons(xml_path, output_dir)
# crop TextLines
process_pagexml_crop_polygons(xml_path, output_dir)