# Line Items Improver (Post-Processing) User Guide

## Disclaimer

This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the DocAI Incubator Team. No guarantees of performance are implied.	

## Objective

This tool takes in parsed json (prediction result) files and a list of child entities which occurs only once (per line item) in the schema of line_items to better group the child entities into correct line_items(parents).

## Prerequisite

* Vertex AI Notebook
* Parsed json files in a GCS Folder
* List of child entities which occur once(optional once or required once)
* Output folder to upload the updated json files

## Line_item improver logic

<img src="./images/image_1.png" width=800 height=400></img>

<img src="./images/image_2.png" width=800 height=400></img>

## Choosing anchor entity and grouping:

1. As majority count is 5 , the entities(line_item/quantity,line_item/amount) which has minimum ‘y’ is chosen as anchor entity 
2. If there is no majority count then maximum is taken a majority and anchor entity
3. The region from minimum ‘y’ of first occurrence of anchor entity to minimum ‘y’ of next occurrence of anchor entity is considered as line_item_region
4. The line_item/ entities which ever come under this  region are grouped into a single line item
5. This is followed for all the occurrences of anchor entities. 

## Step by Step procedure 

## 1. Input Details

In [None]:
# input
gcs_input_path = "gs://<bucket_name>/<input_folder_name>"  # Parsed json files path, end '/' is mandatory
project_id = "<your_project_id>"  # project ID
gcs_output_path = "gs://<bucket_name>/<output_folder_name>"  # output path where the updated jsons to be saved, end '/' is mandatory

unique_entities = ["product_code", "unit_price", "quantity"]
desc_merge_update = "Yes"  # update to Yes if you want to combine description within the line item, else NO#
line_item_across_pages = (
    "Yes"  # update to Yes if you want to group line items across pages#
)

## 2. Run the Code

Run the below source code to get the updated json files

### Source Code

In [None]:
# Download incubator-tools utilities module to present-working-directory
!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py

In [None]:
!pip install tqdm google-cloud-storage google-cloud-documentai -q

In [None]:
from pprint import pprint
import json
from tqdm import tqdm
from google.cloud import storage
from google.cloud import documentai_v1beta3 as documentai
from pathlib import Path
from collections import Counter, defaultdict
import copy
from collections import Counter
import utilities
from typing import List, Dict, Tuple, Any, Optional, Iterable

# list of child entities in line_item which is optional_once or required_once in schema like below example#
gcs_input_path = "gs://<bucket_name>/<input_folder_name>"  # Parsed json files path, end '/' is mandatory
project_id = "<your_project_id>"  # project ID
gcs_output_path = "gs://<bucket_name>/<output_folder_name>"  # output path where the updated jsons to be saved, end '/' is mandatory

unique_entities = ["product_code", "unit_price", "quantity"]
desc_merge_update = "Yes"  # update to Yes if you want to combine description within the line item, else NO#
line_item_across_pages = (
    "Yes"  # update to Yes if you want to group line items across pages#
)

output_bucket_name = gcs_output_path.split("/")[2]  # Extract the output bucket name

storage_client = storage.Client()

# Get the list of file names and their paths from the input bucket
file_names_list, file_dict = utilities.file_names(gcs_input_path)


def get_page_wise_entities(document) -> Dict[str, List]:
    """
    Organizes entities in a document by the page they appear on.

    Args:
        document: The document object containing entities with page information.

    Returns:
        A dictionary where the keys are page numbers (as strings) and the values
        are lists of entities found on those pages.
    """
    entities_page = {}
    for entity in document.entities:
        page = "0"
        try:
            if "page" in entity.page_anchor.page_refs[0]:
                page = entity.page_anchor.page_refs[0].page

            if page in entities_page:
                entities_page[page].append(entity)
            else:
                entities_page[page] = [entity]
        except:
            pass
    return entities_page


def merge_entities(document):
    """
    Merges entities in a document, especially handling line items across multiple pages.

    Args:
        document: The document object containing entities.

    Returns:
        The document object with merged entities.
    """
    entities_page_wise = get_page_wise_entities(document)
    line_entities_classified_pagewise = []

    for page, entities in entities_page_wise.items():
        line_item_count = line_item_check(entities, unique_entities)

        if line_item_count == 1:
            line_entities_temp = single_line_item_merge(entities, page)
            line_entities_classified_pagewise.append(line_entities_temp)
        elif line_item_count > 1:
            line_entities_temp, considered_boundary_ent = multi_page_entites(
                entities, page
            )
            line_entities_classified_pagewise.extend(line_entities_temp)
        else:
            print("no line items")

    final_entities = [
        entity for entity in document.entities if entity.type != "line_item"
    ]
    final_entities.extend(line_entities_classified_pagewise)
    document.entities = final_entities

    return document


