In [6]:
import os
import scipy.io
import cv2
import random
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from string import ascii_letters, digits

In [1]:
DATA_ROOT = "datasets/SynthText"
GT_PATH   = os.path.join(DATA_ROOT, "gt.mat") # GT = Ground Truth

print("Loading annotations...")
gt = scipy.io.loadmat(GT_PATH, struct_as_record=False, squeeze_me=True)
# GT is now a dict with keys: 'imnames', 'wordBB', 'charBB', 'txt'
print(gt.keys())

NameError: name 'os' is not defined

In [8]:
num_images = len(gt['imnames']) # should be 858750
NUM_SAMPLES = 10000   

# Pick NUM_SAMPLES unique image indices
sampled_idxs = random.sample(range(num_images), NUM_SAMPLES)

In [9]:
# Prepare output directory
OUTPUT_ROOT = "datasets/step1_chars"
os.makedirs(OUTPUT_ROOT, exist_ok=True)

In [17]:
DATA_ROOT             = "datasets/SynthText"
GT_PATH               = os.path.join(DATA_ROOT, "gt.mat")
OUTPUT_ROOT           = "datasets/step1_chars"
MIN_SAMPLES_PER_CLASS = 1000 # Often bottlenecks on capital Z's for ascii chars
MAX_SAMPLES_PER_CLASS = 1000
random.seed(42)

# Load SynthText annotations
imnames = gt['imnames']
charBB  = gt['charBB']
txt     = gt['txt']

# Shuffle indices
indices = list(range(len(imnames)))
random.shuffle(indices)

# Valid alphanumeric labels
valid_labels = list(ascii_letters + digits)

# Prepare output dirs & counters
os.makedirs(OUTPUT_ROOT, exist_ok=True)
counts = {lbl: 0 for lbl in valid_labels}
for lbl in valid_labels:
    os.makedirs(os.path.join(OUTPUT_ROOT, lbl), exist_ok=True)

# Helpers
def extract_chars(txt_entry):
    """
    Function name: extract_chars
    Description: Convert a raw text entry into a list of non-whitespace characters.
    Parameters:
        txt_entry (str or np.ndarray): Raw string or array of substrings from the SynthText annotations.
    Return Value:
        list[str]: Flattened list of individual characters, excluding any whitespace.
    """
    if isinstance(txt_entry, str):
        txt_str = txt_entry
    elif isinstance(txt_entry, np.ndarray):
        parts = []
        for item in txt_entry.flatten().tolist():
            if isinstance(item, str):
                parts.append(item)
            elif isinstance(item, np.ndarray):
                try:
                    parts.append("".join(item.tolist()))
                except:
                    pass
            else:
                parts.append(str(item))
        txt_str = "".join(parts)
    else:
        txt_str = str(txt_entry)
    return [c for c in txt_str if not c.isspace()]

def crop_char_safe(img, box):
    """
    Function name: crop_char_safe
    Description: Safely crop a rectangular patch from an RGB image given bounding‚Äêbox coordinates,
                 clamping the crop to image boundaries.
    Parameters:
        img (np.ndarray): Source image in RGB format (HxWx3).
        box (np.ndarray): Array of shape (2,4) giving the four corner coordinates [xs; ys] of the character box.
    Return Value:
        np.ndarray or None: Cropped image patch if valid, otherwise None if the box is degenerate or outside the image.
    """
    xs, ys = box[0, :], box[1, :]
    x1, x2 = int(xs.min()), int(xs.max())
    y1, y2 = int(ys.min()), int(ys.max())
    H, W = img.shape[:2]
    x1, x2 = max(0, x1), min(W, x2)
    y1, y2 = max(0, y1), min(H, y2)
    if x2 <= x1 or y2 <= y1:
        return None
    return img[y1:y2, x1:x2]

