# Map OCR style information to CDE entities

* Author: docai-incubator@google.com

## Disclaimer

This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the **DocAI Incubator Team**. No guarantees of performance are implied.


## Objective

This document provides a step-by-step guide to help you add OCR style information for every CDE entity.

## Prerequisites
* Vertex AI JupyterLab Environment
* Google Cloud Storage Bucket
* OCR Processor

## Step by Step procedure 

### 1.Importing Required Modules

In [None]:
!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py

In [None]:
from google.cloud import documentai_v1beta3 as documentai
from google.cloud import storage
from typing import Dict, List, Tuple, Sequence, Any, Optional
from pathlib import Path
from PIL import Image
from google.api_core.client_options import ClientOptions
import json
import io
from io import BytesIO
from utilities import file_names, store_document_as_json

### 2.Setup the inputs

* `project_id` : This is the unique identifier for the Google Cloud project.
* `location` : This specifies the location or region where the resources are located.
* `processor_id` : This is the unique identifier for a processor in Google Cloud.
* `processor_version` : This identifies the specific version of the processor or model you are using.
* `gcs_input_path` : This is the path to a Google Cloud Storage (GCS) bucket and folder where input documents are stored.
* `gcs_output_path` : This is the GCS path where output will be stored.

In [None]:
project_id = "rand-automl-project"
location = "us"
processor_id = "5cda88db4990c164"
processor_version_id = (
    "pretrained-ocr-v2.0-2023-06-02"  # version should be greater or equal to ocr-v2.0
)

gcs_input_path = "gs://nachinta/customer_testing/test_harness/async_batch_process/test_1/iteration_1/12512366521158903846/"  # '/' should be present at the end of the path.
gcs_output_path = "gs://nachinta/customer_testing/test_harness/sync_process/test_1/mapping_test/"  # '/' should be present at the end of the path.

In [None]:
project_id = "project-id"
location = "us"
processor_id = "processor-id"
processor_version = (
    "pretrained-ocr-v2.0-2023-06-02"  # version should be greater or equal to ocr-v2.0
)
gcs_input_path = "gs://{bucket-name}/{sub-folder}/{input-jsons-files}/"  # '/' should be present at the end of the path.
gcs_output_path = "gs://{bucket-name}/{sub-folder}/{output-json-path-to-store}/"  # '/' should be present at the end of the path.

### 3.Run the required functions

In [None]:
def create_pdf_bytes(json: str) -> bytes:
    """
    Creates PDF bytes from image content in a JSON document (typically ground truth data),
    which is used for further processing of files. This function decodes image data and
    combines them into a single PDF.

    Args:
        json (str): The JSON string representing the ground truth data, typically retrieved
        from Google Cloud's Document AI output or other sources. The JSON should contain image data in
        its content field.

    Returns:
        bytes: A byte representation of the generated PDF containing all images.

    Raises:
        ValueError: If no images are found in the input JSON or an invalid image format is encountered.

    Example:
        json_str = '{"pages": [{"image": {"content": "<image_bytes_in_base64>"}}]}'
        pdf_bytes = create_pdf_bytes(json_str)
    """
    from google.cloud import documentai_v1beta3

    def decode_image(image_bytes: bytes) -> Image.Image:
        """Decodes image bytes into a PIL Image object."""
        with io.BytesIO(image_bytes) as image_file:
            image = Image.open(image_file)
            image.load()
        return image

    def create_pdf_from_images(images: Sequence[Image.Image]) -> bytes:
        """Creates a PDF from a sequence of images.

        Args:
            images: A sequence of images to be included in the PDF.

        Returns:
            bytes: The PDF bytes generated from the images.

        Raises:
            ValueError: If no images are provided.
        """
        if not images:
            raise ValueError("At least one image is required to create a PDF")

        # PIL PDF saver does not support RGBA images
        images = [
            image.convert("RGB") if image.mode == "RGBA" else image for image in images
        ]

        with io.BytesIO() as pdf_file:
            images[0].save(
                pdf_file, save_all=True, append_images=images[1:], format="PDF"
            )
            return pdf_file.getvalue()

    d = documentai_v1beta3.Document
    document = d.from_json(json)
    synthesized_images = []
    for i in range(len(document.pages)):
        synthesized_images.append(decode_image(document.pages[i].image.content))
    pdf_bytes = create_pdf_from_images(synthesized_images)

    return pdf_bytes