def line_item_check(entities: List, unique_entities: List[str]) -> int:
    """
    Checks line items within a list of entities based on unique entity types.

    Args:
        entities: A list of entities to be checked.
        unique_entities: A list of unique entity types to be considered for checking.

    Returns:
        An integer representing the type of line item found.
        1 for single line item, 2 for multiple, and 0 for none.
    """
    entity_types = [
        subentity.type
        for entity in entities
        if hasattr(entity, "properties")
        for subentity in entity.properties
    ]

    entity_counts = {unique: entity_types.count(unique) for unique in unique_entities}
    multiple_entities_count = sum(count > 1 for count in entity_counts.values())

    if any(count == 1 for count in entity_counts.values()):
        return 1
    elif multiple_entities_count and (
        len(unique_entities) >= 3 or multiple_entities_count >= 1
    ):
        return 2
    else:
        return 0


def single_line_item_merge(entities: List, page: str) -> Dict:
    """
    Merges single line item entities into a unified line item.

    Args:
        entities: A list of entities to be merged.
        page: The page number as a string where these entities are found.

    Returns:
        A dictionary representing the merged line item.
    """
    line_item_sub_entities = [
        subentity
        for entity in entities
        if entity.type == "line_item"
        for subentity in entity.properties
    ]

    text_anchors = [
        item.text_anchor.text_segments[0] for item in line_item_sub_entities
    ]
    normalized_vertices = [
        vertex
        for item in line_item_sub_entities
        for vertex in item.page_anchor.page_refs[0].bounding_poly.normalized_vertices
    ]

    min_x = min(normalized_vertices, key=lambda d: d.x).x
    max_x = max(normalized_vertices, key=lambda d: d.x).x
    min_y = min(normalized_vertices, key=lambda d: d.y).y
    max_y = max(normalized_vertices, key=lambda d: d.y).y
    vertices_final = [
        {"x": min_x, "y": min_y},
        {"x": min_x, "y": max_y},
        {"x": max_x, "y": min_y},
        {"x": max_x, "y": max_y},
    ]

    line_item = {
        "mention_text": " ".join(item.mention_text for item in line_item_sub_entities),
        "page_anchor": {
            "page_refs": [
                {"bounding_poly": {"normalized_vertices": vertices_final}, "page": page}
            ]
        },
        "properties": line_item_sub_entities,
        "text_anchor": {
            "text_segments": sorted(text_anchors, key=lambda x: int(x.end_index))
        },
        "type": "line_item",
    }

    return line_item


def group_line_items(document, schema: Optional[Any] = None) -> Any:
    """
    Groups line items in a document, potentially across multiple pages, based on a schema.

    Args:
        document: The document object containing entities to be grouped.
        schema: (Optional) A schema to guide the grouping. If None, a temporary schema is generated.

    Returns:
        The modified document with grouped line items.
    """
    if schema is None:
        schema = generate_temp_schema(document)

    line_items_by_page = sort_line_items_by_page(document)
    groups_across = find_line_item_groups_across_pages(line_items_by_page)
    schema_across = get_schema_across_groups(groups_across)
    entity_spread = calculate_entity_spread(schema_across, schema)

    if entity_spread:
        document = move_entities_across_pages(entity_spread, groups_across, document)

    return document


def generate_temp_schema(document) -> dict:
    """
    Generates a temporary schema for line items in a document based on the most common types.

    Args:
        document: The document object containing entities.

    Returns:
        A dictionary representing the majority schema of line item types.
    """
    line_items = [
        entity for entity in document.entities if hasattr(entity, "properties")
    ]
    type_counts = [
        Counter([child.type for child in item.properties]) for item in line_items
    ]
    merged_counts = Counter()
    for type_count in type_counts:
        for key, value in type_count.items():
            merged_counts[key] += value
    majority_schema = {
        key: value.most_common(1)[0][0]
        for key, value in merged_counts.items()
        if isinstance(value, Counter)
    }
    return majority_schema


