In [4]:
# ===== CORE LIBRARIES =====
import os
import json
import re
import numpy as np
from PIL import Image
from tqdm import tqdm

# ===== OCR =====
import pytesseract
from pytesseract import Output

# ===== TORCH =====
import torch
from torch.utils.data import Dataset, DataLoader

# ===== TRANSFORMERS =====
from transformers import (
    LayoutLMv3Processor,
    LayoutLMv3ForTokenClassification,
    TrainingArguments,
    Trainer
)

# ===== METRICS =====
from sklearn.metrics import classification_report

print("Imports loaded successfully.")

Imports loaded successfully.


In [5]:
# ===== PATHS =====

IMAGE_ROOT = "/kaggle/input/datasets/yamunaaab/artivatic-train-data/OneDrive_2026-02-22/Sample data"
ANNOTATION_PATH = "/kaggle/input/datasets/yamunaaab/annotated-json/json.txt"  # ðŸ”´ change this to correct path


# ===== LOAD ANNOTATIONS =====

with open(ANNOTATION_PATH, "r") as f:
    annotations_raw = json.load(f)

print("Total annotations:", len(annotations_raw))

annotation_keys = sorted(annotations_raw.keys())


# ===== DEFINE FOLDER ORDER (Your Batch Order) =====

folder_order = [
    "Detailed hospital bill",
    "Diagnostic Bills",
    "Discharge Bills",
    "Final Bills",
    "Hospital Bills",
    "IPD Bills"
]


# ===== LOAD ORDERED IMAGE PATHS =====

ordered_image_paths = []

for folder_name in folder_order:
    folder_path = os.path.join(IMAGE_ROOT, folder_name)
    
    files = [
        os.path.join(folder_path, f)
        for f in os.listdir(folder_path)
        if f.lower().endswith((".jpg", ".jpeg", ".png"))
    ]
    
    files = sorted(files)  # ensure consistent order
    
    # Since you annotated 5 per batch
    ordered_image_paths.extend(files[:5])

print("Total ordered images:", len(ordered_image_paths))
print("First 3 ordered images:", ordered_image_paths[:3])

Total annotations: 30
Total ordered images: 30
First 3 ordered images: ['/kaggle/input/datasets/yamunaaab/artivatic-train-data/OneDrive_2026-02-22/Sample data/Detailed hospital bill/1981063-2.jpg', '/kaggle/input/datasets/yamunaaab/artivatic-train-data/OneDrive_2026-02-22/Sample data/Detailed hospital bill/2042173-10.jpg', '/kaggle/input/datasets/yamunaaab/artivatic-train-data/OneDrive_2026-02-22/Sample data/Detailed hospital bill/2042485-05.jpg']


In [6]:
image_annotation_pairs = list(zip(ordered_image_paths, annotation_keys))
print("Pairs created:", len(image_annotation_pairs))

Pairs created: 30


In [8]:
# ===== OCR WORD + BBOX EXTRACTION =====

def get_ocr_words_boxes(image_path):
    image = Image.open(image_path).convert("RGB")
    width, height = image.size
    
    data = pytesseract.image_to_data(image, output_type=Output.DICT)
    
    words = []
    boxes = []
    
    for i in range(len(data["text"])):
        word = data["text"][i].strip()
        if word == "":
            continue
        
        x = data["left"][i]
        y = data["top"][i]
        w = data["width"][i]
        h = data["height"][i]
        
        # Normalize to 0â€“1000 scale (LayoutLM requirement)
        box = [
            int(1000 * x / width),
            int(1000 * y / height),
            int(1000 * (x + w) / width),
            int(1000 * (y + h) / height),
        ]
        
        words.append(word)
        boxes.append(box)
    
    return words, boxes

In [9]:
test_image_path, test_key = image_annotation_pairs[0]

words, boxes = get_ocr_words_boxes(test_image_path)

print("Number of OCR words:", len(words))
print("First 20 words:", words[:20])
print("First 5 boxes:", boxes[:5])

Number of OCR words: 377
First 20 words: ['SUNSHINE', 'HOSPITALS', '(A', 'Unit', 'of', 'Sarvejana', 'Healthcare', 'Pvt.', 'Ltd.)', 'Laxmi', 'Sagar', 'Square,', 'Puri', '-', 'Cuttack', 'Road', 'Bhubaneswar,', 'Odisha', '-', '751006']
First 5 boxes: [[409, 115, 483, 125], [488, 116, 570, 126], [383, 128, 394, 136], [398, 128, 418, 135], [421, 129, 432, 136]]


