# Evaluering av transkripsjoner fra evalueringsløpya

In [None]:
import os
from collections import ChainMap
from functools import lru_cache
from pathlib import Path
from difflib import HtmlDiff
import io
from math import floor, ceil


import lxml.etree
import matplotlib.pyplot as plt
import pandas as pd
from dotenv import load_dotenv
from ipywidgets import interact
from IPython.display import display, Markdown, HTML
from PIL import Image
import numpy as np
from tqdm.auto import tqdm

from samisk_ocr.clean_text_data import clean
import samisk_ocr.trocr
from samisk_ocr.metrics import compute_cer, compute_wer, SpecialCharacterF1

## Last data

In [None]:
metadata = pd.read_csv("../data/new_testset_with_newspapers/metadata.csv")

test_metadata = metadata.query("file_name.str.startswith('test')")

del metadata

## Hent transkripsjonen fra ALTO filene og sammenstill med datasettet

In [None]:
def load_file(path: Path) -> io.BytesIO:
    return io.BytesIO(path.read_bytes())

@lru_cache(30)
def load_image(path: Path) -> Image.Image:
    return Image.open(load_file(path))



In [None]:
LANGUAGES = {
    "sma": "Sørsamisk",
    "sme": "Nordsamisk",
    "smj": "Lulesamisk",
    "smn": "Inaresamisk",
}

def urn_col_to_alto_dirname(urncol: str) -> str:
    out = urncol.removeprefix("URN_NBN_no-nb_").removeprefix("no-nb_digavis_").removesuffix("-1")
    if "monografi" in out:
        return out + "_ocr"
    if "digibok" in out:
        return out + "_ocr_xml"
    return out.removeprefix("no-nb_digavis_") + "_ocr_xml"


def urn_col_to_alto_dir(urncol: str) -> Path:
    return Path("../data/alto") / urn_col_to_alto_dirname(urncol)


def add_page_to_urn(urn: str, page: int) -> str:
    alto_dirname = urn.removesuffix("_xml").removesuffix("_ocr").removeprefix("URN_NBN_no-nb_")
    if "pliktmonografi" in alto_dirname or "digibok" in alto_dirname:
        return f"{alto_dirname}_{page:04.0f}.xml"
    alto_dirname = alto_dirname.removeprefix("no-nb_digavis_")
    return f"{alto_dirname}_{page:03.0f}_null.xml"


def get_alto_file(df: pd.DataFrame) -> str:
    return df["alto_dir"] / df.apply(
        lambda row: add_page_to_urn(row["urn"], row["page"]), axis="columns"
    ).rename("alto_file")


def get_jp2_file(alto_file: str | Path) -> str:
    alto_file = Path(alto_file)
    filename = alto_file.with_suffix(".jp2").name
    folder_name = alto_file.parent.name.removesuffix("ocr_xml") + "jp2"
    return Path("../data") / "alto-jp2" / folder_name / filename

def get_bbox(tag) -> tuple[float, float, float, float]:
    return (
        int(tag.get("HPOS")),
        int(tag.get("VPOS")),
        int(tag.get("HPOS")) + int(tag.get("WIDTH")),
        int(tag.get("VPOS")) + int(tag.get("HEIGHT")),
    )

def get_surrounding_bbox(bboxes: list[tuple[int, int, int, int]]) -> tuple[int, int, int, int]:
    return (
        min(bbox[0] for bbox in bboxes),
        min(bbox[1] for bbox in bboxes),
        max(bbox[2] for bbox in bboxes),
        max(bbox[3] for bbox in bboxes),
    )
        

def get_scales(text_block) -> tuple[float, float, float, float]:
    x1, y1, x2, y2 = get_bbox(text_block)
    x1_, y1_, x2_, y2_ = get_surrounding_bbox([get_bbox(line) for line in text_block.xpath("TextLine")])

    def clean_scale(s):
        if abs(s - 1) < 0.3:
            return 1
        if abs(s - 2) < 0.3:
            return 2
        return s

    from statistics import median
    return tuple(
        s#clean_scale(s)
        for s in (x1/x1_, y1/y1_, x2/x2_, y2/y2_)
    )
    
    

