In [98]:
import pandas as pd
import os
import json
import cv2
from PIL import Image
import sys
from lxml import etree

sys.path.append("..")
from kiebids.utils import crop_image

In [99]:
evaluation_path = "../data/evaluation/tensorboard"
id = "20241220-112659"
experiment_path = f"{evaluation_path}/{id}"

In [111]:
def convert_tb_data(root_dir, sort_by=None):
    """Convert local TensorBoard data into Pandas DataFrame.

    Function takes the root directory path and recursively parses
    all events data.
    If the `sort_by` value is provided then it will use that column
    to sort values; typically `wall_time` or `step`.

    *Note* that the whole data is converted into a DataFrame.
    Depending on the data size this might take a while. If it takes
    too long then narrow it to some sub-directories.

    Paramters:
        root_dir: (str) path to root dir with tensorboard data.
        sort_by: (optional str) column name to sort by.

    Returns:
        pandas.DataFrame with [wall_time, name, step, value] columns.

    """
    import os
    from tensorflow.python.summary.summary_iterator import summary_iterator

    def convert_tfevent(filepath):
        return pd.DataFrame(
            [
                parse_tfevent(e)
                for e in summary_iterator(filepath)
                if len(e.summary.value)
            ]
        )

    def parse_tfevent(tfevent):
        return dict(
            wall_time=tfevent.wall_time,
            name=tfevent.summary.value[0].tag,
            step=tfevent.step,
            value=float(tfevent.summary.value[0].simple_value),
        )

    columns_order = ["wall_time", "name", "step", "value"]

    out = []
    for root, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if "events.out.tfevents" not in filename:
                continue
            file_full_path = os.path.join(root, filename)
            out.append(convert_tfevent(file_full_path))

    # Concatenate (and sort) all partial individual dataframes

    all_df = pd.concat(out)[columns_order]

    pivoted = (
        pd.pivot_table(all_df, columns="name", values=["value"], index=["step"])
        .droplevel(0, axis=1)
        .reset_index()
        .drop(columns=["step"])
        .rename_axis(None)
    )
    pivoted.columns.name = None
    return pivoted


def map_image_to_index(path):
    files = os.listdir(path)
    mapping = {}
    for file in files:
        if file.endswith(".json"):
            with open(os.path.join(path, file), "r") as f:
                data = json.load(f)

            if isinstance(data, dict) and "image_index" in data.keys():
                mapping[data["image_index"]] = file.split(".")[0]
        else:
            continue

    return mapping


def read_xml(file_path: str) -> dict:
    """
    Parses an XML file and extracts information about pages, text regions, and text lines.

    Args:
        file_path (str): The path to the XML file to be parsed.

    Returns:
        dict: A dictionary containing the extracted information with the following structure:
            {
                "image_filename": str,  # The filename of the image associated with the page
                "image_width": str,     # The width of the image
                "image_height": str,    # The height of the image
                "text_regions": [       # A list of text regions
                    {
                        "id": str,           # The ID of the text region
                        "orientation": str,  # The orientation of the text region
                        "coords": str,       # The coordinates of the text region
                        "text": str,         # The text content of the whole text region
                        "text_lines": [      # A list of text lines within the text region
                            {
                                "id": str,        # The ID of the text line
                                "coords": str,    # The coordinates of the text line
                                "baseline": str,  # The baseline coordinates of the text line
                                "text": str       # The text content of the text line
                            }
                        ]
                    }
                ]
            }
    """

    tree = etree.parse(file_path)  # noqa: S320  # Using `lxml` to parse untrusted data is known to be vulnerable to XML attacks
    ns = {"ns": tree.getroot().nsmap.get(None, "")}

    page = tree.find(".//ns:Page", namespaces=ns)
    output = {
        "image_filename": page.get("imageFilename"),
        "image_width": page.get("imageWidth"),
        "image_height": page.get("imageHeight"),
        "text_regions": [],
    }

    for region in page.findall(".//ns:TextRegion", namespaces=ns):
        text_region = {
            "id": region.get("id"),
            "orientation": region.get("orientation"),
            "coords": region.find(".//ns:Coords", namespaces=ns).get("points"),
            "text": (
                region.findall(".//ns:TextEquiv", namespaces=ns)[-1]
                .find(".//ns:Unicode", namespaces=ns)
                .text
                or ""
            ),
            "text_lines": [],
        }

        output["text_regions"].append(text_region)

    return output

In [112]:
df = convert_tb_data(experiment_path)

In [None]:
print(
    f"Data with num ground_truth != num text predictions: {len(df.dropna(subset='Text_recognition/_average_CER'))} / {len(df)}"
)

In [None]:
df.plot.hist(column="Layout_analysis/_average_ious", bins=50)

In [None]:
df.plot.hist(column="Text_recognition/_average_CER", bins=100)

In [None]:
# Check the relatively "good" results
df.plot.hist(column="Text_recognition/_average_CER", bins=100, range=(0, 1))

In [117]:
good_results = df[df["Text_recognition/_average_CER"] < 0.6]
bad_results = df[df["Text_recognition/_average_CER"] > 5]

In [118]:
image_mapping = map_image_to_index("../data/debug/text_recognition")

In [119]:
good_images = [image_mapping[index] for index in good_results.index]
bad_images = [image_mapping[index] for index in bad_results.index]

# "Good" results 

In [195]:
# Look at the good results
i = 8

# Image
image_path = "../data/debug/preprocessing"
image = Image.open(os.path.join(image_path, good_images[i] + ".JPG"))
array_image = cv2.imread(os.path.join(image_path, good_images[i] + ".JPG"))

In [None]:
display(image)

In [197]:
# Get text_recognition results
results_path = "../data/output"
xml = read_xml(os.path.join(results_path, good_images[i] + ".xml"))

In [None]:
# Print
regions = xml["text_regions"]
for region in regions:
    # Imgage
    coords = [int(coord) for coord in region["coords"].split(" ")]
    cropped_image = crop_image(array_image, coords)
    display(Image.fromarray(cropped_image))
    print(region["text"])
    print("")

# Poor results 

In [None]:
# Look at the good results
i = 7

# Image
image_path = "../data/debug/preprocessing"
image = Image.open(os.path.join(image_path, bad_images[i] + ".JPG"))
print(bad_images[i])
array_image = cv2.imread(os.path.join(image_path, bad_images[i] + ".JPG"))
xml = read_xml(os.path.join(results_path, bad_images[i] + ".xml"))
regions = xml["text_regions"]

In [None]:
image

In [None]:
for region in regions:
    # Imgage
    coords = [int(coord) for coord in region["coords"].split(" ")]
    cropped_image = crop_image(array_image, coords)
    display(Image.fromarray(cropped_image))
    print(region["text"])
    print("")