In [1]:
from datasets import load_dataset

train = load_dataset("naver-clova-ix/cord-v2", split="train")
test = load_dataset("naver-clova-ix/cord-v2", split="test")
val = load_dataset("naver-clova-ix/cord-v2", split="validation")

  from .autonotebook import tqdm as notebook_tqdm


### Find out all the labels in the CORD v2 dataset

In [2]:
import json
from collections import Counter


all_labels = []

for sample in train["ground_truth"]:
    sample_info = json.loads(sample)["valid_line"]
    for val in sample_info:
        all_labels.append(val["category"])

label_dict = Counter(all_labels)
label_dict, len(label_dict)

(Counter({'menu.nm': 2100,
          'menu.price': 2093,
          'menu.cnt': 1903,
          'total.total_price': 784,
          'menu.unitprice': 621,
          'sub_total.subtotal_price': 541,
          'total.cashprice': 526,
          'total.changeprice': 502,
          'sub_total.tax_price': 362,
          'menu.sub.nm': 319,
          'total.menuqty_cnt': 233,
          'menu.sub.cnt': 142,
          'total.creditcardprice': 124,
          'menu.sub.price': 123,
          'sub_total.etc': 112,
          'sub_total.service_price': 98,
          'menu.discountprice': 97,
          'menu.num': 94,
          'sub_total.discount_price': 68,
          'total.emoneyprice': 47,
          'total.menutype_cnt': 43,
          'total.total_etc': 28,
          'menu.sub.unitprice': 14,
          'menu.etc': 6,
          'sub_total.othersvc_price': 2,
          'menu.vatyn': 2,
          'void_menu.nm': 1,
          'void_menu.price': 1,
          'menu.itemsubtotal': 1}),
 29)

### Some labels are very less so replacing them with the neutral label 'O'.

In [3]:
replacing_labels = {}

for key, val in label_dict.items():
    if val < 100:
        replacing_labels[key] = 'O'

replacing_labels, len(replacing_labels.items())

({'sub_total.service_price': 'O',
  'total.menutype_cnt': 'O',
  'sub_total.discount_price': 'O',
  'total.total_etc': 'O',
  'menu.num': 'O',
  'menu.discountprice': 'O',
  'total.emoneyprice': 'O',
  'menu.sub.unitprice': 'O',
  'void_menu.nm': 'O',
  'void_menu.price': 'O',
  'sub_total.othersvc_price': 'O',
  'menu.vatyn': 'O',
  'menu.itemsubtotal': 'O',
  'menu.etc': 'O'},
 14)

In [4]:
print(type(train[1]["ground_truth"]))

<class 'str'>


In [5]:
# Replace the categories
def map_labels(sample):
    sample_info = json.loads(sample["ground_truth"])["valid_line"]
    for idx_info, val in enumerate(sample_info):
        if val["category"] in replacing_labels:
            sample_info[idx_info]["category"] = 'O'
    
    updated_valid_line = {"valid_line" : sample_info}
    sample["ground_truth"] = json.dumps(updated_valid_line)
    return sample
 

In [6]:
train = train.map(map_labels)

In [7]:
all_labels = []

for sample in train["ground_truth"]:
    sample_info = json.loads(sample)["valid_line"]
    for val in sample_info:
        all_labels.append(val["category"])

label_dict = Counter(all_labels)
label_dict, len(label_dict)

(Counter({'menu.nm': 2100,
          'menu.price': 2093,
          'menu.cnt': 1903,
          'total.total_price': 784,
          'menu.unitprice': 621,
          'sub_total.subtotal_price': 541,
          'total.cashprice': 526,
          'O': 502,
          'total.changeprice': 502,
          'sub_total.tax_price': 362,
          'menu.sub.nm': 319,
          'total.menuqty_cnt': 233,
          'menu.sub.cnt': 142,
          'total.creditcardprice': 124,
          'menu.sub.price': 123,
          'sub_total.etc': 112}),
 16)