# Main loop
saved, skipped = 0, 0
for idx in indices:
    if all(count >= MIN_SAMPLES_PER_CLASS for count in counts.values()):
        break

    name = imnames[idx]
    img_path = os.path.join(DATA_ROOT, name)
    img_bgr = cv2.imread(img_path)
    if img_bgr is None:
        skipped += 1
        continue
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) #convert to RGB for consistency with Python libraries

    boxes     = charBB[idx]
    txt_entry = txt[idx]
    chars     = extract_chars(txt_entry)
    if boxes.shape[2] != len(chars):
        skipped += 1
        continue

    base = os.path.splitext(os.path.basename(name))[0]
    for j, ch in enumerate(chars):
        if ch not in valid_labels or counts[ch] >= MAX_SAMPLES_PER_CLASS:
            continue
        patch = crop_char_safe(img_rgb, boxes[:, :, j])
        if patch is None:
            skipped += 1
            continue

        patch = cv2.resize(patch, (32, 32), interpolation=cv2.INTER_AREA) #32 x 32 image resizing
        out_dir = os.path.join(OUTPUT_ROOT, ch)
        out_name = f"{base}_{j}.png"
        cv2.imwrite(os.path.join(out_dir, out_name),
                    cv2.cvtColor(patch, cv2.COLOR_RGB2BGR))

        counts[ch] += 1
        saved += 1

print(f"Done: saved {saved} crops, skipped {skipped} bad entries.\n")

# Final distribution
print("=== Sample counts per class ===")
for lbl in sorted(valid_labels):
    print(f"{lbl}: {counts[lbl]}")
print("\nBelow minimum threshold:")
for lbl, cnt in counts.items():
    if cnt < MIN_SAMPLES_PER_CLASS:
        print(f"  {lbl}: {cnt}")


Done: saved 62000 crops, skipped 5 bad entries.

=== Sample counts per class ===
0: 1000
1: 1000
2: 1000
3: 1000
4: 1000
5: 1000
6: 1000
7: 1000
8: 1000
9: 1000
A: 1000
B: 1000
C: 1000
D: 1000
E: 1000
F: 1000
G: 1000
H: 1000
I: 1000
J: 1000
K: 1000
L: 1000
M: 1000
N: 1000
O: 1000
P: 1000
Q: 1000
R: 1000
S: 1000
T: 1000
U: 1000
V: 1000
W: 1000
X: 1000
Y: 1000
Z: 1000
a: 1000
b: 1000
c: 1000
d: 1000
e: 1000
f: 1000
g: 1000
h: 1000
i: 1000
j: 1000
k: 1000
l: 1000
m: 1000
n: 1000
o: 1000
p: 1000
q: 1000
r: 1000
s: 1000
t: 1000
u: 1000
v: 1000
w: 1000
x: 1000
y: 1000
z: 1000

Below minimum threshold:


In [18]:
DATA_ROOT = Path("datasets/step1_chars")

print("Character : # images")
print("-" * 24)
for char_dir in sorted(DATA_ROOT.iterdir()):
    if not char_dir.is_dir():
        continue
    count = len(list(char_dir.glob("*.png")))
    print(f"{char_dir.name:8s} : {count}")


Character : # images
------------------------
0        : 1000
1        : 1000
2        : 1000
3        : 1000
4        : 1000
5        : 1000
6        : 1000
7        : 1000
8        : 1000
9        : 1000
a        : 2000
b        : 2000
c        : 2000
d        : 2000
e        : 2000
f        : 2000
g        : 2000
h        : 2000
i        : 2000
j        : 2000
k        : 2000
l        : 2000
m        : 2000
n        : 2000
o        : 2000
p        : 2000
q        : 2000
r        : 2000
s        : 2000
t        : 2000
u        : 2000
v        : 2000
w        : 2000
x        : 2000
y        : 2000
z        : 2000


In [19]:
DATA_ROOT = Path("datasets/step1_chars")
# build a dict of label : count
counts = {lbl: len(list((DATA_ROOT/lbl).glob("*.png")))
          for lbl in os.listdir(DATA_ROOT)
          if (DATA_ROOT/lbl).is_dir()}

# choose the cutoff
MIN_COUNT = 200

valid_labels = [lbl for lbl,c in counts.items() if c >= MIN_COUNT]
rare_labels  = [lbl for lbl,c in counts.items() if c <  MIN_COUNT]

print(f"Keeping {len(valid_labels)} labels, dropping {len(rare_labels)} labels")
print("Rare labels:", rare_labels)

Keeping 36 labels, dropping 0 labels
Rare labels: []
