In [2]:
from datasets import load_dataset

train = load_dataset("naver-clova-ix/cord-v2", split="train")
test = load_dataset("naver-clova-ix/cord-v2", split="test")
validation = load_dataset("naver-clova-ix/cord-v2", split="validation")

### Find out all the labels in the CORD v2 dataset

In [3]:
import json
from collections import Counter


all_labels = []

for sample in train["ground_truth"]:
    sample_info = json.loads(sample)["valid_line"]
    for val in sample_info:
        all_labels.append(val["category"])

label_dict = Counter(all_labels)
label_dict, len(label_dict)

(Counter({'menu.nm': 2100,
          'menu.price': 2093,
          'menu.cnt': 1903,
          'total.total_price': 784,
          'menu.unitprice': 621,
          'sub_total.subtotal_price': 541,
          'total.cashprice': 526,
          'total.changeprice': 502,
          'sub_total.tax_price': 362,
          'menu.sub.nm': 319,
          'total.menuqty_cnt': 233,
          'menu.sub.cnt': 142,
          'total.creditcardprice': 124,
          'menu.sub.price': 123,
          'sub_total.etc': 112,
          'sub_total.service_price': 98,
          'menu.discountprice': 97,
          'menu.num': 94,
          'sub_total.discount_price': 68,
          'total.emoneyprice': 47,
          'total.menutype_cnt': 43,
          'total.total_etc': 28,
          'menu.sub.unitprice': 14,
          'menu.etc': 6,
          'sub_total.othersvc_price': 2,
          'menu.vatyn': 2,
          'void_menu.nm': 1,
          'void_menu.price': 1,
          'menu.itemsubtotal': 1}),
 29)

### Some labels are very less so replacing them with the neutral label 'O'.

In [4]:
replacing_labels = {}

for key, val in label_dict.items():
    if val < 100:
        replacing_labels[key] = 'O'

replacing_labels, len(replacing_labels.items())

({'sub_total.service_price': 'O',
  'total.menutype_cnt': 'O',
  'sub_total.discount_price': 'O',
  'total.total_etc': 'O',
  'menu.num': 'O',
  'menu.discountprice': 'O',
  'total.emoneyprice': 'O',
  'menu.sub.unitprice': 'O',
  'void_menu.nm': 'O',
  'void_menu.price': 'O',
  'sub_total.othersvc_price': 'O',
  'menu.vatyn': 'O',
  'menu.itemsubtotal': 'O',
  'menu.etc': 'O'},
 14)

In [5]:
print(type(train[1]["ground_truth"]))

<class 'str'>


In [6]:
# Replace the categories
def map_labels(sample):
    sample_info = json.loads(sample["ground_truth"])["valid_line"]
    for idx_info, val in enumerate(sample_info):
        if val["category"] in replacing_labels:
            sample_info[idx_info]["category"] = 'O'
    
    updated_valid_line = {"valid_line" : sample_info}
    sample["ground_truth"] = json.dumps(updated_valid_line)
    return sample
 

In [7]:
train = train.map(map_labels)
validation = validation.map(map_labels)
test = test.map(map_labels)

In [8]:
all_labels = []

for sample in train["ground_truth"]:
    sample_info = json.loads(sample)["valid_line"]
    for val in sample_info:
        all_labels.append(val["category"])

label_dict = Counter(all_labels)
label_dict, len(label_dict)

(Counter({'menu.nm': 2100,
          'menu.price': 2093,
          'menu.cnt': 1903,
          'total.total_price': 784,
          'menu.unitprice': 621,
          'sub_total.subtotal_price': 541,
          'total.cashprice': 526,
          'O': 502,
          'total.changeprice': 502,
          'sub_total.tax_price': 362,
          'menu.sub.nm': 319,
          'total.menuqty_cnt': 233,
          'menu.sub.cnt': 142,
          'total.creditcardprice': 124,
          'menu.sub.price': 123,
          'sub_total.etc': 112}),
 16)

In [9]:
labels = list(set(label_dict))
labels

['O',
 'menu.nm',
 'sub_total.tax_price',
 'sub_total.etc',
 'menu.cnt',
 'menu.sub.nm',
 'sub_total.subtotal_price',
 'total.cashprice',
 'total.changeprice',
 'total.menuqty_cnt',
 'menu.unitprice',
 'total.creditcardprice',
 'menu.sub.cnt',
 'menu.sub.price',
 'total.total_price',
 'menu.price']