@lru_cache(1000)
def get_alto_textlines(alto_file: Path) -> list[str]:
    # Parse the XML file
    tree = lxml.etree.parse(alto_file)
    
    # Find all TextLine elements
    text_lines = tree.xpath('//TextLine')
    
    # Initialize an empty list to store the concatenated strings
    concatenated_strings = []
    bboxes = []

    prev_parent = None
    # Iterate over each TextLine element
    for text_line in text_lines:
        # Find all String elements within the current TextLine
        strings = text_line.xpath('.//String')
        # Extract the CONTENT attribute from each String element and join them with spaces
        concatenated_content = ' '.join(string.get('CONTENT') for string in strings)
        
        # Check if the last element is an HYP tag
        last_element = text_line.xpath('.//String|.//HYP')[-1]
        if last_element.tag == 'HYP':
            concatenated_content += last_element.get('CONTENT', '')
        
        # Append the concatenated content to the list
        concatenated_strings.append(concatenated_content)

        # Get the bounding box for the TextLine
        parent = text_line.getparent()
        scales = get_scales(parent)
        if prev_parent != parent:
            print(scales)
        prev_parent = parent
        bbox = get_bbox(text_line)
        bboxes.append(tuple(
            s * x for s, x in zip(scales, bbox)
        ))
        if bboxes[-1][0] == bboxes[-1][2] or bboxes[-1][1] == bboxes[-1][3]:
            if len(parent.xpath("TextLine")) == 1:
                bboxes[-1] = (
                    int(parent.get("HPOS")),
                    int(parent.get("VPOS")),
                    int(parent.get("HPOS")) + int(parent.get("WIDTH")),
                    int(parent.get("VPOS")) + int(parent.get("HEIGHT")),
                )
            
        


    right_margin = next(iter(tree.xpath("//RightMargin")))
    bottom_margin = next(iter(tree.xpath("//BottomMargin")))
    width = int(right_margin.get("HPOS")) + int(right_margin.get("WIDTH"))
    height = int(bottom_margin.get("VPOS")) + int(bottom_margin.get("HEIGHT"))
    # Print the result
    return concatenated_strings, bboxes, (width, height)

def get_alto_transcription(alto_file, text):
    if not alto_file.exists():
        return pd.NA
    text_lines, bboxes, size = get_alto_textlines(alto_file)
    cleaned_text_lines = [clean(l) for l in text_lines]
    
    selected_line = min(cleaned_text_lines, key=lambda l: compute_cer(text, l))
    return selected_line, dict(zip(cleaned_text_lines, bboxes))[selected_line], size

def compute_alto_cer(row):
    return compute_cer(row["text"], row["transcription"])

def compute_alto_wer(row):
    return compute_wer(row["text"], row["transcription"])

In [None]:
test_metadata = test_metadata.assign(
    alto_dir = test_metadata["urn"].map(urn_col_to_alto_dir)
).assign(
    alto_file=get_alto_file
).assign(
    transcription=lambda df: df.apply(lambda row: get_alto_transcription(row['alto_file'], row['text'])[0], axis="columns"),
    bbox=lambda df: df.apply(lambda row: get_alto_transcription(row['alto_file'], row['text'])[1], axis="columns"),
    img_size=lambda df: df.apply(lambda row: get_alto_transcription(row['alto_file'], row['text'])[2], axis="columns"),
).assign(
    cer=lambda df: df.apply(compute_alto_cer, axis="columns"),
    wer=lambda df: df.apply(compute_alto_wer, axis="columns")
).assign(
    pliktmonografi=lambda df: df["file_name"].str.contains("pliktmonografi")
).assign(
    language=lambda df: df["langcodes"].map(lambda s: LANGUAGES[s.removeprefix("['").removesuffix("']")])
)

In [None]:
def crop_image(path: Path, bbox: tuple[int, int, int, int], ocr_img_size: tuple[int, int]) -> Image.Image:
    img = load_image(path)
    img_w = img.size[0]
    ocr_w = ocr_img_size[0]
    aspect_ratio = img_w / ocr_w

    return img.crop([x * aspect_ratio for x in bbox])

In [None]:
out_dir = Path("../data") / "baseline_huggingface"
img_dir = out_dir / "test"
img_dir.mkdir(parents=True, exist_ok=True)
test_metadata = test_metadata.query("~pliktmonografi")
save_metadata = test_metadata[[
    "file_name",
    "text",
    "urn",
    "langcodes",
    "page",
    "line",
]]
save_metadata.to_csv(out_dir / "metadata.csv")
save_metadata.assign(file_name=save_metadata["file_name"].map(lambda x: Path(x).name)).to_csv(img_dir / "_metadata.csv")

tm = test_metadata#.groupby("alto_dir").apply(lambda x: pd.concat([x.head(1), x.tail(1)]))
count = 0
for row in tqdm(tm.itertuples(), total=len(tm)):
    img_path = out_dir / row.file_name
    img = crop_image(get_jp2_file(row.alto_file), row.bbox, row.img_size)
    img.save(img_path)
    
    if (cer := compute_cer(row.text, row.transcription)) > 0.5:
        img.thumbnail((1000, 200))
        display(row.alto_file, row.text, row.transcription, cer,img)
        count += 1
print(count)

In [None]:
Image.open("../data/alto-jp2/avvir_null_null_20171230_10_248_1_jp2/avvir_null_null_20171230_10_248_1-1_004_null.jp2")