# Line Items Improver (Post-Processing) User Guide

- Author: docai-incubator@google.com

## Purpose and Description

This script will guide to group the child entities into correect line items.

This tool takes in parsed json (prediction result) files and a list of child entities which occurs only once (per line item) in the schema of line_items to better group the child entities into correct line_items(parents).



## Prerequisites 
1. Access to a Google Cloud project to create Document AI processors.
   - Permission to Google project is needed to access Document AI processors.
1. Python: Jupyter notebook (Vertex AI) or Google Colab.
2. List of child entities which occur once(optional once or required once)
3. Output folder to upload the updated json files


## Tool Operation Procedure

### 1. Install required libraries

In [1]:
%pip install google-cloud-documentai
%pip install google-cloud-storage
%pip install google-api-core
%pip install json
%pip install collections
%pip install copy
%pip install tqdm
%pip install pprint


[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0m[31mERROR: Could not find a version that satisfies the requirement json (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for json[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
[0m[31mERROR: Could not find a version that satisfies the requirement collections (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for collections[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
[0m[31mERROR: Could not find a version that satisfies the requirement copy (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for copy[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to

### 2. Import Packages

In [2]:
from google.cloud import storage
import gcsfs
import json
from collections import Counter
import copy
from pprint import pprint
from tqdm import tqdm

### 3. Input Details

In [3]:
#input
Gcs_input_path = 'gs://xxxx/xxxx/'  # Parsed json files path , end '/' is mandatory
project_id = 'xxxx-xxxx-xxxx'  # project ID
Gcs_output_path = 'gs://xxxxxxx/xxxxx/xxxxxxx/'  #output path where the updated jsons to be saved, end '/' is mandatory
#list of child entities in line_item which is optional_once or required_once in schema like below example#
unique_entities = [
    'line_item/product_code', 'line_item/unit_price', 'line_item/quantity'
]  #example given
## FOR CDE outout remove line_item as child entity name doesnt contain any parent name in it##
Desc_merge_update = 'Yes'  # update to Yes if you want to combine description within the line item, else NO#
line_item_across_pages = 'Yes'  # update to Yes if you want to group line items across pages#

### 4. Run the functions

In [4]:
def file_names(file_path):
    """This Function will load the bucket and get the list of files
    in the gs path given
    args: gs path
    output: file names as list and dictionary with file names as keys and file path as values"""

    bucket = file_path.split("/")[2]
    file_names_list = []
    file_dict = {}
    storage_client = storage.Client()
    source_bucket = storage_client.get_bucket(bucket)
    filenames = [
        filename.name for filename in list(
            source_bucket.list_blobs(
                prefix=(('/').join(file_path.split('/')[3:]))))
    ]
    for i in range(len(filenames)):
        x = filenames[i].split('/')[-1]
        if x != "":
            file_names_list.append(x)
            file_dict[x] = filenames[i]
    return file_names_list, file_dict


def load_json(path):
    """ This Function will load the json from the gs path 
    
    args: gs path
    output: loaded json """

    gcs_file_system = gcsfs.GCSFileSystem(project=project_id)
    with gcs_file_system.open(path) as f:
        json_l = json.load(f)
    return json_l


def line_item_check(entities, unique_entities):
    """ 
        This Function is to check the number of line items in a json
    args:entities and unique entities list
    output: line item count
    
    """

    entity_types = []
    for entity in entities:
        if 'properties' in entity.keys() and entity['properties'] != []:
            for subentity in entity['properties']:
                entity_types.append(subentity['type'])
        else:
            pass
            #entity_types.append(entity['type'])

    line_items_count = 0
    entity_unique_count = {}

    for unique in unique_entities:
        entity_unique_count[unique] = entity_types.count(unique)
    li_en = []
    #print(entity_unique_count)
    for i in list(entity_unique_count.values()):
        if i > 1:
            li_en.append(i)

    #print(li_en)
    for unique in unique_entities:
        if entity_types.count(unique) < 1:
            continue
        elif entity_types.count(unique) == 1:
            line_items_count = 1
        elif len(unique_entities) >= 3:
            if entity_types.count(unique) > 1 and len(li_en) >= 1:
                line_items_count = 2
                break
        elif entity_types.count(unique) > 1 and len(li_en) >= 1:
            line_items_count = 2
            break

    #print(line_items_count)
    return line_items_count


def single_line_item_merge(entities, page):
    """ 
    This Function will group the line item if there is only single line item in a document 
    
    args: entities and page number
    output: grouped entities
            """

    line_item_entity = []
    line_item_sub_entities = []
    for entity in entities:
        if entity['type'] == 'line_item':
            for subentity in entity['properties']:
                line_item_sub_entities.append(subentity)
            line_item_entity.append(entity)

    line_item = {
        'mentionText': '',
        'pageAnchor': {
            'pageRefs': [{
                'boundingPoly': {
                    'normalizedVertices': []
                },
                'page': page
            }]
        },
        'properties': [],
        'textAnchor': {
            'textSegments': []
        },
        'type': 'line_item'
    }

    text_anchors_sub_entities = []
    for item in line_item_sub_entities:
        text_anchors_sub_entities.append(item['textAnchor']['textSegments'][0])
    line_item['textAnchor']['textSegments'] = text_anchors_sub_entities
    sorted_list_text1 = sorted(text_anchors_sub_entities,
                               key=lambda x: int(x['endIndex']))

    line_item_mention_text = ''
    line_item_properties = []
    line_item_text_segments = []
    line_item_normalizedvertices = []
    subentities_classified = []
    for index in sorted_list_text1:
        for item in line_item_sub_entities:
            if index in item['textAnchor']['textSegments']:
                if item not in subentities_classified:
                    subentities_classified.append(item)
                    line_item_mention_text = line_item_mention_text + ' ' + item[
                        'mentionText']
                    line_item_properties.append(item)
                    line_item_text_segments.append(index)
                    for i in item['pageAnchor']['pageRefs'][0]['boundingPoly'][
                            'normalizedVertices']:
                        line_item_normalizedvertices.append(i)
    min_x_ln = min(line_item_normalizedvertices, key=lambda d: d['x'])['x']
    max_x_ln = max(line_item_normalizedvertices, key=lambda d: d['x'])['x']
    min_y_ln = min(line_item_normalizedvertices, key=lambda d: d['y'])['y']
    max_y_ln = max(line_item_normalizedvertices, key=lambda d: d['y'])['y']
    line_item_normalizedvertices_final = [{
        'x': min_x_ln,
        'y': min_y_ln
    }, {
        'x': min_x_ln,
        'y': max_y_ln
    }, {
        'x': max_x_ln,
        'y': min_y_ln
    }, {
        'x': max_x_ln,
        'y': max_y_ln
    }]

    line_item['pageAnchor']['pageRefs'][0]['boundingPoly'][
        'normalizedVertices'] = line_item_normalizedvertices_final
    line_item['mentionText'] = line_item_mention_text
    line_item['properties'] = line_item_properties
    line_item['textAnchor']['textSegments'] = line_item_text_segments

    return line_item


#line_items_across pages grouping
def get_lineitems_grouped_pages_across(json_dict, schema={}):
    """ 
    This Function will group the line items which are across pages 
    line items continued into next page
    
    args: loaded json and optional line item schems
    
    output: json with grouped entities
    
    """

    #this function will be used to group the line item entities across pages
    def get_line_items_temp_schema(json_dict):
        """ 
        This Function will get the schema of the line item based on the line items available in the json
        (if there is no schema provided to the  main function)

        args: loaded json

        output: line item schema

        """

        line_items = []
        for entity in json_dict['entities']:
            if 'properties' in entity.keys():
                if entity['type'] == 'line_item':
                    line_items.append(entity)
        x = 1
        line_types = {}
        for i in line_items:
            line_types['line' + '_' + str(x)] = {}
            for child in i['properties']:
                if child['type'] in line_types['line' + '_' + str(x)].keys():
                    line_types['line' + '_' + str(x)][
                        child['type']] = line_types['line' + '_' +
                                                    str(x)][child['type']] + 1
                else:
                    line_types['line' + '_' + str(x)][child['type']] = 1
            x += 1

        all_child = []
        for key, val in line_types.items():
            #print(val)
            all_child.append(val)
        counts = {
            k: Counter([d[k] for d in all_child if k in d])
            for k in set().union(*all_child)
        }
        #print(line_items)
        # Get the majority value for each key without any minimum count criteria,*****ISSUES WITH OUTLIERS OR PARSER ISSUES********
        temp_schema = {k: max(counts[k], key=counts[k].get) for k in counts}

        #considered if atleast half of line items have that entity inside line item
        #temp_schema = {key: value for key, counter in counts.items() for value in counter.keys() if counter[value] >=(round((len(all_child)/2),0))}

        #considered if atleast 3 of line items have that entity inside line item
        #temp_schema = {key: value for key, counter in counts.items() for value in counter.keys() if counter[value] >2}
        #print(temp_schema)

        return temp_schema

    if schema == {}:
        schema = get_line_items_temp_schema(json_dict)

    line_items_1 = []
    for entity in json_dict['entities']:
        if 'properties' in entity.keys() and len(entity['properties']) != 0:
            line_items_1.append(entity)

    line_item_sorted_first = {}
    line_item_sorted_last = {}
    line_item_across_pages_first = {}
    line_item_across_pages_last = {}
    for li_1 in line_items_1:
        try:
            page = li_1['pageAnchor']['pageRefs'][0]['page']
        except:
            page = str(0)
        max_y = max(vertex['y'] for vertex in li_1['pageAnchor']['pageRefs'][0]
                    ['boundingPoly']['normalizedVertices'])
        min_y = min(vertex['y'] for vertex in li_1['pageAnchor']['pageRefs'][0]
                    ['boundingPoly']['normalizedVertices'])
        if page in line_item_sorted_last.keys():
            if line_item_sorted_last[page] >= max_y:
                pass
            else:
                line_item_sorted_last[page] = max_y
                line_item_across_pages_last[page] = li_1
        else:
            line_item_sorted_last[page] = max_y
            line_item_across_pages_last[page] = li_1
        if page in line_item_sorted_first.keys():
            if line_item_sorted_first[page] <= min_y:
                pass
            else:
                line_item_sorted_first[page] = min_y
                line_item_across_pages_first[page] = li_1
        else:
            line_item_sorted_first[page] = min_y
            line_item_across_pages_first[page] = li_1

    groups_across = {}
    p = 0

    for page_last, ent_last in line_item_across_pages_last.items():
        try:
            groups_across[p] = [
                line_item_across_pages_last[page_last],
                line_item_across_pages_first[str(int(page_last) + 1)]
            ]
            p += 1
        except KeyError:
            pass
    #getting schema of each line item in groups
    schema_across = {}
    for group, match in groups_across.items():
        for i in range(len(match)):
            for subitem in match[i]['properties']:
                if group in schema_across.keys():
                    if i in schema_across[group].keys():
                        if subitem['type'] in schema_across[group][i].keys():
                            schema_across[group][i][subitem['type']] += 1
                            #print('yes')
                        else:
                            schema_across[group][i][subitem['type']] = 1
                            #print(subitem['type'])
                    else:
                        schema_across[group][i] = {subitem['type']: 1}
                        #print('yes')
                else:
                    schema_across[group] = {i: {subitem['type']: 1}}
                    #print('yes')
    group_entites_spread = {}
    for selected_group, schema_ent in schema_across.items():
        missing_ent_0 = {}
        for key in schema.keys():
            if key not in schema_ent[0]:
                missing_ent_0[key] = schema[key]
            else:
                missing_ent_0[key] = schema[key] - schema_ent[0][key]
        for k1, v1 in missing_ent_0.items():
            if v1 > 0:
                if k1 in schema_ent[1]:
                    if schema_ent[1][k1] > schema[k1] or len(
                            schema_ent[1]) < (len(schema) / 2):
                        if selected_group in group_entites_spread.keys():
                            group_entites_spread[selected_group][
                                k1] = schema_ent[1][k1] - schema[k1]
                        else:
                            if len(schema_ent[1]) < (len(schema) / 2):
                                group_entites_spread[selected_group] = {
                                    k1: schema_ent[1][k1]
                                }
                            else:
                                group_entites_spread[selected_group] = {
                                    k1: schema_ent[1][k1] - schema[k1]
                                }

    def get_ent_schffle(group, index_1):
        """ 
        
        Internal function to find the right grouping of line items 
        
        """
        ent_min_y = {}
        ent_sort_ent = {}
        for sube1 in groups_across[group][index_1]['properties']:
            if sube1['type'] in group_entites_spread[group].keys():
                min_y_temp = min(vertex['y']
                                 for vertex in sube1['pageAnchor']['pageRefs']
                                 [0]['boundingPoly']['normalizedVertices'])
                if sube1['type'] in ent_min_y.keys():
                    ent_min_y[sube1['type']].append(min_y_temp)
                    if sube1['type'] in ent_sort_ent.keys():
                        ent_sort_ent[sube1['type']][min_y_temp] = sube1
                    else:
                        ent_sort_ent[sube1['type']] = {min_y_temp: sube1}
                else:
                    ent_min_y[sube1['type']] = [min_y_temp]
                    if sube1['type'] in ent_sort_ent.keys():
                        ent_sort_ent[sube1['type']][min_y_temp] = sube1
                    else:
                        ent_sort_ent[sube1['type']] = {min_y_temp: sube1}

        sorted_ent_min_y = {
            key: sorted(values)
            for key, values in ent_min_y.items()
        }
        ent_shuffle = []
        for en1, val1 in sorted_ent_min_y.items():
            b = 0
            for num in range(group_entites_spread[group][en1]):
                for miny in range(len(sorted_ent_min_y[en1])):
                    if num >= b:
                        ent_shuffle.append(
                            ent_sort_ent[en1][sorted_ent_min_y[en1][miny]])
                        b += 1
        return ent_shuffle

    def ent_move(group, index_1, index_0):
        """ 
        
        Internal function to find the right grouping of line items 
        
        """

        temp_group = copy.deepcopy(groups_across[group])
        for ent_sh in ent_shuffle:
            for sub_en3 in temp_group[index_1]['properties']:
                if ent_sh['type'] == sub_en3['type'] and ent_sh[
                        'textAnchor'] == sub_en3['textAnchor']:
                    temp_group[index_1]['properties'].remove(ent_sh)
            temp_group[index_0]['properties'].append(ent_sh)
            for t1 in ent_sh['pageAnchor']['pageRefs']:
                temp_group[index_0]['pageAnchor']['pageRefs'].append(t1)
            temp_group[index_0]['mentionText'] = temp_group[index_0][
                'mentionText'] + ' ' + ent_sh['mentionText']
            for t2 in ent_sh['textAnchor']['textSegments']:
                temp_group[index_0]['textAnchor']['textSegments'].append(t2)
        return temp_group

    def correct_page_text(temp_group_1, index_1):
        """ 
        
        Internal function to fiix the text anchors and page anchors of parent items 
        
        """
        temp_x = []
        temp_y = []
        temp_text_anc = []
        temp_mention_text = ''
        for suben in temp_group_1[index_1]['properties']:
            for tex_an1 in suben['textAnchor']['textSegments']:
                temp_text_anc.append(tex_an1)
            for page_an1 in suben['pageAnchor']['pageRefs'][0]['boundingPoly'][
                    'normalizedVertices']:
                temp_x.append(page_an1['x'])
                temp_y.append(page_an1['y'])

        updated_ver = [{
            'x': min(temp_x),
            'y': min(temp_y)
        }, {
            'x': max(temp_x),
            'y': max(temp_y)
        }, {
            'x': min(temp_x),
            'y': max(temp_y)
        }, {
            'x': max(temp_x),
            'y': min(temp_y)
        }]
        sorted_temp_text_anc = sorted(temp_text_anc,
                                      key=lambda x: int(x['endIndex']))
        temp_group_1[index_1]['textAnchor'][
            'textSegments'] = sorted_temp_text_anc
        temp_group_1[index_1]['pageAnchor']['pageRefs'][0]['boundingPoly'][
            'normalizedVertices'] = updated_ver
        for t5 in sorted_temp_text_anc:
            s1 = t5['startIndex']
            e1 = t5['endIndex']
            temp_mention_text = temp_mention_text + ' ' + json_dict['text'][
                int(s1):int(e1)]
        temp_group_1[index_1]['mentionText'] = temp_mention_text

        return temp_group_1

    if len(group_entites_spread) > 0:
        for group, entity_move in group_entites_spread.items():
            # def modify_across_page(groups_across[0],group_entites_spread[0]):
            try:
                page_1 = groups_across[group][0]['pageAnchor']['pageRefs'][0][
                    'page']
            except:
                page_1 = '0'
            try:
                page_2 = groups_across[group][1]['pageAnchor']['pageRefs'][0][
                    'page']
            except:
                page_2 = '0'
            if page_1 < page_2:
                ent_shuffle = get_ent_schffle(group, 1)
                temp_group_1 = ent_move(group, 1, 0)
                if len(temp_group_1[1]['properties']) > 0:
                    temp_group_updated = correct_page_text(temp_group_1, 1)
                else:
                    temp_group_updated = [temp_group_1[0]]

            elif page_1 > page_2:
                ent_shuffle = get_ent_schffle(group, 0)
                temp_group_1 = ent_move(group, 0, 1)
                if len(temp_group_1[0]['properties']) > 0:
                    temp_group_updated = correct_page_text(temp_group_1, 0)
                else:
                    temp_group_updated = [temp_group_1[1]]

            for ent_remove in groups_across[group]:
                json_dict['entities'].remove(ent_remove)

            for ent_add in temp_group_updated:
                json_dict['entities'].append(ent_add)

    return json_dict


def get_page_wise_entities(json_dict):
    """ 
    This Function will provide the entities page wise
    
    args: loaded json 
    
    output: dictionary with page as key and entities as values
    
    """

    entities_page = {}

    for entity in json_dict['entities']:
        page = '0'
        try:
            if 'page' in entity['pageAnchor']['pageRefs'][0].keys():
                page = entity['pageAnchor']['pageRefs'][0]['page']

            if page in entities_page.keys():
                entities_page[page].append(entity)
            else:
                entities_page[page] = [entity]
        except:
            pass
    return entities_page


def multi_page_entites(entities_pagewise, page):
    """ 
    This Function will group the entities in case of multiple line items page wise
    
    args: entities page wise and page number
    
    output: list of entities grouped
    
    """

    entity_types = []
    line_item_sub_entities = []
    for entity in entities_pagewise:
        if 'properties' in entity.keys() and entity['properties'] != []:
            if entity['type'] == 'line_item':
                for subentity in entity['properties']:
                    entity_types.append(subentity['type'])
                    line_item_sub_entities.append(subentity)
        else:
            entity_types.append(entity['type'])
    #print(line_item_sub_entities)
    line_items_multi_dict = {}
    for unique in unique_entities:
        if entity_types.count(unique) > 1:
            line_items_multi_dict[unique] = entity_types.count(unique)
    #print(line_items_multi_dict)

    value_counts = Counter(line_items_multi_dict.values())
    max_count = max(value_counts.values())
    entity_types_keys = [
        key for key, value in line_items_multi_dict.items()
        if value_counts[value] == max_count
    ]

    #print(entity_types_keys)
    dict_unique_ent = {}
    for entity_type in entity_types_keys:
        for entity in entities_pagewise:
            if 'properties' in entity.keys():
                for subentity in entity['properties']:
                    if subentity['type'] == entity_type:
                        if entity_type in dict_unique_ent.keys():
                            dict_unique_ent[entity_type].append(subentity)
                        else:
                            dict_unique_ent[entity_type] = [subentity]

    #print(dict_unique_ent)
    region_line_items = {}
    region_line_items_x = {}
    region_line_items_y = {}
    opt_region = {}
    for ent in dict_unique_ent:
        region = []
        dict_x_y = []
        count_product_code = 0
        min_x_1 = []
        min_y_1 = []
        for item in dict_unique_ent[ent]:
            x_y = {}
            x = []
            y = []
            x_min = ''
            y_min = ''
            count_product_code += 1
            for i in item['pageAnchor']['pageRefs'][0]['boundingPoly'][
                    'normalizedVertices']:
                y.append(i['y'])
                x.append(i['x'])
            x_min = min(x)
            y_min = min(y)
            diff = max(y) - min(y)
            x_y[count_product_code] = [{
                'x': min(x),
                'y': min(y)
            }, {
                'x': max(x),
                'y': max(y)
            }]
            dict_x_y.append(x_y)
            min_x_1.append(x_min)
            min_y_1.append(y_min)
        region_line_items_x[ent] = min_x_1
        region_line_items_y[ent] = min_y_1
        sorted_lst_x_y = sorted(dict_x_y,
                                key=lambda x: list(x.values())[0][0]['y'])
        region_line_items[ent] = sorted_lst_x_y
        opt_region[ent] = diff
    sorted_region_line_y = {
        key: sorted(values)
        for key, values in region_line_items_y.items()
    }
    sorted_region_line_x = {
        key: sorted(values)
        for key, values in region_line_items_x.items()
    }

    #print(region_line_items)
    regions_line_y_final = {}
    opt_region_ent = {}
    for region_line in sorted_region_line_y:
        line_no = 0
        line_range = {}
        for i in range(len(sorted_region_line_y[region_line])):
            line_no += 1
            try:
                pair = (sorted_region_line_y[region_line][i],
                        sorted_region_line_y[region_line][i + 1])
            except IndexError:
                pair = (sorted_region_line_y[region_line][i])
            line_range[line_no] = pair
        if region_line in regions_line_y_final.keys():
            regions_line_y_final[region_line].append(line_range)
        else:
            regions_line_y_final[region_line] = [line_range]
        opt_region_ent[region_line] = max(
            sorted_region_line_y[region_line]) - min(
                sorted_region_line_y[region_line])

    max_value = max(opt_region_ent.values())  # Find the maximum value

    selected_values = [
        key for key, value in opt_region_ent.items()
        if abs(value - max_value) < 0.005
    ]  # Select values satisfying the condition

    if len(selected_values) > 1:
        considered_boundry_ent = ''
        final_y = 1
        for selected_ent in selected_values:
            if final_y > min(sorted_region_line_y[selected_ent]):
                final_y = min(sorted_region_line_y[selected_ent])
                considered_boundry_ent = selected_ent
            else:
                pass
    else:
        considered_boundry_ent = selected_values[0]

    sub_entities_list = copy.deepcopy(line_item_sub_entities)
    line_item_dict_final = {}
    sub_entities_categorized = []
    count = 0

    for subentity in sub_entities_list:
        y_ent = []
        for ver in subentity['pageAnchor']['pageRefs'][0]['boundingPoly'][
                'normalizedVertices']:
            y_ent.append(ver['y'])
        for line, region in regions_line_y_final[considered_boundry_ent][
                0].items():

            try:
                if (min(y_ent) >= region[0] or max(y_ent)
                        >= region[0]) and (max(y_ent) < region[1]):

                    if line in line_item_dict_final.keys():
                        count = count + 1
                        if subentity not in sub_entities_categorized:
                            line_item_dict_final[line].append(subentity)
                            sub_entities_categorized.append(subentity)
                    else:
                        count = count + 1

                        if subentity not in sub_entities_categorized:
                            line_item_dict_final[line] = [subentity]
                            sub_entities_categorized.append(subentity)

            except TypeError:
                if min(y_ent) >= region:  #-0.005:
                    if line in line_item_dict_final.keys():
                        count = count + 1
                        if subentity not in sub_entities_categorized:
                            line_item_dict_final[line].append(subentity)
                            sub_entities_categorized.append(subentity)
                    else:
                        count = count + 1
                        if subentity not in sub_entities_categorized:
                            line_item_dict_final[line] = [subentity]
                            sub_entities_categorized.append(subentity)

    for item in sub_entities_list:
        if item not in sub_entities_categorized:
            y_ent = []
            #print(regions_line_y_final[considered_boundry_ent])
            for ver in item['pageAnchor']['pageRefs'][0]['boundingPoly'][
                    'normalizedVertices']:
                y_ent.append(ver['y'])
            diff_line = {}
            for line1, y1 in regions_line_y_final[considered_boundry_ent][
                    0].items():
                try:
                    diff = abs(min(y_ent) - y1[0])
                    diff_line[line1] = diff
                except TypeError:
                    diff = abs(min(y_ent) - y1)
                    diff_line[line1] = diff

            min_dist = min(diff_line.values())
            line_item_2 = [
                key for key, value in diff_line.items() if value == min_dist
            ]
            if line_item_2[0] in line_item_dict_final.keys():
                line_item_dict_final[line_item_2[0]].append(item)
                sub_entities_categorized.append(item)
            else:
                line_item_dict_final[line_item_2[0]] = [item]
                sub_entities_categorized.append(item)

    temp3 = []
    for element in line_item_sub_entities:
        if element not in sub_entities_categorized:
            temp3.append(element)
    if len(sub_entities_categorized) < len(line_item_sub_entities):
        left_out = len(line_item_sub_entities) - len(sub_entities_categorized)
        print(
            f'out of {len(line_item_sub_entities)} subentities,{left_out}  are not yet classified '
        )
        pass
    elif len(sub_entities_categorized) == len(line_item_sub_entities):
        print('All lineitems are classified')
        pass
    else:
        print('something is wrong in classified')
        pass

    def create_lineitem(line_item_sub_entities, page):
        """ 
         
            This Function will create a parent item and add the child items into it 
    
            args: list of child entities and page number
    
            output: grouped parent items 
            
        """
        line_item = {
            'mentionText': '',
            'pageAnchor': {
                'pageRefs': [{
                    'boundingPoly': {
                        'normalizedVertices': []
                    },
                    'page': page
                }]
            },
            'properties': [],
            'textAnchor': {
                'textSegments': []
            },
            'type': 'line_item'
        }

        text_anchors_sub_entities = []
        for item in line_item_sub_entities:
            text_anchors_sub_entities.append(
                item['textAnchor']['textSegments'][0])
        line_item['textAnchor']['textSegments'] = text_anchors_sub_entities
        sorted_list_text1 = sorted(text_anchors_sub_entities,
                                   key=lambda x: int(x['endIndex']))
        #print(sorted_list_text1)
        line_item_mention_text = ''
        line_item_properties = []
        line_item_text_segments = []
        line_item_normalizedvertices = []
        subentities_classified = []
        for index in sorted_list_text1:
            for item in line_item_sub_entities:
                if index in item['textAnchor']['textSegments']:
                    if item not in subentities_classified:
                        subentities_classified.append(item)
                        line_item_mention_text = line_item_mention_text + ' ' + item[
                            'mentionText']
                        line_item_properties.append(item)
                        line_item_text_segments.append(index)
                        for i in item['pageAnchor']['pageRefs'][0][
                                'boundingPoly']['normalizedVertices']:
                            line_item_normalizedvertices.append(i)

        min_x_ln_1 = min(line_item_normalizedvertices,
                         key=lambda d: d['x'])['x']
        max_x_ln_1 = max(line_item_normalizedvertices,
                         key=lambda d: d['x'])['x']
        min_y_ln_1 = min(line_item_normalizedvertices,
                         key=lambda d: d['y'])['y']
        max_y_ln_1 = max(line_item_normalizedvertices,
                         key=lambda d: d['y'])['y']
        line_item_normalizedvertices_final = [{
            'x': min_x_ln_1,
            'y': min_y_ln_1
        }, {
            'x': min_x_ln_1,
            'y': max_y_ln_1
        }, {
            'x': max_x_ln_1,
            'y': min_y_ln_1
        }, {
            'x': max_x_ln_1,
            'y': max_y_ln_1
        }]

        line_item['pageAnchor']['pageRefs'][0]['boundingPoly'][
            'normalizedVertices'] = line_item_normalizedvertices_final
        line_item['mentionText'] = line_item_mention_text
        line_item['properties'] = line_item_properties
        line_item['textAnchor']['textSegments'] = line_item_text_segments

        return line_item

    line_items_classified = []

    for key, line_item_1 in line_item_dict_final.items():

        line_item = create_lineitem(line_item_1, page)

        line_items_classified.append(line_item)
    return line_items_classified, considered_boundry_ent


def merge_entities(json_dict):
    """ 
    This Function will consolidate and group the child items into correct parent entities
    consolidated function
    
    args: loaded json
    
    output: updated json with grouped entities
    
    """
    entitites_page_wise = get_page_wise_entities(json_dict)

    line_entities_classified_pagewise = []
    for page, entities in entitites_page_wise.items():
        line_entities_temp = ''
        line_item_count = line_item_check(entities, unique_entities)
        if line_item_count == 1:
            line_entities_temp = single_line_item_merge(entities, page)
            line_entities_classified_pagewise.append(line_entities_temp)
        elif line_item_count > 1:
            line_entities_temp, considered_boundry_ent = multi_page_entites(
                entities, page)
            #print(considered_boundry_ent)
            for ent1 in line_entities_temp:
                line_entities_classified_pagewise.append(ent1)
        elif line_item_count == 0:
            print('no line items')

    #print(line_entities_classified_pagewise)
    final_entities = []
    if len(line_entities_classified_pagewise) == 0:
        pass
    else:
        for entity in json_dict['entities']:
            if entity['type'] != 'line_item':
                final_entities.append(entity)

        #final_entities.append(line_entities_classified_pagewise)
        for ent in line_entities_classified_pagewise:
            final_entities.append(ent)
        #print(json_dict['entities'])
        json_dict['entities'] = final_entities

    return json_dict


def desc_merge_update(json_dict):
    """ 
    This Function will group multiple descriptions in a line item , in case description has multiple lines
    into single description 
    
    args: loaded json
    
    output: grouped description inside line items
    
    """

    def desc_merge_1(ent_desc):
        desc_merge = {
            'mentionText': '',
            'pageAnchor': {
                'pageRefs': ''
            },
            'textAnchor': {
                'textSegments': []
            },
            'type': 'line_item/description'
        }

        text_anchors_desc_merge = []
        pagerefs = ''
        for item in ent_desc:
            text_anchors_desc_merge.append(
                item['textAnchor']['textSegments'][0])
            pagerefs = item['pageAnchor']['pageRefs']
        desc_merge['pageAnchor']['pageRefs'] = pagerefs
        desc_merge['textAnchor']['textSegments'] = text_anchors_desc_merge
        sorted_list_text1 = sorted(text_anchors_desc_merge,
                                   key=lambda x: int(x['endIndex']))
        # print(sorted_list_text1)
        desc_mention_text = ''
        desc_text_segments = []
        desc_normalizedvertices = []
        subentities_classified = []
        for index in sorted_list_text1:
            for item in ent_desc:
                if index in item['textAnchor']['textSegments']:
                    if item not in subentities_classified:
                        subentities_classified.append(item)
                        desc_mention_text = (desc_mention_text + ' ' +
                                             item['mentionText'])
                        desc_text_segments.append(index)
                        for i in item['pageAnchor']['pageRefs'][0][
                                'boundingPoly']['normalizedVertices']:
                            desc_normalizedvertices.append(i)

        #print(desc_normalizedvertices)
        min_x_dsc = min(desc_normalizedvertices, key=lambda d: d['x'])['x']
        max_x_dsc = max(desc_normalizedvertices, key=lambda d: d['x'])['x']
        min_y_dsc = min(desc_normalizedvertices, key=lambda d: d['y'])['y']
        max_y_dsc = max(desc_normalizedvertices, key=lambda d: d['y'])['y']
        #print(min_x_dsc,max_x_dsc,min_y_dsc,max_y_dsc)
        desc_normalizedvertices_final = [{
            'x': min_x_dsc,
            'y': min_y_dsc
        }, {
            'x': min_x_dsc,
            'y': max_y_dsc
        }, {
            'x': max_x_dsc,
            'y': min_y_dsc
        }, {
            'x': max_x_dsc,
            'y': max_y_dsc
        }]
        #print(desc_normalizedvertices_final)

        desc_merge['pageAnchor']['pageRefs'][0]['boundingPoly'][
            'normalizedVertices'] = desc_normalizedvertices_final
        desc_merge['mentionText'] = desc_mention_text
        desc_merge['textAnchor']['textSegments'] = desc_text_segments

        return desc_merge

    for entity in json_dict['entities']:
        line_en = []
        ent_desc = []
        desc_ent_merge = {}
        if entity['type'] == 'line_item':
            line_en.append(entity['properties'])

        for itm in line_en:
            for ent1 in itm:
                if ent1['type'] == 'line_item/description':
                    ent_desc.append(ent1)
        if len(ent_desc) > 1:
            desc_merge = desc_merge_1(ent_desc)
            if entity['type'] == 'line_item':
                for en2 in ent_desc:
                    if en2 in entity['properties']:
                        del entity['properties'][entity['properties'].index(
                            en2)]
            entity['properties'].append(desc_merge)
            #pprint(entity)

    return json_dict

In [5]:
##Calling functions

fs = gcsfs.GCSFileSystem(project_id)
file_names_list, file_dict = file_names(Gcs_input_path)
for filename, filepath in tqdm(file_dict.items(), desc='Progress'):
    input_bucket_name = Gcs_input_path.split('/')[2]
    if '.json' in filepath:
        #output_bucket_name=Gcs_output_path.split('/')[2]
        filepath = "gs://" + input_bucket_name + '/' + filepath
        print(filename)
        json_dict = load_json(filepath)
        json_dict_updated = merge_entities(json_dict)

        if line_item_across_pages == 'Yes':
            json_dict_updated = get_lineitems_grouped_pages_across(
                json_dict_updated)

        if Desc_merge_update == 'Yes':
            json_dict_updated = desc_merge_update(json_dict_updated)

        fs.pipe(Gcs_output_path + '/' + filename,
                bytes(json.dumps(json_dict_updated, ensure_ascii=False),
                      'utf-8'),
                content_type='application/json')

Progress:   0%|          | 0/5 [00:00<?, ?it/s]

1. Client X - Format 1 - Dummy Invoice - Sprint 4.json


Progress:  20%|██        | 1/5 [00:00<00:01,  2.31it/s]

2. Client X - Format 3 - Dummy Invoice - Sprint 4 Demo.json


Progress:  40%|████      | 2/5 [00:00<00:01,  2.43it/s]

21dcad73e9467e51-Client Y Format 1 - Multi-line - v1.2.json
All lineitems are classified


Progress:  60%|██████    | 3/5 [00:01<00:00,  2.52it/s]

3. Client Y - Format 1 - Dummy Invoice - Sprint 4 Demo.json


Progress:  80%|████████  | 4/5 [00:01<00:00,  2.55it/s]

4. Multi-Party Dummy Invoice - Sprint 4 Demo.json
All lineitems are classified


Progress: 100%|██████████| 5/5 [00:02<00:00,  2.47it/s]