def sort_line_items_by_page(document):
    """
    Sorts line items in a document by their page and y-coordinate.

    Args:
        document: The document object containing entities.

    Returns:
        A mapping each page to its first and last line items based on y-coordinates.
    """
    line_items_by_page = defaultdict(lambda: {"first": None, "last": None})
    for entity in document.entities:
        if hasattr(entity, "properties"):
            page = str(
                entity.page_anchor.page_refs[0].page
                if entity.page_anchor.page_refs
                else 0
            )
            y_coords = [
                vertex.y
                for vertex in entity.page_anchor.page_refs[
                    0
                ].bounding_poly.normalized_vertices
            ]
            update_page_sort(line_items_by_page[page], "last", max(y_coords), entity)
            update_page_sort(line_items_by_page[page], "first", min(y_coords), entity)
    return line_items_by_page


def update_page_sort(
    page_dict: dict, position: str, y_value: float, entity: Any
) -> None:
    """
    Updates a page dictionary with a new line item based on y-coordinate and position.

    Args:
        page_dict: The dictionary representing a page's line items.
        position: The position to update ('first' or 'last').
        y_value: The y-coordinate of the line item.
        entity: The line item entity.
    """
    if page_dict[position] is None or compare_y(page_dict[position], y_value, position):
        page_dict[position] = {"y_value": y_value, "entity": entity}


def compare_y(current_dict: dict, new_y: float, position: str) -> bool:
    """
    Compares y-coordinates to determine if a new line item should update the current one.

    Args:
        current_dict: The current line item information.
        new_y: The new y-coordinate.
        position: The position being compared ('first' or 'last').

    Returns:
        A boolean indicating if the new line item should replace the current one.
    """
    return (
        (new_y >= current_dict["y_value"])
        if position == "last"
        else (new_y <= current_dict["y_value"])
    )


def find_line_item_groups_across_pages(sorted_items: dict) -> dict:
    """
    Identifies groups of line items across pages in a sorted items structure.

    Args:
        sorted_items: A dictionary of sorted line items by page.

    Returns:
        A dictionary mapping pages to their associated group of line items across pages.
    """
    groups_across = {}
    for page, items in sorted_items.items():
        next_page = str(int(page) + 1)
        if next_page in sorted_items:
            groups_across[next_page] = [
                items["last"]["entity"],
                sorted_items[next_page]["first"]["entity"],
            ]
    return groups_across


def get_schema_across_groups(
    groups_across: Dict[str, Any]
) -> Dict[str, Dict[int, Counter]]:
    """
    Generates a schema across groups of line items, counting property types for each group.

    Args:
        groups_across: A dictionary mapping group identifiers to lists of line item entities.

    Returns:
        A dictionary representing the schema across groups, with counts of property types.
    """
    schema_across = {}
    for group, matches in groups_across.items():
        for i, match in enumerate(matches):
            for property in match.properties:
                property_type = property.type
                schema_across.setdefault(group, {}).setdefault(i, Counter())[
                    property_type
                ] += 1
    return schema_across


def calculate_entity_spread(
    schema_across: Dict[str, Dict[int, Counter]], schema: Dict[str, int]
) -> Dict[str, Dict[str, int]]:
    """
    Calculates the spread of entities across the schema, identifying missing and excess entities.

    Args:
        schema_across: A dictionary representing the schema across groups.
        schema: A dictionary representing the majority schema of line item types.

    Returns:
        A dictionary indicating the spread of entities in terms of missing and excess counts.
    """
    entity_spread = {}
    for group, schemas in schema_across.items():
        missing_entities = {key: schema[key] - schemas[0].get(key, 0) for key in schema}
        for entity_type, count in missing_entities.items():
            if count > 0 and entity_type in schemas[1]:
                excess_count = schemas[1][entity_type] - schema[entity_type]
                if excess_count > 0 or len(schemas[1]) < (len(schema) / 2):
                    entity_spread.setdefault(group, {})[entity_type] = excess_count
    return entity_spread


