In [None]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("nielsr/funsd")

In [None]:
pip install transformers datasets torch pillow pytesseract


In [None]:
from datasets import load_dataset

dataset = load_dataset("nielsr/funsd")

train_dataset = dataset["train"]
test_dataset = dataset["test"]

print(train_dataset[0])


In [None]:
LABELS = [
    "O",
    "B-QUESTION", "I-QUESTION",
    "B-ANSWER", "I-ANSWER",
    "B-HEADER", "I-HEADER"
]

label2id = {label: i for i, label in enumerate(LABELS)}
id2label = {i: label for label, i in label2id.items()}


In [None]:
from transformers import LayoutLMv3Processor
import torch # Import torch for .squeeze()

processor = LayoutLMv3Processor.from_pretrained(
    "microsoft/layoutlmv3-base",
    apply_ocr=False
)

def preprocess(example):
    image = example["image"].convert("RGB") # Convert image to RGB
    encoding = processor(
        image,
        example["words"],
        boxes=example["bboxes"],
        word_labels=example["ner_tags"],
        truncation=True,
        padding="max_length",
        return_tensors="pt" # Ensure tensors are returned
    )

    # The processor returns tensors with a batch dimension of 1 for a single input.
    # We need to remove this for each individual example before batching by Trainer's collator.
    for k, v in encoding.items():
        if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[0] == 1:
            encoding[k] = v.squeeze(0)

    return encoding

encoded_train = train_dataset.map(
    preprocess,
    batched=False,
    remove_columns=train_dataset.column_names
)

encoded_test = test_dataset.map(
    preprocess,
    batched=False,
    remove_columns=test_dataset.column_names
)

In [None]:
from transformers import LayoutLMv3ForTokenClassification
from transformers import Trainer, TrainingArguments, DefaultDataCollator

model = LayoutLMv3ForTokenClassification.from_pretrained(
    "microsoft/layoutlmv3-base",
    num_labels=len(LABELS),
    id2label=id2label,
    label2id=label2id
)

training_args = TrainingArguments(
    output_dir="./funsd_model",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    eval_strategy="steps",
    logging_steps=100,
    save_steps=500,
    save_total_limit=2,
    report_to="none"
)

data_collator = DefaultDataCollator()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train,
    eval_dataset=encoded_test,
    data_collator=data_collator # Use DefaultDataCollator for preprocessed data
)

trainer.train()

In [None]:
import torch

def infer(example):
    model.eval()

    # Convert the image to RGB before processing, similar to the preprocess function
    image_rgb = example["image"].convert("RGB")

    encoding = processor(
        image_rgb,
        example["words"],
        boxes=example["bboxes"],
        return_tensors="pt",
        truncation=True,
        padding="max_length"
    )

    with torch.no_grad():
        outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()

    results = []
    for word, label_id in zip(example["words"], predictions):
        results.append({
            "word": word,
            "label": id2label[label_id]
        })

    return results

In [None]:

import torch
from PIL import Image
import numpy as np

def infer(example):
    model.eval()

    image = example["image"]

    # FIX: ensure image is RGB (3 channels)
    if isinstance(image, Image.Image):
        image = image.convert("RGB")
    elif isinstance(image, np.ndarray):
        if image.ndim == 2:  # grayscale
            image = np.stack([image]*3, axis=-1)
        image = Image.fromarray(image).convert("RGB")

    encoding = processor(
        image,
        example["words"],
        boxes=example["bboxes"],
        return_tensors="pt",
        truncation=True,
        padding="max_length"
    )

    with torch.no_grad():
        outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()

    results = []
    for word, label_id in zip(example["words"], predictions):
        results.append((word, id2label[label_id]))

    return results


# RUN THIS
sample = test_dataset[0]
predictions = infer(sample)

for w, l in predictions:
    print(w, "->", l)
