<a href="https://colab.research.google.com/github/adamserag1/Interpretability-for-VRDU-models/blob/main/finetuning/RVL_CDIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""RVL CDIP"""

'RVL CDIP'

In [1]:
!apt-get install tesseract-ocr -y

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


In [2]:
#mount drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
pip install -U datasets

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency re

In [5]:
!pip install datasets pytesseract

Collecting pytesseract
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Downloading pytesseract-0.3.13-py3-none-any.whl (14 kB)
Installing collected packages: pytesseract
Successfully installed pytesseract-0.3.13


In [6]:
from datasets import load_dataset, Dataset, ClassLabel
from collections import Counter
import multiprocessing
from tqdm import tqdm
from itertools import islice
import random
import os
from PIL import Image
from io import BytesIO
import pytesseract


In [7]:
N_PER_CLASS = 1000
SEED = 42
SAVE_PATH = "/content/rvl_cdip_financial_subset"

# RVL-CDIP full label map (known RVL-CDIP class names)
TARGET_LABELS = {
    1: "form",
    11: "invoice",
    10: "budget",
    8: "file folder",
    13: "questionnaire"
}


In [8]:
def stream_and_filter_rvl_cdip(n_per_class=N_PER_CLASS, seed=SEED):
    random.seed(seed)
    print("Streaming RVL-CDIP...")
    dataset_stream = load_dataset("rvl_cdip", split="train", streaming=True)
    collected = {i: [] for i in TARGET_LABELS.keys()}

    for example in dataset_stream:
        label = example["label"]
        if label in collected and len(collected[label]) < n_per_class:
            collected[label].append(example)
        if all(len(samples) >= n_per_class for samples in collected.values()):
            break

    all_samples = [item for sublist in collected.values() for item in sublist]
    random.shuffle(all_samples)
    print(f"Collected {len(all_samples)} examples.")
    return Dataset.from_list(all_samples)

In [9]:
def remap_labels(dataset, label_map):
    new_id_map = {old_id: new_id for new_id, old_id in enumerate(sorted(label_map))}
    new_names = [label_map[old_id] for old_id in sorted(label_map)]

    def _remap(example):
        example["label"] = new_id_map[example["label"]]
        return example

    dataset = dataset.map(_remap)
    dataset = dataset.cast_column("label", ClassLabel(names=new_names))
    return dataset

In [10]:
def save_and_validate(dataset, path):
    os.makedirs(path, exist_ok=True)
    dataset.save_to_disk(path)
    print(f"Saved to: {path}")

    label_names = dataset.features["label"].names
    counts = Counter(dataset["label"])
    print("\nLabel distribution:")
    for i, count in sorted(counts.items()):
        print(f"{label_names[i]}: {count}")

    print("\nSample:")
    sample = dataset[0]
    print(f"Label: {label_names[sample['label']]}")
    print(f"Words (first 10): {sample['words'][:10]}")
    print(f"BBoxes (first 10): {sample['bboxes'][:10]}")

In [11]:
def run_tesseract_ocr(example):
    try:
        image = example["image"]
        if isinstance(image, Image.Image):
            pil_image = image.convert("RGB")
        else:
            pil_image = Image.open(BytesIO(image)).convert("RGB")

        data = pytesseract.image_to_data(pil_image, output_type=pytesseract.Output.DICT)
        words = []
        bboxes = []

        for i in range(len(data["text"])):
            word = data["text"][i].strip()
            if word == "":
                continue
            x, y, w, h = data["left"][i], data["top"][i], data["width"][i], data["height"][i]
            x0, y0, x1, y1 = x, y, x + w, y + h
            words.append(word)
            bboxes.append([x0, y0, x1, y1])  # pixel-space

        example["words"] = words
        example["bboxes"] = bboxes
    except Exception as e:
        print(f"OCR failed on example: {e}")
        example["words"] = []
        example["bboxes"] = []

    return example

In [15]:
def parallel_ocr(dataset, num_processes=None):
    """
    Apply Tesseract OCR in parallel using multiprocessing with responsive tqdm.
    """
    if num_processes is None:
        num_processes = max(1, multiprocessing.cpu_count() - 1)

    print(f"\nRunning OCR with {num_processes} processes...")
    with multiprocessing.Pool(processes=num_processes) as pool:
        results = []
        with tqdm(total=len(dataset)) as pbar:
            for result in pool.imap_unordered(run_tesseract_ocr, dataset):
                results.append(result)
                pbar.update(1)

    return Dataset.from_list(results)

In [21]:
def build_rvl_cdip_subset_with_ocr():
    filtered = stream_and_filter_rvl_cdip()
    processed = remap_labels(filtered, TARGET_LABELS)
    print("Running OCR over dataset...")
    ocr_processed = processed.map(run_tesseract_ocr, num_proc=2)
    save_and_validate(ocr_processed, SAVE_PATH)
    return ocr_processed


In [22]:
subset = build_rvl_cdip_subset_with_ocr()
# save "/content/rvl_cdip_financial_subset" to drive
!cp -r /content/rvl_cdip_financial_subset /content/drive/MyDrive/THESIS

Streaming RVL-CDIP...
Collected 5000 examples.


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Running OCR over dataset...


Map (num_proc=2):   0%|          | 0/5000 [00:00<?, ? examples/s]

TimeoutError: 