def move_entities_across_pages(
    entity_spread: Dict[str, Dict[str, int]],
    groups_across: Dict[str, Any],
    document: Any,
) -> Any:
    """
    Moves entities across pages in a document based on the entity spread.

    Args:
        entity_spread: A dictionary indicating the spread of entities.
        groups_across: A dictionary of line item groups across pages.
        document: The document object containing entities to be moved.

    Returns:
        The modified document with entities moved across pages.
    """
    for group, _ in entity_spread.items():
        page_1 = get_page_number(groups_across[group][0])
        page_2 = get_page_number(groups_across[group][1])
        if page_1 < page_2:
            entities_to_move = get_entities_to_shuffle(group, 1, entity_spread)
            temp_group = move_entities(group, 1, 0, entities_to_move, groups_across)
        else:
            entities_to_move = get_entities_to_shuffle(group, 0, entity_spread)
            temp_group = move_entities(group, 0, 1, entities_to_move, groups_across)
        document.entities.extend(
            update_json_entities(document.entities, groups_across[group], temp_group)
        )
    return document


def get_page_number(entity: Any) -> str:
    """
    Retrieves the page number from a given entity.

    Args:
        entity: The entity from which the page number is to be retrieved.

    Returns:
        The page number as a string. Returns '0' if no page reference is found.
    """
    return entity.page_anchor.page_refs[0].page if entity.page_anchor.page_refs else "0"


def get_entities_to_shuffle(
    group: str, index: int, entity_spread: Dict[str, Dict[str, int]]
) -> list:
    """
    Determines which entities need to be shuffled within a group.

    Args:
        group: The identifier of the group where entities are to be shuffled.
        index: The index within the group to consider for shuffling.
        entity_spread: A dictionary indicating the spread of entities.

    Returns:
        A list of entities that need to be shuffled.
    """
    entities = []
    for entity_type, count in entity_spread[group].items():
        entities.extend(get_sorted_entities_by_y(group, index, entity_type, count))
    return entities


def get_sorted_entities_by_y(
    group: str, index: int, entity_type: str, count: int
) -> List[Any]:
    """
    Retrieves and sorts entities of a certain type within a group based on their y-coordinate.

    Args:
        group: The identifier of the group to search within.
        index: The index within the group to consider.
        entity_type: The type of entity to filter and sort.
        count: The number of entities to retrieve.

    Returns:
        A list of entities sorted by their y-coordinate.
    """
    entities_by_y = {}
    for property in groups_across[group][index].properties:
        if property.type == entity_type:
            y_value = min(
                vertex.y
                for vertex in property.page_anchor.page_refs[
                    0
                ].bounding_poly.normalized_vertices
            )
            entities_by_y.setdefault(y_value, []).append(property)
    return [
        entity for y in sorted(entities_by_y)[:count] for entity in entities_by_y[y]
    ]


def move_entities(
    group: str,
    from_index: int,
    to_index: int,
    entities_to_move: List[Any],
    groups_across: Dict[str, Any],
) -> Iterable:
    """
    Moves entities from one index to another within a group.

    Args:
        group: The identifier of the group where entities are to be moved.
        from_index: The index from which entities are to be moved.
        to_index: The index to which entities are to be moved.
        entities_to_move: A list of entities to move.
        groups_across: A dictionary of groups with their associated entities.

    Returns:
        An iterable containing the updated group entities.
    """
    temp_group = copy.deepcopy(groups_across[group])
    remove_entities_from_group(entities_to_move, temp_group[from_index].properties)
    temp_group[to_index].properties.extend(entities_to_move)
    update_page_text(temp_group[to_index])
    return temp_group.values()


def remove_entities_from_group(
    entities_to_remove: List[Any], properties: List[Any]
) -> None:
    """
    Removes specified entities from a list of properties.

    Args:
        entities_to_remove: A list of entities to be removed.
        properties: A list of properties/entities from which to remove the specified entities.
    """
    for entity in entities_to_remove:
        properties.remove(entity)


def update_page_text(entity: Any) -> None:
    """
    Updates the text of an entity based on its properties.

    Args:
        entity: The entity whose text is to be updated.
    """
    segments = [
        segment
        for property in entity.properties
        for segment in property.text_anchor.text_segments
    ]
    entity.text_anchor.text_segments = sorted(segments, key=lambda x: x.start_index)
    entity.mention_text = " ".join(
        segment.content for segment in entity.text_anchor.text_segments
    )


