In [4]:
!apt-get install tesseract-ocr -y
!pip install pytesseract transformers torch torchvision --quiet

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
0 upgraded, 0 newly installed, 0 to remove and 2 not upgraded.


In [5]:
import torch
import pytesseract
from pytesseract import Output
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering
from PIL import Image
import json
import re

In [7]:
model_path = "/content"

processor = LayoutLMv3Processor.from_pretrained(model_path)
model = LayoutLMv3ForQuestionAnswering.from_pretrained(model_path)

model.to(device)
model.eval()

Loading weights:   0%|          | 0/216 [00:00<?, ?it/s]

LayoutLMv3ForQuestionAnswering(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder(


In [14]:
def get_ocr_words_boxes(image_path):

    image = Image.open(image_path).convert("RGB")
    width, height = image.size

    data = pytesseract.image_to_data(image, output_type=Output.DICT)

    words = []
    boxes = []

    n = len(data["text"])

    for i in range(n):
        if int(data["conf"][i]) > 0 and data["text"][i].strip() != "":
            word = data["text"][i]
            x = data["left"][i]
            y = data["top"][i]
            w = data["width"][i]
            h = data["height"][i]

            x0 = int(1000 * x / width)
            y0 = int(1000 * y / height)
            x1 = int(1000 * (x + w) / width)
            y1 = int(1000 * (y + h) / height)

            words.append(word)
            boxes.append([x0, y0, x1, y1])

    return words, boxes

In [15]:
field_questions = {
    "hospital_name": "What is the hospital name?",
    "bill_number": "What is the bill number?",
    "bill_date": "What is the bill date?",
    "admit_date": "What is the admit date?",
    "discharge_date": "What is the discharge date?",
    "mrn": "What is the MRN number?",
    "total_amount": "What is the total amount?"
}

In [16]:
import torch

def extract_field(image_path, field_name):

    question = field_questions[field_name]
    words, boxes = get_ocr_words_boxes(image_path)

    image = Image.open(image_path).convert("RGB")

    encoding = processor(
        images=image,
        text=question,
        text_pair=words,
        boxes=boxes,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    sequence_ids = encoding.sequence_ids(0)
    encoding = {k: v.to(device) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)

    start_logits = outputs.start_logits.squeeze(0)
    end_logits = outputs.end_logits.squeeze(0)

    # Mask question tokens
    for i, seq_id in enumerate(sequence_ids):
        if seq_id != 1:
            start_logits[i] = -1e9
            end_logits[i] = -1e9

    start_pred = torch.argmax(start_logits).item()
    end_pred = torch.argmax(end_logits).item()

    tokens = processor.tokenizer.convert_ids_to_tokens(
        encoding["input_ids"].squeeze(0)
    )

    predicted_tokens = tokens[start_pred:end_pred+1]
    predicted_text = processor.tokenizer.convert_tokens_to_string(predicted_tokens)

    return predicted_text.strip()

In [17]:
def extract_document_fields(image_path):

    extracted = {}

    for field in field_questions:
        try:
            value = extract_field(image_path, field)
            extracted[field] = value
        except:
            extracted[field] = None

    return extracted


In [18]:
image_path = "/content/2048528-06.jpg"

result = extract_document_fields(image_path)

import json
print(json.dumps(result, indent=2))

{
  "hospital_name": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018",
  "bill_number": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018",
  "bill_date": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018",
  "admit_date": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018",
  "discharge_date": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018",
  "mrn": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018",
  "total_amount": "18-Sep-2018 12:00AM To Admission No. IP/18/012383 Adm Date 18-Sep-2018"
}
