# LayoutLM Inference: Invoice Field Extraction

This script performs inference using a fine-tuned `LayoutLM` model to extract structured fields from scanned invoices.

---

### Workflow:
- Load trained model from checkpoint
- OCR invoice using `pytesseract`
- Tokenize OCR words + bounding boxes
- Predict labels (`B-COMPANY`, `B-DATE`, etc.)
- Return a dictionary of extracted fields

> Designed for use with trained LayoutLM model on the SROIE invoice dataset


In [4]:
# Import necessary libraries
import os
import torch
import pytesseract
from PIL import Image
from transformers import LayoutLMTokenizerFast, LayoutLMForTokenClassification
from pytesseract import Output

In [13]:
model = LayoutLMForTokenClassification.from_pretrained(
    "../models/layoutlm_invoice",
    local_files_only=True,
    use_safetensors=True
)

tokenizer = LayoutLMTokenizerFast.from_pretrained(
    "../models/layoutlm_invoice",
    local_files_only=True
)


In [14]:
# Label map
id2label = model.config.id2label

# Normalize boxes to 0–1000 scale
def normalize_box(box, width, height):
    return [
        int(box[0] * 1000 / width),
        int(box[1] * 1000 / height),
        int(box[2] * 1000 / width),
        int(box[3] * 1000 / height),
    ]

# Inference function
def predict_invoice_fields(image_path):
    image = Image.open(image_path).convert("RGB")
    width, height = image.size

    ocr_data = pytesseract.image_to_data(image, output_type=Output.DICT)

    words = []
    boxes = []

    for i in range(len(ocr_data["text"])):
        word = ocr_data["text"][i].strip()
        if word == "":
            continue
        x, y, w, h = ocr_data["left"][i], ocr_data["top"][i], ocr_data["width"][i], ocr_data["height"][i]
        words.append(word)
        boxes.append(normalize_box([x, y, x + w, y + h], width, height))

    # Tokenize
    encoding = tokenizer(
        words,
        is_split_into_words=True,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512
    )

    input_ids = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]
    bbox = boxes + [[0, 0, 0, 0]] * (512 - len(boxes))
    encoding["bbox"] = torch.tensor([bbox])

    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=encoding["bbox"]
        )
    predictions = torch.argmax(outputs.logits, dim=2)

    labels = predictions[0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    extracted = {}

    for token, label_idx, word in zip(tokens, labels, words):
        label = id2label[label_idx]
        if label != 'O' and not token.startswith("##"):
            extracted.setdefault(label, []).append(word)

    # Combine multi-token fields
    final_output = {k: " ".join(v) for k, v in extracted.items()}
    return final_output

# Example usage
image_path = "../data/sroie/images/X51008123604.jpg"  # Change to your test image
result = predict_invoice_fields(image_path)
print("Extracted Fields:")
for k, v in result.items():
    print(f"{k}: {v}")


Extracted Fields:
B-COMPANY: OGN Jalan Dinar U3/G, Date PM
B-DATE: 4:37:37 Scissors Discount Rounding Adjustment: 0.00 TOTAL : GOODS Come
