<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Donut/RVL-CDIP/Preparing_an_image_classification_dataset_for_Donut.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prepare a document image classification dataset for Donut

In this notebook, I'll show how to prepare a document image classification dataset for Donut. Basically, there are 2 steps:

1. Load an image classification dataset: this could be an existing dataset from the hub, or your own custom dataset, in which case you can use the [ImageFolder](https://huggingface.co/docs/datasets/image_load#imagefolder) feature.
2. Prepare in Donut format: Add a `ground_truth` column to the dataset. Each ground truth is a `gt_parse` target sequence, which follows the format of `{"class" : {class_name}}`, for example, `{"class" : "scientific_report"}` or `{"class" : "presentation"}`.

## Set-up environment

We'll install 🤗 Datasets first.

In [2]:
!pip install -q datasets

## 1. Load an image classification dataset (toy RVL-CDIP in our case, 10 examples per class)

The first step is to load an image classification dataset as a 🤗 [Dataset](https://huggingface.co/docs/datasets/v2.4.0/en/package_reference/main_classes#datasets.Dataset) or 🤗 [DatasetDict](https://huggingface.co/docs/datasets/v2.4.0/en/package_reference/main_classes#datasets.DatasetDict), with 2 columns, namely "image" and "label".

In case you have your own custom data, it's recommended to make use the [ImageFolder](https://huggingface.co/docs/datasets/image_load#imagefolder) feature. This let's you create a 🤗 Dataset easily based on your own local or remote files. Note that you can optionally push your dataset to the hub using [`push_to_hub`](https://huggingface.co/docs/datasets/v2.4.0/en/package_reference/main_classes#datasets.Dataset.push_to_hub),  to reload it afterwards with [`load_dataset`](https://huggingface.co/docs/datasets/v2.4.0/en/package_reference/loading_methods#datasets.load_dataset).

Alternatively (and that's what we're going to do here), is load an existing dataset from the hub. So feel free to skip the code below in case you're using ImageFolder, in which case you can continue at [section 2](#section-2).

We'll create a small subset of [RVL-CDIP](https://paperswithcode.com/dataset/rvl-cdip), an important benchmark for document image classification. As RVL-CDIP is huge (it contains 400,000 images), we create a toy dataset with a train and test split, where each split contains 10 documents per class. As there are 16 classes, this means that both training and test set should contain 160 documents. We'll start from the test set of RVL-CDIP.

We'll do this by filtering the dataset to look for each of the classes, and then select 10 for training and 10 for testing (for each class). Let's load the test set first:



In [None]:
from datasets import load_dataset

dataset = load_dataset("rvl_cdip", split="test")

Downloading builder script:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.24k [00:00<?, ?B/s]



Downloading and preparing dataset rvl_cdip/default (download: 36.12 GiB, generated: 45.21 GiB, post-processed: Unknown size, total: 81.33 GiB) to /root/.cache/huggingface/datasets/rvl_cdip/default/1.0.0/ea410993ed3f5b9744d8616ffbaad5f70a75a21a4233626dd07b3de31d381e53...


Downloading data:   0%|          | 0.00/38.8G [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/13.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/320000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/40000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/40000 [00:00<?, ? examples/s]

Dataset rvl_cdip downloaded and prepared to /root/.cache/huggingface/datasets/rvl_cdip/default/1.0.0/ea410993ed3f5b9744d8616ffbaad5f70a75a21a4233626dd07b3de31d381e53. Subsequent calls will reuse this data.


In [None]:
dataset 

Dataset({
    features: ['image', 'label'],
    num_rows: 40000
})

Next, let's filter the dataset to create lists of 🤗 Dataset objects:

In [None]:
train_datasets, test_datasets = [], []

id2label = {
  0: "letter",
  1: "form",
  2: "email",
  3: "handwritten",
  4: "advertisement",
  5: "scientific_report",
  6: "scientific_publication",
  7: "specification",
  8: "file_folder",
  9: "news_article",
  10: "budget",
  11: "invoice",
  12: "presentation",
  13: "questionnaire",
  14: "resume",
  15: "memo"
}

def check_label(examples, label_index):
  booleans = []
  for label in examples['label']:
    try:
      if label == label_index:
        booleans.append(True)
      else:
        booleans.append(False)
    except:
      booleans.append(False)

  return booleans

# for each class: filter the dataset on documents with that class
# and then use the first 10 for training, and the last 10 for testing
for id in id2label.keys():
  # filter dataset on particular label
  filtered_dataset = dataset.filter(check_label, fn_kwargs={'label_index':id}, batched=True)
  # select first 10 examples for training
  filted_train_dataset = filtered_dataset.select(range(10))
  train_datasets.append(filted_train_dataset)
  # select last 10 examples for testing
  filted_test_dataset = filtered_dataset.select(range(10, 20))
  test_datasets.append(filted_test_dataset)



In [None]:
train_datasets[0]

Dataset({
    features: ['image', 'label'],
    num_rows: 10
})

In [None]:
test_datasets[0]

Dataset({
    features: ['image', 'label'],
    num_rows: 10
})

Next, we concatenate both lists to get 1 `DatasetDict`:

In [None]:
from datasets import DatasetDict, concatenate_datasets

toy_dataset = DatasetDict({"train": concatenate_datasets(train_datasets),
                           "test": concatenate_datasets(test_datasets)
                           })

In [None]:
toy_dataset["test"][144]

{'image': <PIL.TiffImagePlugin.TiffImageFile image mode=L size=754x1000 at 0x7FE145FBBAD0>,
 'label': 14}

### Push to the hub

Pushing the image classification dataset to the hub is as easy as:

In [None]:
# note that, in case you use the private hub, you can just add `private=True`
toy_dataset.push_to_hub("nielsr/rvl_cdip_10_examples_per_class")



  0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]



Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Reloading is as easy as:

In [None]:
toy_dataset = load_dataset("nielsr/rvl_cdip_10_examples_per_class")



  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
toy_dataset

DatasetDict({
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 160
    })
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 160
    })
})

