### Imports

In [None]:
import warnings
warnings.filterwarnings("ignore")
import spacy
import os
import glob
from pathlib import Path
import json
from google.cloud import vision
import io
import time

import math
import cv2
import os
import boto3
import logging
from PIL import Image
from io import BytesIO
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torch
import gc
from database import update_doc_type


from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'gcp/googlecreds.json'

### KeyValue (extract and map)

In [None]:
def extract(text):
    current_Key = ""
    dic = {}
    key_cnt = 0
    val_cnt = 0
    current_val = ""
    text = text.replace("\n" , "")
#     dic['Text'] = text
    doc = nlp_ner(text)
    for ent in doc.ents:
        if(ent.label_ == "KEY"):
            key_cnt +=1
            current_Key = ent.text
            dic[current_Key]= ""
        elif (ent.label_ == "VALUE"):
            val_cnt +=1
            current_val = ent.text
            if(current_Key!=""):
                dic[current_Key] = ent.text
                current_Key=""
    if val_cnt == 1 and key_cnt == 1 and current_Key!="" :
        dic[current_Key]= current_val
    return dic


def get_master_key_val(key_val_map):
    dic = {}
    duplicate_keys = []
    for key in key_val_map.keys():
        for x in key_val_map[key].keys():
            if x not in dic.keys() and x not in duplicate_keys and x not in duplicates_label_list:
                dic[x] = key_val_map[key][x]
            elif x not in duplicate_keys and x not in duplicates_label_list:
                dic.pop(x)
                duplicate_keys.append(x)
    return dic
                

### Image (Bbox detection to text) 

In [None]:
def detection(model, path):
    image = cv2.imread(path)

    with torch.no_grad():
        outputs = model(image)
    scores = outputs['instances'].scores.cpu().numpy().tolist()
    boxes = outputs['instances'].pred_boxes.tensor.cpu().numpy().tolist()
    labels = outputs['instances'].pred_classes.cpu().numpy().tolist()

#     upload_to_s3(output_path, getOutputBucket())
#     logger.info('Uploaded output to S3.')

    del image
    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    filename = Path(path).stem
    data = {'filename': filename, 'boxes': boxes, 'scores': scores, 'labels': labels}
    return data


def detectText(path):

    vertexList = []
    client = vision.ImageAnnotatorClient()

    with io.open(path, 'rb') as image_file:
        content = image_file.read()

    image = vision.Image(content=content)
    response = client.text_detection(image=image)
    # print(response)
    texts = response.text_annotations
    bool = False
    complete_text = ''
    for text in texts:
        # print(text)

        vertices = text.bounding_poly.vertices
        if bool:
            vertexList.append((text.description, vertices))
        else:
            complete_text = text.description
            bool = True

    if response.error.message:
        raise Exception(
            '{}\n'.format(response.error.message))

    return vertexList


def create_text_list(vertexList, data):
    detectionList = data['boxes']
    labelList = data['labels']
    labels = []
    text_array = []
    boxes = []
    for i in range(len(detectionList)):
        boundingBox = detectionList[i]
        label = labelList[i]
        text = ''
        try:
            for textGroup in vertexList:
                vertexText = textGroup['text']
                vertex = textGroup['border']

                x1 = vertex.get('minX', 0)
                y1 = vertex.get('minY', 0)
                x2 = vertex.get('maxX', 0)
                y2 = vertex.get('maxY', 0)
                height = y2 - y1
                width = x2 - x1

                if (not (max(x1, boundingBox[0]) + width / 3 > min(x2, boundingBox[2]))) and \
                        (not (max(y1, boundingBox[1]) + height / 3 > min(y2, boundingBox[3]))):
                    text += vertexText + ' '
        except KeyError as e:
            logger.info('Key Error:')
            logger.error(e)

        except Exception as e:
            logger.error(e)

        labels.append(label)
        text_array.append(text)
        boxes.append([boundingBox[0], boundingBox[1], boundingBox[2], boundingBox[3]])

    return {"label": labels, "text_data": text_array, "boxes": boxes}


### Load Detection model

In [None]:
def load_model():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = 'keyValue_detectron2.pth'
    if not os.path.exists(model_path):
        logger.info("Downloading model from s3")
        get_check_point_file_from_s3(model_path)
        get_check_point_file_from_s3('config_101.yaml')
    model = get_model(model_path, 0.5, device, 'config_101.yaml')
    return model
    
    
def get_model(model_path, threshold, device, config_path):
    # Create config
    cfg = get_cfg()
    cfg.merge_from_file(config_path)
    cfg.MODEL.DEVICE = device

    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold

    cfg.MODEL.WEIGHTS = model_path

    return DefaultPredictor(cfg)

In [None]:
def main():
    model = load_model()
    nlp_ner = spacy.load("./model-best")
    base_path = 'dir/images/'
    json_list = []
    for file in glob.glob(base_path + '/*.png'):
        data = detection(model, file)
        vertexList = detectText(file)
        results = create_text_list(vertexList, data)
        final_dic = {}
        c = 0
        for line in results['text_data']:
            line = line replace("\n" , "")
            dic = extract(line)
            c = c + 1
            bbox = "BBOX" + str(c)
            final_dic[bbox] = dic
        json_list.append(final_dic)
    return json_list


if __name__ == '__main__':
    main()