def update_json_entities(
    current_entities: List[Any], original_group: Iterable, updated_group: Iterable
) -> List[Any]:
    """
    Updates a list of entities by removing the original group and adding the updated group.

    Args:
        current_entities: The current list of entities.
        original_group: The original group of entities to be removed.
        updated_group: The updated group of entities to be added.

    Returns:
        The updated list of entities.
    """
    for entity in original_group:
        current_entities.remove(entity)
    current_entities.extend(updated_group)
    return current_entities


def extract_entity_types_and_subentities(
    entities_pagewise: List[Any],
) -> (List[str], List[Any]):
    """
    Extracts entity types and subentities from a list of entities.

    Args:
        entities_pagewise: A list of entities, potentially including line items with subentities.

    Returns:
        A tuple containing a list of entity types and a list of line item subentities.
    """
    entity_types = []
    line_item_sub_entities = []
    for entity in entities_pagewise:
        if hasattr(entity, "properties"):
            if entity.type == "line_item":
                for subentity in entity.properties:
                    entity_types.append(subentity.type)
                    line_item_sub_entities.append(subentity)
        else:
            entity_types.append(entity.type)
    return entity_types, line_item_sub_entities


def count_unique_entities(entity_types: List[str]) -> Dict[str, int]:
    """
    Counts the occurrences of each unique entity type in a given list and returns a dictionary
    containing only those entities which appear more than once.

    Args:
    entity_types (List[str]): A list of entity types.

    Returns:
    Dict[str, int]: A dictionary where keys are entity types that occur more than once, and
                    values are the counts of these occurrences.
    """
    entity_type_counts = Counter(entity_types)
    return {entity: count for entity, count in entity_type_counts.items() if count > 1}


def get_entity_types_with_max_count(line_items_multi_dict: Dict[str, int]) -> List[str]:
    """
    Identifies the entity types with the maximum occurrence count in a dictionary.

    Args:
    line_items_multi_dict (Dict[str, int]): A dictionary where keys are entity types and
                                            values are their occurrence counts.

    Returns:
    List[str]: A list of entity types that have the maximum count in the given dictionary.
    """
    value_counts = Counter(line_items_multi_dict.values())
    max_count = max(value_counts.values())
    return [
        entity_type
        for entity_type, count in line_items_multi_dict.items()
        if value_counts[count] == max_count
    ]


def find_optimal_region(
    dict_unique_ent: Dict[str, List[object]]
) -> Tuple[Dict[str, List[float]], Dict[str, float]]:
    """
    Calculates the optimal region for different entity types based on their bounding box coordinates.

    Args:
    dict_unique_ent (Dict[str, List[object]]): A dictionary where keys are entity types and
                                               values are lists of sub-entity objects.

    Returns:
    Tuple[Dict[str, List[float]], Dict[str, float]]: A tuple containing two dictionaries.
                                                     The first dictionary maps entity types to sorted y-coordinates.
                                                     The second dictionary maps entity types to the maximum difference
                                                     in y-coordinates.
    """
    region_line_items_x = {}
    region_line_items_y = {}
    opt_region = {}
    for ent_type, sub_entities in dict_unique_ent.items():
        min_x_1 = [
            min(
                vertex.x
                for vertex in item.page_anchor.page_refs[
                    0
                ].bounding_poly.normalized_vertices
            )
            for item in sub_entities
        ]
        min_y_1 = [
            min(
                vertex.y
                for vertex in item.page_anchor.page_refs[
                    0
                ].bounding_poly.normalized_vertices
            )
            for item in sub_entities
        ]
        region_line_items_x[ent_type] = min_x_1
        region_line_items_y[ent_type] = min_y_1
        opt_region[ent_type] = max(min_y_1) - min(min_y_1)
    return {k: sorted(v) for k, v in region_line_items_y.items()}, opt_region