In [8]:
labels = list(set(label_dict))
labels

['menu.nm',
 'total.creditcardprice',
 'sub_total.etc',
 'menu.sub.nm',
 'total.changeprice',
 'menu.sub.cnt',
 'total.cashprice',
 'total.menuqty_cnt',
 'menu.unitprice',
 'sub_total.tax_price',
 'sub_total.subtotal_price',
 'menu.price',
 'total.total_price',
 'menu.cnt',
 'menu.sub.price',
 'O']

In [9]:
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for idx, label in enumerate(labels)}
print(label2id)
print(id2label)

{'menu.nm': 0, 'total.creditcardprice': 1, 'sub_total.etc': 2, 'menu.sub.nm': 3, 'total.changeprice': 4, 'menu.sub.cnt': 5, 'total.cashprice': 6, 'total.menuqty_cnt': 7, 'menu.unitprice': 8, 'sub_total.tax_price': 9, 'sub_total.subtotal_price': 10, 'menu.price': 11, 'total.total_price': 12, 'menu.cnt': 13, 'menu.sub.price': 14, 'O': 15}
{0: 'menu.nm', 1: 'total.creditcardprice', 2: 'sub_total.etc', 3: 'menu.sub.nm', 4: 'total.changeprice', 5: 'menu.sub.cnt', 6: 'total.cashprice', 7: 'total.menuqty_cnt', 8: 'menu.unitprice', 9: 'sub_total.tax_price', 10: 'sub_total.subtotal_price', 11: 'menu.price', 12: 'total.total_price', 13: 'menu.cnt', 14: 'menu.sub.price', 15: 'O'}


In [14]:
from os import listdir
from torch.utils.data import Dataset
import torch
from PIL import Image

def normalize_bbox(bbox, width, height):
            return [
                int(1000 * (bbox[0] / width)),
                int(1000 * (bbox[1] / height)),
                int(1000 * (bbox[2] / width)),
                int(1000 * (bbox[3] / height)),
            ]

def create_bbox(quad_dict, width, height):
    x_coords = [quad_dict['x1'], quad_dict['x2'], quad_dict['x3'], quad_dict['x4']]
    y_coords = [quad_dict['y1'], quad_dict['y2'], quad_dict['y3'], quad_dict['y4']]

    # Get min/max for rectangular bbox
    x_min = min(x_coords)
    y_min = min(y_coords)
    x_max = max(x_coords)
    y_max = max(y_coords)

    # Ensure valid bounding box (e.g., x_min <= x_max, y_min <= y_max)
    # This can happen if OCR is poor or annotations are malformed
    if x_min > x_max: x_min, x_max = x_max, x_min
    if y_min > y_max: y_min, y_max = y_max, y_min

    # Normalize bounding box
    normalized_bbox = normalize_bbox([x_min, y_min, x_max, y_max], width, height)

    return normalized_bbox

class CORDDataset(Dataset):
    """CORD dataset."""

    def __init__(self, ground_truth, image, processor=None, max_length=512):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.ground_truth = ground_truth
        self.image = image
        self.processor = processor

    def __len__(self):
        return len(self.image_file_names)

    def __getitem__(self, idx):
        # first, take an image
        image = self.image[idx]

        width, height = image.size

        sample = json.loads(self.ground_truth[idx])["valid_line"]

        words = []
        boxes = []
        word_labels = []

        for val in sample:
            for item in val["words"]:
                words.append(item["text"])
                normalized_box = create_bbox(item["quad"], width, height)
                boxes.append(normalized_box)
                word_labels.append(val["category"])

        assert len(words) == len(boxes) == len(word_labels)
        
        word_labels = [label2id[label] for label in word_labels]
        # use processor to prepare everything
        encoded_inputs = self.processor(image, words, boxes=boxes, word_labels=word_labels, 
                                        padding="max_length", truncation=True, 
                                        return_tensors="pt")
        
        # remove batch dimension
        for k,v in encoded_inputs.items():
          encoded_inputs[k] = v.squeeze()
      
        return encoded_inputs

