In [1]:
import json
from tqdm import tqdm
from PIL import Image
from src.process import create_image_view

import warnings
warnings.filterwarnings("ignore")

from src.AnonymizationInference import AnonymizationInference
import os
from src.w_b import count_and_log_all_metrics
from src.config import funsd_label_list


def get_layoutlm_predictions(
            inference, images, path_to_image, path_to_gt_labeled_images, labels_saving_path, image_views_path, path_to_gt_labels=None,
    ):
        for image in tqdm(images):

            if not image.endswith(".png"):
                continue

            image_path = os.path.join(path_to_image, image)

            if path_to_gt_labels:
                with open(os.path.join(path_to_gt_labels, image.replace(".png", ".json"))) as f:
                    gt_labels = json.load(f)
                words = gt_labels["tokens"]
                boxes = gt_labels["bboxes"]
                
                predictions = inference.predict(image_path, words, boxes)
            else:
                predictions = inference.predict(image_path)
            
            img_with_pred_bboxes = inference.draw_bboxes(image_path, predictions)
            img_with_gt_bboxes = Image.open(os.path.join(path_to_gt_labeled_images, image))
            create_image_view(img_with_gt_bboxes, img_with_pred_bboxes, f"{image_views_path}/{image}")
            label_name = image.replace(".png", ".json")
            with open(f"{labels_saving_path}/{label_name}", "w") as f:
                json.dump(predictions, f, indent=4)

## Getting predictions

In [2]:
name = "layoutlm_best"

# models
detection_model="fast_base"
recognition_model="master"
ocr_model = f"{detection_model}_{recognition_model}"
layoutlm_model = "best_finetuned"
lm_model_weights = f"weights/{name}"
signature_model_weights = "weights/yolo_signatures.pt"

# paths
path_to_results = "results"
path_to_benchmark = "data/funsd_benchmark"

# benchmark
path_to_benchmark_images = os.path.join(path_to_benchmark, "images")
path_to_gt_benchmark_labeled_images = os.path.join(path_to_benchmark, "labeled_images")
path_to_gt_benchmark_labels = os.path.join(path_to_benchmark, "layoutlm_labels")
predicted_benchmark_image_view_path = os.path.join(path_to_results, f"benchmark_image_views_{ocr_model}")
predicted_benchmark_labels_folder = os.path.join(path_to_results, f"benchmark_layoutlm_labels_{ocr_model}")
benchmark_images = os.listdir(path_to_benchmark_images)

os.makedirs(path_to_results, exist_ok=True)
os.makedirs(predicted_benchmark_labels_folder, exist_ok=True)
os.makedirs(predicted_benchmark_image_view_path, exist_ok=True)

In [8]:
inference = AnonymizationInference(
    detection_model=detection_model,
    recognition_model=recognition_model,
    path_to_layoutlm_weights=lm_model_weights,
    path_to_signature_weights=signature_model_weights,
    label_list=funsd_label_list,
)

get_layoutlm_predictions(
    inference=inference,
    images=benchmark_images, 
    path_to_image=path_to_benchmark_images, 
    path_to_gt_labeled_images=path_to_gt_benchmark_labeled_images, 
    labels_saving_path=predicted_benchmark_labels_folder, 
    image_views_path=predicted_benchmark_image_view_path,
    path_to_gt_labels=path_to_gt_benchmark_labels,
)

100%|██████████| 255/255 [00:29<00:00,  8.59it/s]


## Metrics

In [3]:
test_samples = [
        {   
            "test_name": "benchmark",
            "gt_labels": path_to_gt_benchmark_labels,
            "predicted_labels": predicted_benchmark_labels_folder,
            "image_views": predicted_benchmark_image_view_path,
            "class_names": [
                "full_name", "phone_number", "address", "company_name", "email_address", "signature"
            ]
        },
]

In [1]:
count_and_log_all_metrics(
    samples=test_samples,
    lm_model_name=f"LayoutLM_{layoutlm_model}",
    ocr_model_name=ocr_model,
    run_specification="trained_on_invoices"
)