In [10]:
# ===== TEXT NORMALIZATION =====

def normalize_text(text):
    if text is None:
        return None
    text = text.lower()
    text = re.sub(r'[^a-z0-9]', '', text)
    return text


# ===== ENTITY LIST =====

schema_entities = [
    "HOSPITAL_NAME",
    "PATIENT_NAME",
    "BILL_NUMBER",
    "BILL_DATE",
    "ADMIT_DATE",
    "DISCHARGE_DATE",
    "MRN",
    "TOTAL_AMOUNT"
]


# ===== BIO LABEL GENERATION =====

def generate_bio_labels(words, annotation_dict):
    labels = ["O"] * len(words)
    norm_words = [normalize_text(w) for w in words]

    for entity in schema_entities:
        value = annotation_dict.get(entity.lower())
        if value is None:
            continue
        
        norm_value = normalize_text(value)
        if not norm_value:
            continue
        
        # sliding window search
        for i in range(len(norm_words)):
            combined = ""
            for j in range(i, len(norm_words)):
                combined += norm_words[j]
                
                if combined == norm_value:
                    labels[i] = f"B-{entity}"
                    for k in range(i+1, j+1):
                        labels[k] = f"I-{entity}"
                    break
                
                if len(combined) > len(norm_value):
                    break
    
    return labels

In [11]:
# Get first pair
test_image_path, test_key = image_annotation_pairs[0]

annotation_data = annotations_raw[test_key]

words, boxes = get_ocr_words_boxes(test_image_path)
bio_labels = generate_bio_labels(words, annotation_data)

# Print matched tokens
for w, l in zip(words, bio_labels):
    if l != "O":
        print(w, "->", l)

SUNSHINE -> B-HOSPITAL_NAME
HOSPITALS -> I-HOSPITAL_NAME
03-Sep-2018 -> B-DISCHARGE_DATE


In [12]:
label_list = [
    "O",

    "B-HOSPITAL_NAME", "I-HOSPITAL_NAME",
    "B-BILL_NUMBER", "I-BILL_NUMBER",
    "B-BILL_DATE", "I-BILL_DATE",
    "B-ADMIT_DATE", "I-ADMIT_DATE",
    "B-DISCHARGE_DATE", "I-DISCHARGE_DATE",
    "B-MRN", "I-MRN",
    "B-TOTAL_AMOUNT", "I-TOTAL_AMOUNT"
]

label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}

print(label2id)

{'O': 0, 'B-HOSPITAL_NAME': 1, 'I-HOSPITAL_NAME': 2, 'B-BILL_NUMBER': 3, 'I-BILL_NUMBER': 4, 'B-BILL_DATE': 5, 'I-BILL_DATE': 6, 'B-ADMIT_DATE': 7, 'I-ADMIT_DATE': 8, 'B-DISCHARGE_DATE': 9, 'I-DISCHARGE_DATE': 10, 'B-MRN': 11, 'I-MRN': 12, 'B-TOTAL_AMOUNT': 13, 'I-TOTAL_AMOUNT': 14}


In [13]:
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification

processor = LayoutLMv3Processor.from_pretrained(
    "microsoft/layoutlmv3-base",
    apply_ocr=False
)