<a name="section-2"></a>
## 2. Prepare in Donut format

The Donut model requires a dataset of (image, text) pairs, where each text is a string containing a "gt_parse" key, rather than just the label.

We do this by adding a new ground_truth column to the dataset, which contains this template.

In [None]:
template = '{"gt_parse": {"class" : '

In [None]:
id2label = {
  0: "letter",
  1: "form",
  2: "email",
  3: "handwritten",
  4: "advertisement",
  5: "scientific_report",
  6: "scientific_publication",
  7: "specification",
  8: "file_folder",
  9: "news_article",
  10: "budget",
  11: "invoice",
  12: "presentation",
  13: "questionnaire",
  14: "resume",
  15: "memo"
}


def update_examples(examples):
  ground_truths = []
  for label in examples['label']:
    ground_truths.append(template + '"' + id2label[label] + '"' + "}}")

  examples['ground_truth'] = ground_truths

  return examples

toy_dataset = toy_dataset.map(update_examples, batched=True)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
test = toy_dataset['train'][0]['ground_truth']
test

'{"gt_parse": {"class" : "letter"}}'

Let's verify we can read it as a Dict:

In [None]:
from ast import literal_eval

test2 = literal_eval(test)
test2['gt_parse']

{'class': 'letter'}

In [None]:
toy_dataset

DatasetDict({
    test: Dataset({
        features: ['image', 'label', 'ground_truth'],
        num_rows: 160
    })
    train: Dataset({
        features: ['image', 'label', 'ground_truth'],
        num_rows: 160
    })
})

### Push Donut dataset to the hub

Finally, we push this dataset to the hub such that we can easily reuse it, share with colleagues etc.

In [None]:
# you can simply add `private=True` in case you're using the private hub
toy_dataset.push_to_hub("nielsr/rvl_cdip_10_examples_per_class_donut")



  0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]



Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Reloading is as easy as:

In [5]:
from datasets import load_dataset

toy_dataset = load_dataset("nielsr/rvl_cdip_10_examples_per_class_donut")

Downloading:   0%|          | 0.00/1.85k [00:00<?, ?B/s]



Downloading and preparing dataset None/None (download: 33.60 MiB, generated: 35.68 MiB, post-processed: Unknown size, total: 69.28 MiB) to /root/.cache/huggingface/datasets/nielsr___parquet/nielsr--rvl_cdip_10_examples_per_class_donut-f7a67080e6d136af/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.3M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/nielsr___parquet/nielsr--rvl_cdip_10_examples_per_class_donut-f7a67080e6d136af/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
toy_dataset

DatasetDict({
    test: Dataset({
        features: ['image', 'label', 'ground_truth'],
        num_rows: 160
    })
    train: Dataset({
        features: ['image', 'label', 'ground_truth'],
        num_rows: 160
    })
})