In [15]:
from transformers import LayoutLMv3Processor

processor = LayoutLMv3Processor.from_pretrained("nielsr/layoutlmv3-finetuned-cord", apply_ocr=False)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RobertaTokenizer'. 
The class this function is called from is 'LayoutLMv3TokenizerFast'.


In [16]:
train_dataset = CORDDataset(ground_truth=train["ground_truth"],
                            image=train["image"], 
                            processor=processor)

In [17]:
encoding = train_dataset[0]
encoding.keys()

KeysView({'input_ids': tensor([    0,   112,  3023,   234,  8209,  4746,   710,   163,  3644,  3337,
            6,   151,   112,  3023,   163,   428,   330, 30244,   718,   234,
         8209, 10529,     6,   151,   112,  3023, 22124,  3609,  5113,  2141,
        41161,  2908,     6,   151,   112,  3023,  8761, 22059, 16577,   706,
            6,   151,   112,  3023,   234,  8209,  5847,   424, 17340,  2186,
         1510,     6,   151,   155,  3023,  3130,  8761, 16577,   321,   112,
         3023, 24349,  1628,  4141,  3620,     6,   151,   112,  3023,  8761,
        16577,   504,     6,   151,   112,  3023,  8761,  5726,  1132,     6,
          151,   112,  3023,  5847,   424,  3296,   853,   163,  3644,  5663,
            6,   151,   132,  3023,   255, 17421, 16603,  2590,  2491,     6,
          151,   132,  3023,  9188,  2379, 16603,  2590,  2491,     6,   151,
          112,  3023,   255, 17421,  5477,   368,   287,   179,   843,     6,
          151,     4,   112,  3023,   234

In [18]:
for k,v in encoding.items():
  print(k, v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
pixel_values torch.Size([3, 224, 224])


In [19]:
print(processor.tokenizer.decode(encoding['input_ids']))

<s> 1 x Nasi Campur Bali 75,000 1 x Bbk Bengil Nasi 125,000 1 x MilkShake Starwb 37,000 1 x Ice Lemon Tea 24,000 1 x Nasi Ayam Dewata 70,000 3 x Free Ice Tea 0 1 x Organic Green Sa 65,000 1 x Ice Tea 18,000 1 x Ice Orange 29,000 1 x Ayam Suir Bali 85,000 2 x Tahu Goreng 36,000 2 x Tempe Goreng 36,000 1 x Tahu Telor Asin 40,000. 1 x Nasi Goreng Samb 70,000 3 x Bbk Panggang Sam 366,000 1 x Ayam Sambal Hija 92,000 2 x Hot Tea 44,000 1 x Ice Kopi 32,000 1 x Tahu Telor Asin 40,000 1 x Free Ice Tea 0 1 x Bebek Street 44,000 1 x Ice Tea Tawar 18,000 Sub-Total 1,346,000 Service 100,950 PB1 144,695 Rounding -45 Grand Total 1,591,600</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [21]:
[id2label[label] for label in encoding['labels'].tolist() if label != -100]

['menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm

In [22]:
for id, label in zip(encoding['input_ids'][:30], encoding['labels'][:30]):
  print(processor.tokenizer.decode([id]), label.item())

<s> -100
 1 13
 x 13
 N 0
asi -100
 Camp 0
ur -100
 B 0
ali -100
 75 11
, -100
000 -100
 1 13
 x 13
 B 0
b -100
k -100
 Beng 0
il -100
 N 0
asi -100
 125 11
, -100
000 -100
 1 13
 x 13
 Milk 0
Sh -100
ake -100
 Star 0


In [23]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

AttributeError: 'CORDDataset' object has no attribute 'image_file_names'