In [None]:
import os
import re
import zipfile
import requests
import xml.sax.saxutils

import numpy as np
import pandas as pd

from joblib import hash
from copy import deepcopy
from rich.pretty import pprint

from aymurai.utils.json_data import load_json
from aymurai.text.docx2xml import DocxXMLExtract

In [None]:
docx_xml_extract = DocxXMLExtract()

In [None]:
# Sample docx
# doc_path = "/resources/data/restricted/ar-juz-pcyf-10/RESOLUCIONES DEL JUZGADO/Aclaratoria/15473 AMARILLA 89 82 y 80 aclaratoria duracion mediddas de protección - hace lugar.docx"
# doc_path = "/resources/data/restricted/ar-juz-pcyf-10/RESOLUCIONES DEL JUZGADO/Allanamiento/92939 VADERROSA rechaza allanamiento.docx"
doc_path = "/resources/data/restricted/ar-juz-pcyf-10/RESOLUCIONES DEL JUZGADO/Allanamiento/Hace lugar/54765 ARDITO 149 allanamiento armas hace lugar.docx"
output_dir = os.path.basename(doc_path).split(".")[0]

In [None]:
docx_xml_extract.unzip_document(doc_path, output_dir)

## /document-extract endpoint output

In [None]:
# # /document-extract endpoint output
# extracted_document = load_json("response-document-extract.json")
# extracted_document = load_json("response_1716927762752.json")
extracted_document = load_json("response_1716928496834.json")
extracted_document

In [None]:
len(extracted_document["document"])

## /docx2xml endpoint output

In [None]:
# /docx2xml endpoint output
xml_document = docx_xml_extract({"path": doc_path})
xml_document = {"paragraphs": xml_document}

In [None]:
len(xml_document["paragraphs"])

In [None]:
xml_document["paragraphs"][:10]

## Matching

In [None]:
from unicodedata import normalize

In [None]:
hashes = [hash(paragraph) for paragraph in extracted_document["document"]]
hashes

In [None]:
hash2idx = {i: _hash for i, _hash in enumerate(hashes)}
hash2idx

In [None]:
extracted_document["document"][29]

In [None]:
hash2idx = {
    hash(normalize("NFKC", paragraph["plain_text"].strip())): np.where(
        np.array(hashes) == hash(normalize("NFKC", paragraph["plain_text"].strip()))
    )[0].tolist()
    for paragraph in xml_document["paragraphs"]
}

hash2idx

In [None]:
xml_document["paragraphs"] = [
    paragraph | {"hash": hash(normalize("NFKC", paragraph["plain_text"].strip()))}
    for paragraph in xml_document["paragraphs"]
]

xml_document["paragraphs"] = [
    paragraph | {"extracted_document_indices": hash2idx[paragraph["hash"]]}
    for paragraph in xml_document["paragraphs"]
]

In [None]:
len(xml_document["paragraphs"])

In [None]:
xml_document["paragraphs"] = [
    paragraph
    for paragraph in xml_document["paragraphs"]
    if paragraph["extracted_document_indices"]
]

len(xml_document["paragraphs"])

## Inference

In [None]:
# Function to make inference
def get_predictions(sample: str) -> dict:
    response = requests.post(
        url="http://localhost:8899/anonymizer/predict",
        json={"text": sample},
    )
    return response.json()

In [None]:
preds = [get_predictions(paragraph) for paragraph in extracted_document["document"]]
preds

In [None]:
# Matching
xml_document["paragraphs"] = [
    paragraph | {"labels": preds[hash2idx[paragraph["hash"]][0]]["labels"]}
    for paragraph in xml_document["paragraphs"]
]

In [None]:
iter_paragraphs = iter(xml_document["paragraphs"])

In [None]:
paragraph = next(iter_paragraphs)
pprint(paragraph)

## Replace plain texts

In [None]:
from string import punctuation

punctuation

In [None]:
def unify_consecutive_labels(sample: dict, text_key: str = "document"):
    labels = sample["labels"]
    document = sample[text_key]

    unified_labels = []
    current_group = None

    for label in labels:
        # TODO: make this a post-processing in prediction pipeline
        is_punctuation = label["text"] in punctuation
        if is_punctuation:
            continue

        if current_group is None:
            # Start a new group with the current label
            current_group = {
                "text": label["text"],
                "start_char": label["start_char"],
                "end_char": label["end_char"],
                "aymurai_label": label["attrs"]["aymurai_label"],
            }
        elif (
            current_group["aymurai_label"] == label["attrs"]["aymurai_label"]
            and (label["start_char"] - current_group["end_char"]) <= 1
        ):
            # Extend the current group with the current label
            current_group["end_char"] = label["end_char"]
        else:
            # Finish the current group and start a new one
            current_group["text"] = document[
                current_group["start_char"] : current_group["end_char"] + 1
            ]
            unified_labels.append(current_group)
            current_group = {
                "text": label["text"],
                "start_char": label["start_char"],
                "end_char": label["end_char"],
                "aymurai_label": label["attrs"]["aymurai_label"],
            }

    # Finish the last group
    if current_group is not None:
        current_group["text"] = document[
            current_group["start_char"] : current_group["end_char"] + 1
        ]
        unified_labels.append(current_group)

    return unified_labels