def process_document_sample(
    project_id: str,
    location: str,
    processor_id: str,
    file_path: str,
    processor_version_id: str,
    mime_type,
) -> documentai.ProcessResponse:
    """
    Processes a document using a specified Document AI processor in Google Cloud and
    returns the processed result. This function reads a file, processes it through a Document AI processor,
    and retrieves the result which may include text extraction, form parsing, etc.

    Args:
        project_id (str): The Google Cloud project ID where the Document AI processor is located.
        location (str): The location/region of the Document AI processor (e.g., 'us', 'eu').
        processor_id (str): The ID of the Document AI processor to use for processing.
        file_path (str): The local path or in-memory string content of the document to be processed.
        processor_version_id (Optional[str], optional): The specific processor version to use, if any.
            If not provided, the default processor version will be used. Defaults to None.
        mime_type (Optional[str], optional): The MIME type of the document. Defaults to 'application/pdf'.
        field_mask (Optional[str], optional): Field mask specifying the parts of the document to process.
            If not provided, the entire document will be processed. Defaults to None.

    Returns:
        documentai.ProcessResponse: The response object containing the processed document data from the processor.
    """
    # You must set the `api_endpoint` if you use a location other than "us".
    opts = ClientOptions(api_endpoint=f"{location}-documentai.googleapis.com")
    client = documentai.DocumentProcessorServiceClient(client_options=opts)
    if processor_version_id:
        name = client.processor_version_path(
            project_id, location, processor_id, processor_version_id
        )
    else:
        name = client.processor_path(project_id, location, processor_id)
    # Read the file into memory
    image_content = file_path
    # Load binary data
    raw_document = documentai.RawDocument(content=image_content, mime_type=mime_type)

    process_options = documentai.ProcessOptions(
        ocr_config=documentai.OcrConfig(
            enable_native_pdf_parsing=False,
            enable_image_quality_scores=False,
            enable_symbol=False,
            # OCR Add Ons https://cloud.google.com/document-ai/docs/ocr-add-ons
            premium_features=documentai.OcrConfig.PremiumFeatures(
                compute_style_info=True,
            ),
        )
    )

    request = documentai.ProcessRequest(
        name=name,
        raw_document=raw_document,
        process_options=process_options,
    )
    result = client.process_document(request=request)

    return result.document


def get_token_xy(token: Any) -> Tuple[float, float, float, float]:
    """
    Extracts the normalized bounding box coordinates (min_x, min_y, max_x, max_y) of a token.

    Args:
    - token (Any): A token object with layout information.

    Returns:
    - Tuple[float, float, float, float]: The normalized bounding box coordinates.

    """
    vertices = token.layout.bounding_poly.normalized_vertices
    minx_token, miny_token = min(point.x for point in vertices), min(
        point.y for point in vertices
    )
    maxx_token, maxy_token = max(point.x for point in vertices), max(
        point.y for point in vertices
    )

    return minx_token, miny_token, maxx_token, maxy_token


def get_token_data(
    json_dict: documentai.Document,
    min_x: float,
    max_x: float,
    min_y: float,
    max_y: float,
    page_num: int,
):
    """
    Extracts token data from the JSON dictionary based on provided bounding box coordinates and page number.

    Args:
    - json_dict (Dict[str, Any]): The JSON dictionary containing token data.
    - min_x (float): Minimum x-coordinate of the bounding box.
    - max_x (float): Maximum x-coordinate of the bounding box.
    - min_y (float): Minimum y-coordinate of the bounding box.
    - max_y (float): Maximum y-coordinate of the bounding box.
    - page_num (int): Page number.

    Returns:
    - Tuple[str, List[Dict[str, Any]], List[Dict[str, float]]]: A tuple containing:
        1. The extracted text from the tokens.
        2. A list of dictionaries containing text anchor data for each token.
        3. A list of dictionaries containing page anchor data.
    """
    try:
        font_details = []

        y_allowance = (
            0.005  # edit this if the line items are closer and your not getitng desir
        )
        x_allowance = 0.005
        for page in json_dict.pages:
            if page_num == page.page_number - 1:
                for token in page.tokens:
                    minx_token, miny_token, maxx_token, maxy_token = get_token_xy(token)
                    if (
                        min_y <= miny_token + y_allowance
                        and max_y >= maxy_token - y_allowance
                        and min_x <= minx_token + x_allowance
                        and max_x >= maxx_token - x_allowance
                    ):
                        # print(token)
                        font_details.append(token.style_info)
    except:
        print("No tokens found in the entity")

    return font_details

### 4.Run the code