def classify_line_items(
    sub_entities: List[object], regions_line_y: Dict[str, List[float]]
) -> Tuple[Dict[str, List[object]], List[object]]:
    """
    Classifies sub-entities into different line items based on their y-coordinate regions.

    Args:
    sub_entities (List[object]): A list of sub-entity objects.
    regions_line_y (Dict[str, List[float]]): A dictionary mapping line numbers to y-coordinate regions.

    Returns:
    Tuple[Dict[str, List[object]], List[object]]: A tuple containing a dictionary of classified line items and
                                                  a list of categorized sub-entities.
    """
    line_item_dict_final = {}
    sub_entities_categorized = []

    for subentity in sub_entities:
        y_ent = [
            vertex.y
            for vertex in subentity.page_anchor.page_refs[
                0
            ].bounding_poly.normalized_vertices
        ]
        for line_no, region in regions_line_y.items():
            # Check if region is a tuple (meaning it has a start and end)
            if isinstance(region, tuple):
                # If y_ent is within the region bounds
                if region[0] <= min(y_ent) < region[1]:
                    line_item_dict_final.setdefault(line_no, []).append(subentity)
                    sub_entities_categorized.append(subentity)
            else:
                # If y_ent is greater than or equal to the single value region
                if min(y_ent) >= region[0]:  # Assuming region is a tuple with one value
                    line_item_dict_final.setdefault(line_no, []).append(subentity)
                    sub_entities_categorized.append(subentity)
    return line_item_dict_final, sub_entities_categorized


def create_lineitem(sub_entities: List[object], page: int) -> Dict[str, object]:
    """
    Creates a line item from a list of sub-entities.

    Args:
    sub_entities (List[object]): A list of sub-entity objects.
    page (int): The page number where the line item is located.

    Returns:
    Dict[str, object]: A dictionary representing a line item with various properties.
    """
    mention_text = " ".join(e.mention_text for e in sub_entities)
    bounding_vertices = [
        vertex
        for e in sub_entities
        for vertex in e.page_anchor.page_refs[0].bounding_poly.normalized_vertices
    ]
    text_segments = [
        segment for e in sub_entities for segment in e.text_anchor.text_segments
    ]

    x_coords = [vertex.x for vertex in bounding_vertices]
    y_coords = [vertex.y for vertex in bounding_vertices]
    bounding_poly = [
        {"x": min(x_coords), "y": min(y_coords)},
        {"x": max(x_coords), "y": max(y_coords)},
    ]

    line_item = {
        "mention_text": mention_text,
        "page_anchor": {
            "page_refs": [
                {"bounding_poly": {"normalized_vertices": bounding_poly}, "page": page}
            ]
        },
        "properties": sub_entities,
        "text_anchor": {"text_segments": text_segments},
        "type": "line_item",
    }
    return line_item


def multi_page_entites(
    entities_pagewise: List[object], page: int
) -> Tuple[List[Dict], str]:
    """
    Processes entities page-wise to classify and create line items.

    Args:
    entities_pagewise (List[object]): A list of entity objects per page.
    page (int): The page number of the entities.

    Returns:
    Tuple[List[Dict], str]: A tuple containing a list of classified line items and the entity type
                             with the maximum boundary.
    """
    entity_types, line_item_sub_entities = extract_entity_types_and_subentities(
        entities_pagewise
    )

    line_items_multi_dict = count_unique_entities(entity_types)
    entity_types_keys = get_entity_types_with_max_count(line_items_multi_dict)

    dict_unique_ent = {key: [] for key in entity_types_keys}
    for entity in line_item_sub_entities:
        entity_type = entity.type
        if entity_type in dict_unique_ent:
            dict_unique_ent[entity_type].append(entity)

    sorted_region_line_y, opt_region = find_optimal_region(dict_unique_ent)

    considered_boundry_ent = max(opt_region, key=opt_region.get)
    regions_line_y_final = {
        ent: {
            i: (start, end)
            for i, (start, end) in enumerate(zip(regions, regions[1:]), 1)
        }
        for ent, regions in sorted_region_line_y.items()
        if ent == considered_boundry_ent
    }

    line_item_dict_final, sub_entities_categorized = classify_line_items(
        line_item_sub_entities, regions_line_y_final[considered_boundry_ent]
    )

    line_items_classified = [
        create_lineitem(sub_entities, page)
        for sub_entities in line_item_dict_final.values()
    ]

    return line_items_classified, considered_boundry_ent