In [None]:
preds[4]

In [None]:
unify_consecutive_labels(preds[4])

In [None]:
[unify_consecutive_labels(pred) for pred in preds]

In [None]:
def replace_labels_in_text(pred: dict, text_key: str = "document"):
    pred = deepcopy(pred)
    doc = pred[text_key]

    unified_labels = unify_consecutive_labels(pred, text_key)
    offset = 0

    for unified_label in unified_labels:
        start_char = unified_label["start_char"] + offset
        end_char = unified_label["end_char"] + offset
        len_text_to_replace = end_char - start_char

        aymurai_label = xml.sax.saxutils.escape(f" <{unified_label['aymurai_label']}>")
        len_aymurai_label = len(aymurai_label)

        doc = doc[:start_char] + aymurai_label + doc[end_char:]

        offset += len_aymurai_label - len_text_to_replace

    return re.sub(r" +", " ", doc).strip()

In [None]:
replaced = [replace_labels_in_text(pred) for pred in preds]
replaced

In [None]:
replaced = [
    replace_labels_in_text(pred, text_key="plain_text")
    for pred in xml_document["paragraphs"]
]
replaced

## Replace source XMLs

In [None]:
from aymurai.utils.alignment.core import align_text, tokenize

In [None]:
def erase_duplicates_justseen(series: pd.Series) -> pd.Series:
    return pd.Series(
        [
            ("" if (i > 0 and series.iloc[i] == series.iloc[i - 1]) else series.iloc[i])
            for i in range(len(series))
        ]
    )

In [None]:
from collections import Counter


def parse_token_indices(sample: dict) -> pd.DataFrame:
    original_text = " ".join(
        [fragment["text"] for fragment in sample["metadata"]["fragments"]]
    )
    anonymized_text = replace_labels_in_text(sample, text_key="plain_text")

    xml_file = sample["metadata"]["xml_file"]

    aligned = align_text(
        "<START> " + original_text + " <END>",
        "<START> " + anonymized_text + " <END>",
        # preserve_whitespaces=True,
    )
    aligned["target"] = erase_duplicates_justseen(aligned["target"])

    tokens = []
    for i, fragment in enumerate(sample["metadata"]["fragments"]):
        text = fragment["text"]
        tokenized_text = tokenize(text)
        paragraph_index = fragment["paragraph_index"]

        counter = Counter()
        for j, token in enumerate(tokenized_text):
            counter.update([token])

            splits = text.split(token)
            left, right = splits[: counter[token]], splits[counter[token] :]
            left = "".join(left)
            right = "".join(right)

            start = sample["metadata"]["start"] + fragment["start"] + len(left)
            end = start + len(token)

            fragment_start = sample["metadata"]["start"] + fragment["start"]
            fragment_end = sample["metadata"]["start"] + fragment["end"]

            tokens.append(
                (
                    xml_file,
                    paragraph_index,
                    i,
                    j,
                    token,
                    start,
                    end,
                    fragment_start,
                    fragment_end,
                    text,
                )
            )

    tokens = pd.DataFrame(
        tokens,
        columns=[
            "xml_file",
            "paragraph_index",
            "fragment_index",
            "token_index",
            "token",
            "start_char",
            "end_char",
            "original_start_char",
            "original_end_char",
            "original_text",
        ],
    )

    tokens = pd.concat(
        [tokens, aligned["target"].iloc[1:-1].reset_index(drop=True)], axis=1
    )

    # tokens["target"] = tokens["target"].map(xml.sax.saxutils.escape)

    return tokens

In [None]:
xml_document["paragraphs"][6]

In [None]:
parse_token_indices(xml_document["paragraphs"][6])

In [None]:
# import xml.etree.ElementTree as ET


# # Look for every w:t tag in the document and attach the whitespace_preserve attribute
# def preserve_whitespace(xml_content: str) -> str:
#     namespaces = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"}

#     # Parse the XML content
#     tree = ET.ElementTree(ET.fromstring(xml_content))
#     root = tree.getroot()

#     # Register namespaces
#     for prefix, uri in namespaces.items():
#         ET.register_namespace(prefix, uri)

