In [None]:
from transformers import AutoModelForTokenClassification

# Will load LayoutLMv3ForTokenClassification
MODEL_PATH = "C:/Projects/IDP/watercare/model_output/23_11_03_03/checkpoint-150"
model = AutoModelForTokenClassification.from_pretrained(MODEL_PATH)

In [None]:
from transformers import AutoProcessor

# we'll use the Auto API here - it will load LayoutLMv3Processor behind the scenes,
# based on the checkpoint we provide from the hub
processor = AutoProcessor.from_pretrained(MODEL_PATH, apply_ocr=False)

In [None]:
# Convert pdf to images
from pdf2image import convert_from_path
from pathlib import Path

# Open a pdf file
images = convert_from_path(Path("C:/Projects/IDP/watercare/dataset/pdfs/23_10_25.pdf"), fmt="png")


In [None]:
import pytesseract
from idp.annotations.bbox_utils import normalize_box

def extract_text_from_image(image):
    img_width, img_height = image.size
    tesseract_output = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
    boxes = []
    texts = []
    for i, level_idx in enumerate(tesseract_output["level"]):
        if level_idx == 5:
            bbox = [tesseract_output["left"][i],
                    tesseract_output["top"][i],
                    tesseract_output["left"][i]+tesseract_output["width"][i],
                    tesseract_output["top"][i]+tesseract_output["height"][i]
            ]
            if not tesseract_output["text"][i].strip():
                continue
            bbox = normalize_box(bbox, img_width, img_height)
            texts.append(tesseract_output["text"][i])
            boxes.append(bbox)

    return (texts, boxes)

pages = len(images)
text_box_pairs = [extract_text_from_image(image) for image in images]
texts = [pair[0] for pair in text_box_pairs]
boxes = [pair[1] for pair in text_box_pairs]
page_indexes = [[index] * len(text_arr) for index, text_arr in enumerate(texts)]

In [None]:
def is_box_a_within_box_b(box_a, box_b):
    left_a, top_a, right_a, bottom_a = box_a
    left_b, top_b, right_b, bottom_b = box_b
    
    # Check if Box B contains box A
    return left_b<=left_a and top_b<=top_a and right_b>=right_a and bottom_b>=bottom_a
    
# Assume all pages have the same size
img_width, img_height = images[0].size

# HACK remove texts & boxes outside of Labelling area
def text_box_relevant(text_box_page):
    text, box, page = text_box_page
    
    if page == 0:
        outer_box = normalize_box((1298,828,1536,1550), img_width, img_height)
        return is_box_a_within_box_b(box, outer_box)
    elif page == 1:
        outer_box = normalize_box((139,119,827,1173), img_width, img_height)
        return is_box_a_within_box_b(box, outer_box)
    else:
        return False

out_texts = []
out_boxes = []

for page in range(pages):
    filtered_list = list(filter(text_box_relevant, list(zip(texts[page], boxes[page], page_indexes[page]))))
    results = [[text for text, box, page in filtered_list],[box for text, box, page in filtered_list]]
    temp_texts, temp_boxes = results
    out_texts.append(temp_texts)
    out_boxes.append(temp_boxes)
# texts, boxes = result

In [None]:
encoding = processor(images=images, text=out_texts, boxes=out_boxes, truncation=True, padding="max_length", return_tensors="pt")

In [None]:
# HACK - Trim tokens to match 512 
# encoding['input_ids'] = encoding['input_ids'][:,:512]
# encoding['attention_mask'] = encoding['attention_mask'][:,:512]
# encoding['bbox'] = encoding['bbox'][:,:512]

print(encoding['input_ids'][1])

In [None]:
import torch

with torch.no_grad():
  outputs = model(**encoding)

In [None]:
predictions = outputs.logits.argmax(-1).tolist()

In [None]:
model.config.id2label

In [None]:
from idp.annotations.bbox_utils import unnormalize_box, merge_box_extremes
from idp.annotations.annotation_utils import Classes, CLASS_TO_LABEL_MAP


token_boxes = encoding.bbox.tolist()
pages_input_ids = encoding['input_ids'].tolist()
PAD_ID = 1
# First token is a special token, ignore
pad_indexes = [page_input_ids.index(PAD_ID) for page_input_ids in encoding['input_ids'].tolist()]
true_predictions = [[model.config.id2label[pred] for pred in page_preds[1:pad_indexes[page_index]]] for page_index, page_preds in enumerate(predictions)]
true_boxes = [[unnormalize_box(box, img_width, img_height) for box in page_token_boxes[1:pad_indexes[page_index]]] for page_index,page_token_boxes in enumerate(token_boxes)]
true_texts = [[processor.tokenizer.decode([value]) for index, value in enumerate(page_input_ids[1:pad_indexes[page_index]])] for page_index,page_input_ids in enumerate(pages_input_ids)]



In [None]:

output = [{value: {'text':'','box':[]} for key,value in CLASS_TO_LABEL_MAP.items() if key != Classes.OTHER} for page in range(pages)]

for page_indx in range(pages):
    for key,value in CLASS_TO_LABEL_MAP.items():
        if key == Classes.OTHER:
            continue
        output[page_indx][value]['text'] = ''.join([ text for text, prediction, box in zip(true_texts[page_indx], true_predictions[page_indx], true_boxes[page_indx]) if (prediction == value and box != [0,0,0,0])])
        output[page_indx][value]['box'] = merge_box_extremes([box for text, prediction, box in zip(true_texts[page_indx], true_predictions[page_indx], true_boxes[page_indx]) if (prediction == value and box != [0,0,0,0])])

# #trim empty outputs
item_not_empty = lambda item : len(item[1]['text']) != 0
filtered_output = [dict(filter(item_not_empty, page_output.items())) for page_output in output]

In [None]:
filtered_output

In [None]:
from PIL import ImageDraw, ImageFont

# font = ImageFont.load_default()
font = ImageFont.truetype("arial.ttf", 20)

def iob_to_label(label):
    label = label[2:]
    if not label:
      return 'other'
    return label

label2color = {'other':'pink','balance_still_owing':'red', 'water_consumption':'purple', 'wastewater_consumption':'green', 'wastewater_fixed':'orange', 'balance_current_charges':'violet',
              "total_due": "black",'water_consumption_details':'red','wastewater_consumption_details':'purple','wastewater_fixed_details':'green','this_reading':'grey','last_reading':'black'}

for indx, page_output in enumerate(filtered_output):
    draw = ImageDraw.Draw(images[indx])
    for item in page_output.items():
        predicted_label = iob_to_label(item[0]).lower()
        box = item[1]['box']
        draw.rectangle(box, outline=label2color[predicted_label])
#         draw.text((box[0] - 100, box[1]), text=f"{predicted_label}", fill=label2color[predicted_label], font=font)

In [None]:
images[0]

In [None]:
images[1]