In [10]:
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for idx, label in enumerate(labels)}
print(label2id)
print(id2label)

{'O': 0, 'menu.nm': 1, 'sub_total.tax_price': 2, 'sub_total.etc': 3, 'menu.cnt': 4, 'menu.sub.nm': 5, 'sub_total.subtotal_price': 6, 'total.cashprice': 7, 'total.changeprice': 8, 'total.menuqty_cnt': 9, 'menu.unitprice': 10, 'total.creditcardprice': 11, 'menu.sub.cnt': 12, 'menu.sub.price': 13, 'total.total_price': 14, 'menu.price': 15}
{0: 'O', 1: 'menu.nm', 2: 'sub_total.tax_price', 3: 'sub_total.etc', 4: 'menu.cnt', 5: 'menu.sub.nm', 6: 'sub_total.subtotal_price', 7: 'total.cashprice', 8: 'total.changeprice', 9: 'total.menuqty_cnt', 10: 'menu.unitprice', 11: 'total.creditcardprice', 12: 'menu.sub.cnt', 13: 'menu.sub.price', 14: 'total.total_price', 15: 'menu.price'}


In [11]:
from os import listdir
from torch.utils.data import Dataset
import torch
from PIL import Image

def normalize_bbox(bbox, width, height):
            return [
                int(1000 * (bbox[0] / width)),
                int(1000 * (bbox[1] / height)),
                int(1000 * (bbox[2] / width)),
                int(1000 * (bbox[3] / height)),
            ]

def create_bbox(quad_dict, width, height):
    x_coords = [quad_dict['x1'], quad_dict['x2'], quad_dict['x3'], quad_dict['x4']]
    y_coords = [quad_dict['y1'], quad_dict['y2'], quad_dict['y3'], quad_dict['y4']]

    # Get min/max for rectangular bbox
    x_min = min(x_coords)
    y_min = min(y_coords)
    x_max = max(x_coords)
    y_max = max(y_coords)

    # Ensure valid bounding box (e.g., x_min <= x_max, y_min <= y_max)
    # This can happen if OCR is poor or annotations are malformed
    if x_min > x_max: x_min, x_max = x_max, x_min
    if y_min > y_max: y_min, y_max = y_max, y_min

    # Normalize bounding box
    normalized_bbox = normalize_bbox([x_min, y_min, x_max, y_max], width, height)

    return normalized_bbox

class CORDDataset(Dataset):
    """CORD dataset."""

    def __init__(self, ground_truth, image, processor=None, max_length=512):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.ground_truth = ground_truth
        self.image = image
        self.processor = processor

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        # first, take an image
        image = self.image[idx]

        width, height = image.size

        sample = json.loads(self.ground_truth[idx])["valid_line"]

        words = []
        boxes = []
        word_labels = []

        for val in sample:
            for item in val["words"]:
                words.append(item["text"])
                normalized_box = create_bbox(item["quad"], width, height)
                boxes.append(normalized_box)
                word_labels.append(val["category"])

        assert len(words) == len(boxes) == len(word_labels)
        
        word_labels = [label2id[label] for label in word_labels]
        # use processor to prepare everything
        encoded_inputs = self.processor(image, words, boxes=boxes, word_labels=word_labels, 
                                        padding="max_length", truncation=True, 
                                        return_tensors="pt")
        
        # remove batch dimension
        for k,v in encoded_inputs.items():
          encoded_inputs[k] = v.squeeze()
      
        return encoded_inputs

In [12]:
from transformers import LayoutLMv3Processor

processor = LayoutLMv3Processor.from_pretrained("nielsr/layoutlmv3-finetuned-cord", apply_ocr=False)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RobertaTokenizer'. 
The class this function is called from is 'LayoutLMv3TokenizerFast'.


In [13]:
train_dataset = CORDDataset(ground_truth=train["ground_truth"],
                            image=train["image"], 
                            processor=processor)
validation_dataset = CORDDataset(ground_truth=validation["ground_truth"],
                            image=validation["image"], 
                            processor=processor)
test_dataset = CORDDataset(ground_truth=test["ground_truth"],
                            image=test["image"], 
                            processor=processor)

In [14]:
encoding = train_dataset[0]
encoding.keys()

