# The purpose of this notebook is to ready the data for the training of the CRNN model (for word recognition in OCR).

In [None]:
import os, sys
import json
import random
import torch
from torch.utils.data import Dataset
from PIL import Image
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T

sys.path.append(os.path.abspath(os.path.join('../src')))

# add parent directory to sys.path to access data without having to put "../data"
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

In [2]:
# Count the number of words in the the cropped_words directory = number of files
cropped_words_dir = "../data/cropped_words/"
word_files = os.listdir(cropped_words_dir)
print(f"Number of cropped word images: {len(word_files)}")

Number of cropped word images: 2186


In [3]:
# Extract character set from JSON annotations
def extract_charset(json_path):
    chars = set()

    with open(json_path, "r") as f:
        data = json.load(f)

    for _, words in data.items():
        for entry in words:
            chars.update(entry["word"])

    chars = sorted(chars)
    return chars

# Build vocabulary mappings from character set
def build_vocab(chars):
    char2idx = {c: i + 1 for i, c in enumerate(chars)}  # 0 = blank
    idx2char = {i + 1: c for i, c in enumerate(chars)}
    blank_idx = 0
    num_classes = len(chars) + 1

    return char2idx, idx2char, blank_idx, num_classes

json_annotations_path = "../data/filename_to_word_files.json"
chars = extract_charset(json_annotations_path)
print(f"Extracted {len(chars)} unique characters from annotations.")
print(chars)

char2idx, idx2char, blank_idx, num_classes = build_vocab(chars)
print(f"Number of classes (including blank): {num_classes}")

Extracted 78 unique characters from annotations.
[' ', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '>', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y']
Number of classes (including blank): 79


In [24]:
with open(json_annotations_path, "r") as f:
    data = json.load(f)

receipts = list(data.keys())
print(receipts[:5])  # Display first 5 receipt filenames

['dev_receipt_00091.png', 'dev_receipt_00085.png', 'dev_receipt_00052.png', 'dev_receipt_00046.png', 'dev_receipt_00047.png']


In [25]:
# Split receipts into training and validation sets - we do this at the receipt level to avoid data leakage
def split_receipts(json_path, train_ratio=0.8, seed=42):
    random.seed(seed)

    with open(json_path, "r") as f:
        data = json.load(f)

    receipts = list(data.keys())
    random.shuffle(receipts)

    n_train = int(len(receipts) * train_ratio)

    train_receipts = receipts[:n_train]
    val_receipts = receipts[n_train:]

    return train_receipts, val_receipts

# Build samples list from receipt filenames
def build_samples(json_path, receipt_filenames):
    with open(json_path, "r") as f:
        data = json.load(f)

    samples = []

    for rid in receipt_filenames:
        for entry in data[rid]:
            samples.append(
                (entry["word_file"], entry["word"])
            )

    return samples


train_receipts, val_receipts = split_receipts(json_annotations_path, train_ratio=0.8, seed=42)
print(f"Number of training receipts: {len(train_receipts)}")
print(f"Number of validation receipts: {len(val_receipts)}")

train_samples = build_samples(json_annotations_path, train_receipts)
val_samples = build_samples(json_annotations_path, val_receipts)
print(f"Number of training samples (words): {len(train_samples)}")
print(f"Number of validation samples (words): {len(val_samples)}")

# print an example training sample
print(f"Example training sample: {train_samples[0]}")

Number of training receipts: 80
Number of validation receipts: 20
Number of training samples (words): 1754
Number of validation samples (words): 432
Example training sample: ('data/cropped_words/dev_receipt_00026_word_0.png', 'Rp.')


In [14]:
class OCRDataset(Dataset):
    """
    This class implements a PyTorch Dataset for our OCR task.
    It handles loading images and their corresponding text labels,
    encoding the text into indices, and applying any necessary transformations.
    """
    def __init__(self, samples, char2idx, transform=None):
        """
        samples: list of tuples (image_path, text)
        char2idx: dictionary mapping characters to indices
        """
        self.samples = samples
        self.char2idx = char2idx
        self.to_tensor = T.ToTensor()

    def encode(self, text):
        """
        The encode method converts a text string into a tensor of character indices.
        This is essential for preparing the target labels for training the OCR model.
        """
        return torch.tensor(
            [self.char2idx[c] for c in text],
            dtype=torch.long
        )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, text = self.samples[idx]

        # the file_names in the json are stored as data/... so we need to go one directory up
        image = Image.open(os.path.join(os.pardir, img_path)).convert("L")

        # convert to tensor [1, 128, 128]
        image = self.to_tensor(image)

        target = self.encode(text)
        target_length = len(target)

        return image, target, target_length, text


# Create an instance of the OCRDataset for a subset to illustrate usage
sample_dataset = OCRDataset(train_samples[:2], char2idx)
for i in range(len(sample_dataset)):
    image, target, target_length, text = sample_dataset[i]
    print(f"Sample {i}:")
    print(f"  Target indices: {target}")
    print(f"  Target length: {target_length}")
    print(f"  Original text: {text}")
    

Sample 0:
  Target indices: tensor([44, 69, 11])
  Target length: 3
  Original text: Rp.
Sample 1:
  Target indices: tensor([40, 27, 45, 35])
  Target length: 4
  Original text: NASI


In [22]:
# def build_transforms():
#     return T.Compose([
#         T.Resize((128, 128)),
#         T.ToTensor(),
#         T.Normalize(mean=[0.5], std=[0.5]),
#     ])


def ctc_collate_fn(batch):
    """
    Custom collate function for CTC loss that handles variable-length targets.
    In classification tasks, this is not needed, because all targets are of the same length. 
    Given we are not doing classification, but recognition with CTC loss, we need this function.
    
    Parameters
    ----------
    batch : list of tuples
        Each tuple contains (image, target, target_length, text).
    """
    images, targets, target_lengths, texts = zip(*batch)

    images = torch.stack(images)
    targets = torch.cat(targets)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)

    return images, targets, target_lengths, texts


def build_dataloaders(
    train_dataset,
    val_dataset,
    batch_size=16,
):
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=ctc_collate_fn,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=ctc_collate_fn,
    )

    return train_loader, val_loader


# Example usage of DataLoader with the OCRDataset and ctc_collate_fn
train_dataset = OCRDataset(train_samples, char2idx)
val_dataset = OCRDataset(val_samples, char2idx)

train_loader, val_loader = build_dataloaders(
    train_dataset,
    val_dataset,
    batch_size=16
)

images, targets, target_lengths, texts = next(iter(train_loader))

print("Images:", images.shape)              # SHape is expected to be (B, 1, 128, 128)
print("Targets:", targets.shape)            # Shape is expected to be (sum of target lengths in batch,)
print("Target lengths:", target_lengths)    # Shape is expected to be (B,)
assert target_lengths.sum().item() == targets.shape[0], "Sum of target lengths must equal number of target indices"
print("Example texts:", texts[:3])

Images: torch.Size([16, 1, 128, 128])
Targets: torch.Size([74])
Target lengths: tensor([ 3,  5,  5,  6,  1,  3, 11,  4,  3,  6,  6,  5,  6,  6,  1,  3])
Example texts: ('Tax', 'Bayar', 'GREEN')


The dataset should be ready to be used. We can refactor the code into python scripts.

In the `dataset.py` file