In [None]:
if __name__ == "__main__":
    file_name_list, file_dicts = file_names(gcs_input_path)
    storage_client = storage.Client()
    source_bucket = storage_client.bucket(gcs_input_path.split("/")[2])
    for i in file_dicts.values():
        try:
            file_name = ("/").join(i.split("/")[-2:])
            print(
                "Processing File >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ",
                file_name.split("/")[-1],
            )
            js_json = source_bucket.blob(i).download_as_string().decode("utf-8")
            merged_pdf = create_pdf_bytes(js_json)
            js = json.loads(js_json)
            res = process_document_sample(
                project_id=project_id,
                location=location,
                processor_id=processor_id,
                file_path=merged_pdf,
                processor_version_id=processor_version_id,
                mime_type="application/pdf",
            )
            for entity in js["entities"]:
                if "properties" in entity.keys() and len(entity["properties"]) != 0:
                    for child in entity["properties"]:
                        try:
                            norm_vert = child["pageAnchor"]["pageRefs"][0][
                                "boundingPoly"
                            ]["normalizedVertices"]
                            min_x = min([ver["x"] for ver in norm_vert])
                            min_y = min([ver["y"] for ver in norm_vert])
                            max_x = max([ver["x"] for ver in norm_vert])
                            max_y = max([ver["y"] for ver in norm_vert])
                            font_dts = get_token_data(
                                res,
                                min_x,
                                max_x,
                                min_y,
                                max_y,
                                int(
                                    child["pageAnchor"]["pageRefs"][0].get(
                                        "pageNumber", "0"
                                    )
                                ),
                            )
                            temp_list = []
                            if len(font_dts) > 0:
                                for i in font_dts:
                                    temp = {
                                        "fontSize": i.font_size,
                                        "pixelFontSize": i.pixel_font_size,
                                        "fontType": i.font_type,
                                        "fontWeight": i.font_weight,
                                        "handWritten": i.handwritten,
                                        "textColor": {
                                            "red": i.text_color.red,
                                            "green": i.text_color.green,
                                            "blue": i.text_color.blue,
                                        },
                                        "backgroundColor": {
                                            "red": i.background_color.red,
                                            "green": i.background_color.green,
                                            "blue": i.background_color.blue,
                                        },
                                    }
                                    temp_list.append(temp)
                            child["styleInfo"] = temp_list
                        except Exception as e:
                            print("Error:", e)
                else:
                    try:
                        norm_vert = entity["pageAnchor"]["pageRefs"][0]["boundingPoly"][
                            "normalizedVertices"
                        ]
                        min_x = min([ver["x"] for ver in norm_vert])
                        min_y = min([ver["y"] for ver in norm_vert])
                        max_x = max([ver["x"] for ver in norm_vert])
                        max_y = max([ver["y"] for ver in norm_vert])
                        font_dts = get_token_data(
                            res,
                            min_x,
                            max_x,
                            min_y,
                            max_y,
                            int(
                                entity["pageAnchor"]["pageRefs"][0].get(
                                    "pageNumber", "0"
                                )
                            ),
                        )
                        # print(font_dts)
                        temp_list = []
                        if len(font_dts) > 0:
                            for i in font_dts:
                                temp = {
                                    "fontSize": i.font_size,
                                    "pixelFontSize": i.pixel_font_size,
                                    "fontType": i.font_type,
                                    "fontWeight": i.font_weight,
                                    "handWritten": i.handwritten,
                                    "textColor": {
                                        "red": i.text_color.red,
                                        "green": i.text_color.green,
                                        "blue": i.text_color.blue,
                                    },
                                    "backgroundColor": {
                                        "red": i.background_color.red,
                                        "green": i.background_color.green,
                                        "blue": i.background_color.blue,
                                    },
                                }
                                temp_list.append(temp)
                        entity["styleInfo"] = temp_list
                    except Exception as e:
                        print("Error:", e)
            print("Processed Successfully : ", file_name.split("/")[-1])
            store_document_as_json(
                json.dumps(js),
                gcs_output_path.split("/")[2],
                "/".join(gcs_output_path.split("/")[3:]) + file_name.split("/")[-1],
            )
        except Exception as e:
            print("Main Try Error:", e)

### Output

The updated JSONs containing Style information for each entity and will be saved to the specified output folder.

#### Before Tooling JSON file 
<img src="./images/before_tooling.png" width=400 height=200 ></img>
#### After Tooling JSON file 
<img src="./images/after_tooling.png" width=400 height=200 ></img>

In [None]:
"hi"