KeysView({'input_ids': tensor([    0,   112,  3023,   234,  8209,  4746,   710,   163,  3644,  3337,
            6,   151,   112,  3023,   163,   428,   330, 30244,   718,   234,
         8209, 10529,     6,   151,   112,  3023, 22124,  3609,  5113,  2141,
        41161,  2908,     6,   151,   112,  3023,  8761, 22059, 16577,   706,
            6,   151,   112,  3023,   234,  8209,  5847,   424, 17340,  2186,
         1510,     6,   151,   155,  3023,  3130,  8761, 16577,   321,   112,
         3023, 24349,  1628,  4141,  3620,     6,   151,   112,  3023,  8761,
        16577,   504,     6,   151,   112,  3023,  8761,  5726,  1132,     6,
          151,   112,  3023,  5847,   424,  3296,   853,   163,  3644,  5663,
            6,   151,   132,  3023,   255, 17421, 16603,  2590,  2491,     6,
          151,   132,  3023,  9188,  2379, 16603,  2590,  2491,     6,   151,
          112,  3023,   255, 17421,  5477,   368,   287,   179,   843,     6,
          151,     4,   112,  3023,   234

In [15]:
for k,v in encoding.items():
  print(k, v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
pixel_values torch.Size([3, 224, 224])


In [16]:
print(processor.tokenizer.decode(encoding['input_ids']))

<s> 1 x Nasi Campur Bali 75,000 1 x Bbk Bengil Nasi 125,000 1 x MilkShake Starwb 37,000 1 x Ice Lemon Tea 24,000 1 x Nasi Ayam Dewata 70,000 3 x Free Ice Tea 0 1 x Organic Green Sa 65,000 1 x Ice Tea 18,000 1 x Ice Orange 29,000 1 x Ayam Suir Bali 85,000 2 x Tahu Goreng 36,000 2 x Tempe Goreng 36,000 1 x Tahu Telor Asin 40,000. 1 x Nasi Goreng Samb 70,000 3 x Bbk Panggang Sam 366,000 1 x Ayam Sambal Hija 92,000 2 x Hot Tea 44,000 1 x Ice Kopi 32,000 1 x Tahu Telor Asin 40,000 1 x Free Ice Tea 0 1 x Bebek Street 44,000 1 x Ice Tea Tawar 18,000 Sub-Total 1,346,000 Service 100,950 PB1 144,695 Rounding -45 Grand Total 1,591,600</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [17]:
[id2label[label] for label in encoding['labels'].tolist() if label != -100]

['menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm

In [18]:
for id, label in zip(encoding['input_ids'][:30], encoding['labels'][:30]):
  print(processor.tokenizer.decode([id]), label.item())

<s> -100
 1 4
 x 4
 N 1
asi -100
 Camp 1
ur -100
 B 1
ali -100
 75 15
, -100
000 -100
 1 4
 x 4
 B 1
b -100
k -100
 Beng 1
il -100
 N 1
asi -100
 125 15
, -100
000 -100
 1 4
 x 4
 Milk 1
Sh -100
ake -100
 Star 1


In [19]:
from torch.utils.data import DataLoader

BATCH_SIZE=2

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

train_dataloader, validation_dataloader, test_dataloader

(<torch.utils.data.dataloader.DataLoader at 0x15109beefe0>,
 <torch.utils.data.dataloader.DataLoader at 0x15109befbb0>,
 <torch.utils.data.dataloader.DataLoader at 0x15109befb20>)

### Model

In [20]:
from transformers import AutoModelForTokenClassification
import torch
from tqdm.notebook import tqdm

model = AutoModelForTokenClassification.from_pretrained('nielsr/layoutlmv3-finetuned-cord', num_labels=len(labels), ignore_mismatched_sizes=True)

Some weights of LayoutLMv3ForTokenClassification were not initialized from the model checkpoint at nielsr/layoutlmv3-finetuned-cord and are newly initialized because the shapes did not match:
- classifier.out_proj.bias: found shape torch.Size([61]) in the checkpoint and torch.Size([16]) in the model instantiated
- classifier.out_proj.weight: found shape torch.Size([61, 768]) in the checkpoint and torch.Size([16, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
device="cpu"

In [22]:
sequence_length = 512
image_channels = 3
image_height = 224
image_width = 224

# Create dummy input tensors with the CORRECT dtypes and move them to the device
# input_ids, attention_mask, bbox, and token_type_ids should be torch.long
# pixel_values should be torch.float

dummy_input_ids = torch.randint(0, model.config.vocab_size, (BATCH_SIZE, sequence_length), dtype=torch.long, device=device)
dummy_attention_mask = torch.ones((BATCH_SIZE, sequence_length), dtype=torch.long, device=device)
dummy_bbox = torch.randint(0, 1000, (BATCH_SIZE, sequence_length, 4), dtype=torch.long, device=device)
dummy_pixel_values = torch.randn((BATCH_SIZE, image_channels, image_height, image_width), dtype=torch.float, device=device)

# Crucially, LayoutLMv3 models often expect `token_type_ids`.
# Even if your specific forward pass doesn't show it explicitly in your snippet,
# the traceback shows `token_type_ids=token_type_ids` deeper in the call stack.
# It's safest to include it.
dummy_labels = torch.zeros((BATCH_SIZE, sequence_length), dtype=torch.long, device=device)


# Provide these dummy tensors to torchinfo using `input_data`
# This ensures that torchinfo passes inputs with the correct names and dtypes
input_data = {
    "input_ids": dummy_input_ids,
    "attention_mask": dummy_attention_mask,
    "bbox": dummy_bbox,
    "labels": dummy_labels,
    "pixel_values": dummy_pixel_values, # Add this
    # If your model expects 'labels' for `forward` during summary, you'd add:
    # "labels": torch.randint(0, num_classes, (BATCH_SIZE,), dtype=torch.long, device=device)
    # or for token classification:
    # "labels": torch.randint(0, num_labels, (BATCH_SIZE, sequence_length), dtype=torch.long, device=device)
}


In [23]:
from torchinfo import summary

# input_ids torch.Size([512])
# attention_mask torch.Size([512])
# bbox torch.Size([512, 4])
# labels torch.Size([512])
# pixel_values torch.Size([3, 224, 224])


summary(model, input_data=input_data, depth=4, col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"])



Layer (type (var_name))                                                Input Shape          Output Shape         Param #              Trainable
LayoutLMv3ForTokenClassification (LayoutLMv3ForTokenClassification)    --                   [2, 512, 16]         --                   True
├─LayoutLMv3Model (layoutlmv3)                                         [2, 512]             [2, 709, 768]        152,064              True
│    └─LayoutLMv3TextEmbeddings (embeddings)                           --                   [2, 512, 768]        --                   True
│    │    └─Embedding (word_embeddings)                                [2, 512]             [2, 512, 768]        38,603,520           True
│    │    └─Embedding (token_type_embeddings)                          [2, 512]             [2, 512, 768]        768                  True
│    │    └─Embedding (position_embeddings)                            [2, 512]             [2, 512, 768]        394,752              True
│    │    └─Embedding 

In [24]:
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("microsoft/layoutlmv3-base", torch_dtype="auto")

In [25]:
from torchinfo import summary

summary(model)

Layer (type:depth-idx)                                  Param #
LayoutLMv3Model                                         152,064
├─LayoutLMv3TextEmbeddings: 1-1                         --
│    └─Embedding: 2-1                                   38,603,520
│    └─Embedding: 2-2                                   768
│    └─LayerNorm: 2-3                                   1,536
│    └─Dropout: 2-4                                     --
│    └─Embedding: 2-5                                   394,752
│    └─Embedding: 2-6                                   131,072
│    └─Embedding: 2-7                                   131,072
│    └─Embedding: 2-8                                   131,072
│    └─Embedding: 2-9                                   131,072
├─LayoutLMv3PatchEmbeddings: 1-2                        --
│    └─Conv2d: 2-10                                     590,592
├─Dropout: 1-3                                          --
├─LayerNorm: 1-4                                        1,536
├

### Evaluation

In [26]:
encoding = test_dataset[0]
processor.tokenizer.decode(encoding['input_ids'])

'<s> 901016 -TICKET CP 2 60.000 60.000 TOTAL DISC $ -60.000 TAX 5.455 Subtotal 60.000 TOTAL 60.000 (Qty 2.00 EDC CIMB NIAGA No: xx7730 60.000</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [27]:
ground_truth_labels = [id2label[label] for label in encoding['labels'].squeeze().tolist() if label != -100]
print(ground_truth_labels)

['O', 'menu.nm', 'menu.nm', 'menu.cnt', 'menu.price', 'O', 'O', 'O', 'O', 'O', 'sub_total.tax_price', 'sub_total.tax_price', 'sub_total.subtotal_price', 'sub_total.subtotal_price', 'total.total_price', 'total.total_price', 'total.menuqty_cnt', 'total.menuqty_cnt', 'total.creditcardprice', 'total.creditcardprice', 'total.creditcardprice', 'total.creditcardprice', 'total.creditcardprice', 'total.creditcardprice']