def desc_merge_update(document: object) -> object:
    """
    Merges and updates the description of entities within a document.

    Args:
    document (object): The document containing entities to be processed.

    Returns:
    object: The updated document with merged entity descriptions.
    """

    def desc_merge_1(ent_desc):
        desc_merge = documentai.types.documentai_pb2.Document.Entity()
        desc_merge.type_ = "line_item/description"
        desc_merge.mention_text = ""

        text_anchors, pagerefs = extract_text_anchors_and_pagerefs(ent_desc)
        desc_merge.text_anchor.text_segments.extend(text_anchors)
        desc_merge.page_anchor.page_refs.extend(pagerefs)

        sorted_list_text = sorted(text_anchors, key=lambda x: int(x.end_index))

        desc_mention_text, desc_normalizedvertices = merge_descriptions(
            sorted_list_text, ent_desc
        )
        desc_merge.mention_text = desc_mention_text
        desc_merge.page_anchor.page_refs[0].bounding_poly.normalized_vertices.extend(
            desc_normalizedvertices
        )

        return desc_merge

    def extract_text_anchors_and_pagerefs(ent_desc):
        text_anchors = []
        pagerefs = []
        for item in ent_desc:
            text_anchors.append(item.text_anchor.text_segments[0])
            pagerefs.extend(item.page_anchor.page_refs)
        return text_anchors, pagerefs

    def merge_descriptions(sorted_list_text, ent_desc):
        desc_mention_text = ""
        desc_normalizedvertices = []
        for index in sorted_list_text:
            for item in ent_desc:
                if index in item.text_anchor.text_segments:
                    desc_mention_text += " " + item.mention_text
                    desc_normalizedvertices.extend(
                        item.page_anchor.page_refs[0].bounding_poly.normalized_vertices
                    )
        return desc_mention_text, desc_normalizedvertices

    def calculate_bounding_box(vertices):
        min_x = min(vertices, key=lambda d: d.x).x
        max_x = max(vertices, key=lambda d: d.x).x
        min_y = min(vertices, key=lambda d: d.y).y
        max_y = max(vertices, key=lambda d: d.y).y
        return [
            documentai.types.documentai_pb2.Document.Page.Layout.NormalizedVertex(
                x=min_x, y=min_y
            ),
            documentai.types.documentai_pb2.Document.Page.Layout.NormalizedVertex(
                x=min_x, y=max_y
            ),
            documentai.types.documentai_pb2.Document.Page.Layout.NormalizedVertex(
                x=max_x, y=min_y
            ),
            documentai.types.documentai_pb2.Document.Page.Layout.NormalizedVertex(
                x=max_x, y=max_y
            ),
        ]

    for entity in document.entities:
        if entity.type == "line_item":
            ent_desc = [
                prop
                for prop in entity.properties
                if prop.type == "line_item/description"
            ]
            if len(ent_desc) > 1:
                desc_merge = desc_merge_1(ent_desc)
                entity.properties = [
                    prop for prop in entity.properties if prop not in ent_desc
                ]
                entity.properties.append(desc_merge)

    return document


for filename, filepath in tqdm(file_dict.items(), desc="Progress"):
    if ".json" in filepath:
        # Construct the full file path
        full_file_path = f"gs://{gcs_input_path.split('/')[2]}/{filepath}"
        print(filename)
        # Download the JSON file content
        bucket = storage_client.get_bucket(gcs_input_path.split("/")[2])
        blob = bucket.blob(filepath)
        json_dict = json.loads(blob.download_as_string().decode("utf-8"))

        print(json_dict.keys())
        # Perform the necessary processing with json_dict
        document = documentai.Document.from_json(json.dumps(json_dict))
        # print(document)
        json_dict_updated = merge_entities(document)

        if line_item_across_pages == "Yes":
            json_dict_updated = group_line_items(json_dict_updated)

        if desc_merge_update == "Yes":
            json_dict_updated = desc_merge_update(json_dict_updated)

        # Define the full path within the output bucket where the file should be saved
        output_path_within_bucket = "/".join(gcs_output_path.split("/")[3:]) + filename

        print(type(json_dict_updated))
        temp_dict = json.loads(documentai.Document.to_json(json_dict_updated))

        # Save the updated JSON back to the output bucket using the provided function
        updated_document_json = json.dumps(temp_dict, ensure_ascii=False)
        utilities.store_document_as_json(
            updated_document_json, output_bucket_name, output_path_within_bucket
        )

## 3. Output

### Groups the child items of line_item into right parent (line_item)
### Before post processing the parsed json has multiple line items without proper grouping

<img src="./images/line_item_improver_pre_sample.png" width=800 height=400></img>

### After post processing, properly grouped line items

<img src="./images/line_item_improver_pre_sample.png" width=800 height=400></img>