In [55]:
from datasets import load_dataset

train = load_dataset("naver-clova-ix/cord-v2", split="train")
test = load_dataset("naver-clova-ix/cord-v2", split="test")
val = load_dataset("naver-clova-ix/cord-v2", split="validation")

### Find out all the labels in the CORD v2 dataset

In [56]:
import json
from collections import Counter


all_labels = []

for sample in train["ground_truth"]:
    sample_info = json.loads(sample)["valid_line"]
    for val in sample_info:
        all_labels.append(val["category"])

label_dict = Counter(all_labels)
label_dict, len(label_dict)

(Counter({'menu.nm': 2100,
          'menu.price': 2093,
          'menu.cnt': 1903,
          'total.total_price': 784,
          'menu.unitprice': 621,
          'sub_total.subtotal_price': 541,
          'total.cashprice': 526,
          'total.changeprice': 502,
          'sub_total.tax_price': 362,
          'menu.sub.nm': 319,
          'total.menuqty_cnt': 233,
          'menu.sub.cnt': 142,
          'total.creditcardprice': 124,
          'menu.sub.price': 123,
          'sub_total.etc': 112,
          'sub_total.service_price': 98,
          'menu.discountprice': 97,
          'menu.num': 94,
          'sub_total.discount_price': 68,
          'total.emoneyprice': 47,
          'total.menutype_cnt': 43,
          'total.total_etc': 28,
          'menu.sub.unitprice': 14,
          'menu.etc': 6,
          'sub_total.othersvc_price': 2,
          'menu.vatyn': 2,
          'void_menu.nm': 1,
          'void_menu.price': 1,
          'menu.itemsubtotal': 1}),
 29)

### Some labels are very less so replacing them with the neutral label 'O'.

In [57]:
replacing_labels = {}

for key, val in label_dict.items():
    if val < 100:
        replacing_labels[key] = 'O'

replacing_labels, len(replacing_labels.items())

({'sub_total.service_price': 'O',
  'total.menutype_cnt': 'O',
  'sub_total.discount_price': 'O',
  'total.total_etc': 'O',
  'menu.num': 'O',
  'menu.discountprice': 'O',
  'total.emoneyprice': 'O',
  'menu.sub.unitprice': 'O',
  'void_menu.nm': 'O',
  'void_menu.price': 'O',
  'sub_total.othersvc_price': 'O',
  'menu.vatyn': 'O',
  'menu.itemsubtotal': 'O',
  'menu.etc': 'O'},
 14)

In [58]:
print(type(train[1]["ground_truth"]))

<class 'str'>


In [59]:
# Replace the categories
def map_labels(sample):
    sample_info = json.loads(sample["ground_truth"])["valid_line"]
    for idx_info, val in enumerate(sample_info):
        if val["category"] in replacing_labels:
            sample_info[idx_info]["category"] = 'O'
    
    updated_valid_line = {"valid_line" : sample_info}
    sample["ground_truth"] = json.dumps(updated_valid_line)
    return sample
 

In [60]:
train = train.map(map_labels)

In [61]:
all_labels = []

for sample in train["ground_truth"]:
    sample_info = json.loads(sample)["valid_line"]
    for val in sample_info:
        all_labels.append(val["category"])

label_dict = Counter(all_labels)
label_dict, len(label_dict)

(Counter({'menu.nm': 2100,
          'menu.price': 2093,
          'menu.cnt': 1903,
          'total.total_price': 784,
          'menu.unitprice': 621,
          'sub_total.subtotal_price': 541,
          'total.cashprice': 526,
          'O': 502,
          'total.changeprice': 502,
          'sub_total.tax_price': 362,
          'menu.sub.nm': 319,
          'total.menuqty_cnt': 233,
          'menu.sub.cnt': 142,
          'total.creditcardprice': 124,
          'menu.sub.price': 123,
          'sub_total.etc': 112}),
 16)

In [62]:
labels = list(set(label_dict))
labels

['total.total_price',
 'total.cashprice',
 'menu.sub.cnt',
 'menu.price',
 'total.changeprice',
 'menu.sub.price',
 'total.creditcardprice',
 'O',
 'menu.nm',
 'sub_total.etc',
 'menu.sub.nm',
 'menu.cnt',
 'sub_total.subtotal_price',
 'total.menuqty_cnt',
 'sub_total.tax_price',
 'menu.unitprice']

In [63]:
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for idx, label in enumerate(labels)}
print(label2id)
print(id2label)

{'total.total_price': 0, 'total.cashprice': 1, 'menu.sub.cnt': 2, 'menu.price': 3, 'total.changeprice': 4, 'menu.sub.price': 5, 'total.creditcardprice': 6, 'O': 7, 'menu.nm': 8, 'sub_total.etc': 9, 'menu.sub.nm': 10, 'menu.cnt': 11, 'sub_total.subtotal_price': 12, 'total.menuqty_cnt': 13, 'sub_total.tax_price': 14, 'menu.unitprice': 15}
{0: 'total.total_price', 1: 'total.cashprice', 2: 'menu.sub.cnt', 3: 'menu.price', 4: 'total.changeprice', 5: 'menu.sub.price', 6: 'total.creditcardprice', 7: 'O', 8: 'menu.nm', 9: 'sub_total.etc', 10: 'menu.sub.nm', 11: 'menu.cnt', 12: 'sub_total.subtotal_price', 13: 'total.menuqty_cnt', 14: 'sub_total.tax_price', 15: 'menu.unitprice'}