model = LayoutLMv3ForTokenClassification.from_pretrained(
    "microsoft/layoutlmv3-base",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

Loading weights:   0%|          | 0/212 [00:00<?, ?it/s]

[1mLayoutLMv3ForTokenClassification LOAD REPORT[0m from: microsoft/layoutlmv3-base
Key                                | Status     | 
-----------------------------------+------------+-
layoutlmv3.embeddings.position_ids | UNEXPECTED | 
classifier.dense.weight            | MISSING    | 
classifier.out_proj.weight         | MISSING    | 
classifier.out_proj.bias           | MISSING    | 
classifier.dense.bias              | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


In [14]:
from PIL import Image
import os

# ---- Pick first image ----
test_image_path = ordered_image_paths[0]

# ---- Load image ----
image = Image.open(test_image_path).convert("RGB")
print("Loaded image:", test_image_path)

# ---- Get index of image ----
idx = ordered_image_paths.index(test_image_path)

# ---- Get matching annotation key ----
annotation_key = annotation_keys[idx]

# ---- Load annotation ----
annotation_data = annotations_raw[annotation_key]

# ---- Run OCR ----
words, boxes = get_ocr_words_boxes(test_image_path)

# ---- Generate BIO ----
bio_labels = generate_bio_labels(words, annotation_data)

print("Annotation key used:", annotation_key)

Loaded image: /kaggle/input/datasets/yamunaaab/artivatic-train-data/OneDrive_2026-02-22/Sample data/Detailed hospital bill/1981063-2.jpg
Annotation key used: image_1


In [15]:
filename = os.path.basename(test_image_path)

# Get index of this image in ordered_image_paths
idx = ordered_image_paths.index(test_image_path)

# Get corresponding annotation key
annotation_key = annotation_keys[idx]

annotation_data = annotations_raw[annotation_key]

In [16]:
for w, l in zip(words, bio_labels):
    if l != "O":
        print(w, "->", l)

SUNSHINE -> B-HOSPITAL_NAME
HOSPITALS -> I-HOSPITAL_NAME
03-Sep-2018 -> B-DISCHARGE_DATE


In [17]:
# ---- OCR ----
words, boxes = get_ocr_words_boxes(test_image_path)

# ---- Map image to annotation using index (Option B) ----
idx = ordered_image_paths.index(test_image_path)
annotation_key = annotation_keys[idx]
annotation_data = annotations_raw[annotation_key]

# ---- Generate BIO ----
bio_labels = generate_bio_labels(words, annotation_data)

In [18]:
print("Words:", len(words))
print("Boxes:", len(boxes))
print("BIO labels:", len(bio_labels))

Words: 377
Boxes: 377
BIO labels: 377


In [19]:
encoding = processor(
    images=image,
    text=words,        # <-- FIXED (not words=)
    boxes=boxes,
    word_labels=[label2id[label] for label in bio_labels],
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

for k, v in encoding.items():
    print(k, v.shape)

input_ids torch.Size([1, 512])
attention_mask torch.Size([1, 512])
bbox torch.Size([1, 512, 4])
labels torch.Size([1, 512])
pixel_values torch.Size([1, 3, 224, 224])


In [20]:
import torch
from torch.utils.data import Dataset

class HospitalDataset(Dataset):
    def __init__(self, image_paths, annotations_raw, annotation_keys, processor, label2id):
        self.image_paths = image_paths
        self.annotations_raw = annotations_raw
        self.annotation_keys = annotation_keys
        self.processor = processor
        self.label2id = label2id

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        # Load image
        image = Image.open(image_path).convert("RGB")

        # OCR
        words, boxes = get_ocr_words_boxes(image_path)

        # Get matching annotation (Option B mapping)
        annotation_key = self.annotation_keys[idx]
        annotation_data = self.annotations_raw[annotation_key]

        # Generate BIO labels
        bio_labels = generate_bio_labels(words, annotation_data)

        encoding = self.processor(
            images=image,
            text=words,
            boxes=boxes,
            word_labels=[self.label2id[label] for label in bio_labels],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Remove batch dimension
        item = {k: v.squeeze(0) for k, v in encoding.items()}
        return item

In [21]:
dataset = HospitalDataset(
    image_paths=ordered_image_paths,
    annotations_raw=annotations_raw,
    annotation_keys=annotation_keys,
    processor=processor,
    label2id=label2id
)

print("Dataset size:", len(dataset))

Dataset size: 30


In [22]:
sample = dataset[0]

for k, v in sample.items():
    print(k, v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
pixel_values torch.Size([3, 224, 224])


In [23]:
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

print("Batches per epoch:", len(train_loader))

Batches per epoch: 15


In [24]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Using device:", device)

Using device: cuda


In [26]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [55]:
import torch.nn as nn

# Initialize all weights as 1
class_weights = torch.ones(len(label_list)).to(device)

# Down-weight the O class heavily
class_weights[label2id["O"]] = 0.05

loss_fct = nn.CrossEntropyLoss(
    weight=class_weights,
    ignore_index=-100
)

print("Class weights:", class_weights)

Class weights: tensor([0.0500, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0')


In [57]:
from tqdm import tqdm

model.train()

EPOCHS = 5

for epoch in range(EPOCHS):
    total_loss = 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch in loop:
        batch = {k: v.to(device) for k, v in batch.items()}
    
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            bbox=batch["bbox"],
            pixel_values=batch["pixel_values"]
        )
    
        logits = outputs.logits
    
        loss = loss_fct(
            logits.view(-1, logits.shape[-1]),
            batch["labels"].view(-1)
        )
    
        total_loss += loss.item()
    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"\nEpoch {epoch+1} Average Loss: {avg_loss:.4f}")

Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [02:33<00:00, 10.24s/it, loss=0.804]  



Epoch 1 Average Loss: 0.0817


Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [01:32<00:00,  6.14s/it, loss=0.0134]



Epoch 2 Average Loss: 0.0902


Epoch 3: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [01:31<00:00,  6.10s/it, loss=0.354]  



Epoch 3 Average Loss: 0.0763


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [01:31<00:00,  6.11s/it, loss=0.00887]



Epoch 4 Average Loss: 0.2442


Epoch 5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [01:31<00:00,  6.11s/it, loss=0.00919]


Epoch 5 Average Loss: 0.0833





In [58]:
model.eval()

LayoutLMv3ForTokenClassification(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder

In [60]:
for i, label in enumerate(pred_labels):
    if label != "O":
        print(i, label)

In [61]:
with torch.no_grad():
    sample = dataset[0]

    input_ids = sample["input_ids"].unsqueeze(0).to(device)
    attention_mask = sample["attention_mask"].unsqueeze(0).to(device)
    bbox = sample["bbox"].unsqueeze(0).to(device)
    pixel_values = sample["pixel_values"].unsqueeze(0).to(device)

    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        bbox=bbox,
        pixel_values=pixel_values
    )

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)

In [62]:
pred_ids = predictions.squeeze(0).cpu().tolist()
pred_labels = [id2label[id] for id in pred_ids]

print(pred_labels[:50])

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [64]:
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering

processor = LayoutLMv3Processor.from_pretrained(
    "microsoft/layoutlmv3-base",
    apply_ocr=False
)

model = LayoutLMv3ForQuestionAnswering.from_pretrained(
    "microsoft/layoutlmv3-base"
)

model.to(device)

Loading weights:   0%|          | 0/212 [00:00<?, ?it/s]

[1mLayoutLMv3ForQuestionAnswering LOAD REPORT[0m from: microsoft/layoutlmv3-base
Key                                | Status     | 
-----------------------------------+------------+-
layoutlmv3.embeddings.position_ids | UNEXPECTED | 
qa_outputs.dense.bias              | MISSING    | 
qa_outputs.dense.weight            | MISSING    | 
qa_outputs.out_proj.weight         | MISSING    | 
qa_outputs.out_proj.bias           | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


LayoutLMv3ForQuestionAnswering(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder(


In [95]:
field_questions = {
    "hospital_name": "What is the hospital name?",
    "bill_number": "What is the bill number?",
    "bill_date": "What is the bill date?",
    "admit_date": "What is the admit date?",
    "discharge_date": "What is the discharge date?",
    "mrn": "What is the MRN number?",
    "total_amount": "What is the total amount?"
}

In [96]:
def normalize_token(t):
    return "".join(e.lower() for e in t if e.isalnum())

def find_answer_span(words, answer_text):
    if answer_text is None:
        return None, None

    words_norm = [normalize_token(w) for w in words]
    answer_tokens = answer_text.split()
    answer_norm = [normalize_token(w) for w in answer_tokens]

    for i in range(len(words_norm) - len(answer_norm) + 1):
        if words_norm[i:i+len(answer_norm)] == answer_norm:
            return i, i + len(answer_norm) - 1

    return None, None

In [101]:
ANNOTATION_PATH = "/kaggle/input/datasets/yamunaaab/final-json/json.txt"

with open(ANNOTATION_PATH, "r") as f:
    raw = f.read()

lines = raw.split("\n")

fixed_lines = []
skip_next = False

for i in range(len(lines)):
    if skip_next:
        skip_next = False
        continue

    line = lines[i]

    # Detect filename split across lines
    if line.strip().startswith('"') and line.strip().endswith('.jpg'):
        # Merge with next non-empty line
        next_line = lines[i+2]  # skip empty line in between
        merged = line.strip() + '": {'
        fixed_lines.append("  " + merged)
        skip_next = True  # skip empty line
    elif line.strip() == '": {':
        continue
    else:
        fixed_lines.append(line)

fixed_raw = "\n".join(fixed_lines)

# Now try loading
annotations_raw = json.loads(fixed_raw)

print("Loaded entries:", len(annotations_raw))
print(list(annotations_raw.keys())[:5])

Loaded entries: 29
['1981063-2.jpg', '2042173-10.jpg', '2042485-05.jpg', '2048528-06.jpg', '2090022-final bill-1.jpg']


In [102]:
with open("/kaggle/working/clean_annotations.json", "w") as f:
    json.dump(annotations_raw, f, indent=2)

print("Clean JSON saved.")

Clean JSON saved.


In [108]:
from torch.utils.data import Dataset
from PIL import Image
import os
import torch

class HospitalQADataset(Dataset):

    def __init__(self, image_paths, annotations_raw, processor):
        self.samples = []
        self.processor = processor

        for image_path in image_paths:
            filename = os.path.basename(image_path)
            annotation_data = annotations_raw.get(filename)

            if annotation_data is None:
                continue

            words, boxes = get_ocr_words_boxes(image_path)

            for field, question in field_questions.items():
                value = annotation_data.get(field)

                if value is None:
                    continue

                start_idx, end_idx = find_answer_span(words, value)

                if start_idx is None:
                    continue

                self.samples.append({
                    "image_path": image_path,
                    "question": question,
                    "words": words,
                    "boxes": boxes,
                    "start_word": start_idx,
                    "end_word": end_idx
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        image = Image.open(sample["image_path"]).convert("RGB")

        encoding = self.processor(
            images=image,
            text=sample["question"],
            text_pair=sample["words"],
            boxes=sample["boxes"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        word_ids = encoding.word_ids(batch_index=0)

        start_position = 0
        end_position = 0

        for token_idx, word_id in enumerate(word_ids):
            if word_id == sample["start_word"] and start_position == 0:
                start_position = token_idx
            if word_id == sample["end_word"]:
                end_position = token_idx

        encoding["start_positions"] = torch.tensor(start_position)
        encoding["end_positions"] = torch.tensor(end_position)

        return {k: v.squeeze(0) if isinstance(v, torch.Tensor) else v
                for k, v in encoding.items()}

In [111]:
def __len__(self):
    return len(self.samples)

def __getitem__(self, idx):
    sample = self.samples[idx]

    image = Image.open(sample["image_path"]).convert("RGB")

    encoding = self.processor(
        images=image,
        text=sample["question"],
        text_pair=sample["words"],
        boxes=sample["boxes"],
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    word_ids = encoding.word_ids(batch_index=0)

    start_position = None
    end_position = None

    for token_idx, word_id in enumerate(word_ids):
        if word_id == sample["start_word"] and start_position is None:
            start_position = token_idx
        if word_id == sample["end_word"]:
            end_position = token_idx

    if start_position is None:
        start_position = 0
    if end_position is None:
        end_position = 0

    encoding["start_positions"] = torch.tensor(start_position)
    encoding["end_positions"] = torch.tensor(end_position)

    return {k: v.squeeze(0) if isinstance(v, torch.Tensor) else v
            for k, v in encoding.items()}

In [113]:
qa_dataset = HospitalQADataset(
    ordered_image_paths,
    annotations_raw,
    processor
)

print("Total QA samples:", len(qa_dataset))

Total QA samples: 76


In [114]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    qa_dataset,
    batch_size=4,
    shuffle=True
)

print("Batches:", len(train_loader))

Batches: 19


In [115]:
from transformers import LayoutLMv3ForQuestionAnswering

model = LayoutLMv3ForQuestionAnswering.from_pretrained(
    "microsoft/layoutlmv3-base"
)

model.to(device)

Loading weights:   0%|          | 0/212 [00:00<?, ?it/s]

[1mLayoutLMv3ForQuestionAnswering LOAD REPORT[0m from: microsoft/layoutlmv3-base
Key                                | Status     | 
-----------------------------------+------------+-
layoutlmv3.embeddings.position_ids | UNEXPECTED | 
qa_outputs.dense.bias              | MISSING    | 
qa_outputs.dense.weight            | MISSING    | 
qa_outputs.out_proj.weight         | MISSING    | 
qa_outputs.out_proj.bias           | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


LayoutLMv3ForQuestionAnswering(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder(


In [116]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [117]:
from tqdm import tqdm

EPOCHS = 3

model.train()

for epoch in range(EPOCHS):
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch in loop:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss

        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1} Avg Loss:", total_loss / len(train_loader))

Epoch 1: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 19/19 [00:17<00:00,  1.08it/s, loss=4.11]


Epoch 1 Avg Loss: 5.140728975597181


Epoch 2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 19/19 [00:18<00:00,  1.05it/s, loss=3.11]


Epoch 2 Avg Loss: 3.4224594768724943


Epoch 3: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 19/19 [00:18<00:00,  1.03it/s, loss=3.5] 

Epoch 3 Avg Loss: 2.7847067243174504





In [123]:
model.eval()

raw_sample = qa_dataset.samples[0]

image = Image.open(raw_sample["image_path"]).convert("RGB")

encoding = processor(
    images=image,
    text=raw_sample["question"],
    text_pair=raw_sample["words"],
    boxes=raw_sample["boxes"],
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

sequence_ids = encoding.sequence_ids(0)

encoding = {k: v.to(device) for k, v in encoding.items()}

with torch.no_grad():
    outputs = model(**encoding)

start_logits = outputs.start_logits.squeeze(0)
end_logits = outputs.end_logits.squeeze(0)

# Mask non-context tokens
for i, seq_id in enumerate(sequence_ids):
    if seq_id != 1:
        start_logits[i] = -1e9
        end_logits[i] = -1e9

start_pred = torch.argmax(start_logits).item()
end_pred = torch.argmax(end_logits).item()

tokens = processor.tokenizer.convert_ids_to_tokens(
    encoding["input_ids"].squeeze(0)
)

predicted_tokens = tokens[start_pred:end_pred+1]
predicted_text = processor.tokenizer.convert_tokens_to_string(predicted_tokens)

print("Question:", raw_sample["question"])
print("Predicted text:", predicted_text)

Question: What is the hospital name?
Predicted text:  SUNSHINE HOSPITALS


In [124]:
input_ids = batch["input_ids"].squeeze(0)
tokens = processor.tokenizer.convert_ids_to_tokens(input_ids)

predicted_tokens = tokens[start_pred:end_pred+1]

print("Predicted tokens:", predicted_tokens)

predicted_text = processor.tokenizer.convert_tokens_to_string(predicted_tokens)

print("Predicted text:", predicted_text)

Predicted tokens: ['Ä SUN', 'SH', 'INE', 'Ä H', 'OSP', 'IT', 'ALS']
Predicted text:  SUNSHINE HOSPITALS


In [125]:
model.save_pretrained("/kaggle/working/layoutlmv3_qa_model")
processor.save_pretrained("/kaggle/working/layoutlmv3_qa_model")

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

['/kaggle/working/layoutlmv3_qa_model/processor_config.json']

In [127]:
def extract_field(image_path, field_name):
    question = field_questions[field_name]

    words, boxes = get_ocr_words_boxes(image_path)

    image = Image.open(image_path).convert("RGB")

    encoding = processor(
        images=image,
        text=question,
        text_pair=words,
        boxes=boxes,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

    sequence_ids = encoding.sequence_ids(0)
    encoding = {k: v.to(device) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)

    start_logits = outputs.start_logits.squeeze(0)
    end_logits = outputs.end_logits.squeeze(0)

    for i, seq_id in enumerate(sequence_ids):
        if seq_id != 1:
            start_logits[i] = -1e9
            end_logits[i] = -1e9

    start_pred = torch.argmax(start_logits).item()
    end_pred = torch.argmax(end_logits).item()

    tokens = processor.tokenizer.convert_ids_to_tokens(
        encoding["input_ids"].squeeze(0)
    )

    predicted_tokens = tokens[start_pred:end_pred+1]
    return processor.tokenizer.convert_tokens_to_string(predicted_tokens).strip()

In [128]:
import re

def normalize_text(s):
    s = s.lower()
    s = re.sub(r'[^a-z0-9]', '', s)
    return s

def compute_f1(pred, truth):
    pred_tokens = normalize_text(pred)
    truth_tokens = normalize_text(truth)

    if pred_tokens == truth_tokens:
        return 1.0

    common = set(pred_tokens) & set(truth_tokens)
    if len(common) == 0:
        return 0.0

    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(truth_tokens)
    return 2 * (precision * recall) / (precision + recall)

In [129]:
model.eval()

exact_matches = 0
total = 0
f1_total = 0

for raw_sample in qa_dataset.samples:

    image_path = raw_sample["image_path"]
    question = raw_sample["question"]
    true_answer = " ".join(
        raw_sample["words"][raw_sample["start_word"]:raw_sample["end_word"]+1]
    )

    predicted = extract_field(image_path, 
                              [k for k,v in field_questions.items() 
                               if v == question][0])

    total += 1

    if normalize_text(predicted) == normalize_text(true_answer):
        exact_matches += 1

    f1_total += compute_f1(predicted, true_answer)

exact_match_score = exact_matches / total
avg_f1 = f1_total / total

print("Total samples:", total)
print("Exact Match:", exact_match_score)
print("Average F1:", avg_f1)

Total samples: 76
Exact Match: 0.21052631578947367
Average F1: 0.3363177029845176