#     # Find all w:t elements and set the xml:space attribute
#     for wt in root.findall(".//w:t", namespaces):
#         wt.set("{http://www.w3.org/XML/1998/namespace}space", "preserve")

#     # Write back the XML content to a string
#     xml_str = ET.tostring(root, encoding="unicode", method="xml")

#     return xml_str


# Look for every w:t tag in the document, attach the whitespace_preserve attribute, and replace multiple spaces with a single space
# def preserve_whitespace_and_normalize_spaces(xml_content: str) -> str:
#     namespaces = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"}

#     # Parse the XML content
#     tree = ET.ElementTree(ET.fromstring(xml_content))
#     root = tree.getroot()

#     # Register namespaces
#     for prefix, uri in namespaces.items():
#         ET.register_namespace(prefix, uri)

#     # Find all w:t elements
#     for wt in root.findall(".//w:t", namespaces):
#         # Set the xml:space attribute to preserve
#         wt.set("{http://www.w3.org/XML/1998/namespace}space", "preserve")

#         # Replace multiple spaces with a single space in the text content
#         if wt.text:
#             wt.text = re.sub(r"\s+", " ", wt.text)

#     # Write back the XML content to a string
#     xml_str = ET.tostring(root, encoding="unicode", method="xml")

#     return xml_str

In [None]:
import re
from lxml import etree


# Look for every w:t tag in the document, attach the whitespace_preserve attribute, replace multiple spaces with a single space, and remove empty blocks
def normalize_document(
    xml_content: str,
) -> str:
    # Parse the XML content with lxml
    parser = etree.XMLParser(ns_clean=True)
    root = etree.fromstring(xml_content.encode("utf-8"), parser)

    # Extract namespaces
    namespaces = {k: v for k, v in root.nsmap.items() if k}

    # Find all w:r elements containing w:t elements
    for wr in root.xpath("//w:r", namespaces=namespaces):
        wt = wr.find(".//w:t", namespaces)
        if wt is not None:
            # Set the xml:space attribute to preserve
            wt.set("{http://www.w3.org/XML/1998/namespace}space", "preserve")

            # Replace multiple spaces with a single space in the text content
            if wt.text:
                wt.text = re.sub(r"\s+", " ", wt.text)

            # Check if the text is empty after normalization
            if not wt.text or wt.text.strip() == "":
                # Remove the w:r element from its parent
                wr.getparent().remove(wr)

    # Write back the XML content to a string
    xml_str = etree.tostring(root, encoding="unicode", pretty_print=True)

    return xml_str

In [None]:
def replace_text_in_xml(paragraphs: list[dict], base_dir: str):
    tokens = pd.concat(
        [parse_token_indices(sample) for sample in paragraphs], ignore_index=True
    )
    fragments = (
        tokens.groupby(["xml_file", "paragraph_index", "fragment_index"])
        .agg(
            {
                "target": " ".join,
                "start_char": "min",
                "end_char": "max",
                "original_start_char": "min",
                "original_end_char": "max",
                "original_text": "first",
            }
        )
        .reset_index()
    )

    for xml_file, group in fragments.groupby("xml_file"):
        group = group.sort_values("end_char", ascending=False)

        with open(f"{base_dir}/word/{xml_file}", "r+") as file:
            content = file.read()

            for i, r in group.iterrows():
                start_char = r["original_start_char"]
                end_char = r["original_end_char"]

                target = r["target"]

                text = r["original_text"]
                if text.startswith(" ") and not target.startswith(" "):
                    target = " " + target
                if text.endswith(" ") and not target.endswith(" "):
                    target = target + " "

                target = re.sub(r"\s+", " ", target)

                content = content[:start_char] + target + content[end_char:]

            # content = re.sub(r"\s+", " ", content)

            # MUST be at the end to dont screw up the indexes
            content = normalize_document(content)
            # content = preserve_whitespace(content)
            # content = preserve_whitespace_and_normalize_spaces(content)

            file.seek(0)  # Move the file pointer to the beginning
            file.write(content)
            file.truncate()

In [None]:
replace_text_in_xml(xml_document["paragraphs"], output_dir)

## Recreate anonymized docx

In [None]:
# Function to add files to a zip archive
def add_files_to_zip(zip_file, directory):
    for root, _, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            zip_file.write(file_path, os.path.relpath(file_path, directory))


# Function to create the DOCX file
def create_docx(xml_directory, output_file):
    # Create a new zip file
    with zipfile.ZipFile(output_file, "w") as docx:
        # Add XML components
        add_files_to_zip(docx, xml_directory)

In [None]:
# Example usage
os.makedirs("anonymized-documents", exist_ok=True)

output_file = "edited.docx"
create_docx(output_dir, f"anonymized-documents/{output_dir}-edited.docx")