In [1]:
from tensorflow.keras.layers.experimental.preprocessing import StringLookup
import os
from tqdm.auto import tqdm
import random
from PIL import Image
import tensorflow as tf

DATA_FILE_PATH = "./data.txt"
IMAGE_HEIGHT = 64
BATCH_SIZE = 16
N_NEURONS=2048
LEARNING_RATE=0.00001
CP_PATH = "./training/cp-{epoch:04d}.ckpt"
AUTOTUNE = tf.data.AUTOTUNE # Let tf decide the best tunning algos


In [2]:
def filter_data(data, len_range, max_space):
    if data["label"] is None:
        return False
    if len(data["label"]) < len_range[0] or len(data["label"]) > len_range[1]:
        return False
    if data["label"].count(' ') > max_space:
        return False
    if data["label"].isascii() == False:
        return False    
    for char in "\n\r\xad\xa0":
        if char in data["label"]:
            return False
    # try:
    #     image = Image.open(data["image_path"])
    #     image.verify()
    #     if image.format != "JPEG" and image.format != "PNG":
    #         print("not a valid format")
    #         return False
    # except:
    #     print("invalid image or path")
    #     return False

    # if os.path.exists(data["image_path"]) == False:
    #     return False
    return True

def get_dataset(
        data_file_dir,
        multi_fonts_dir,
        mjsynth_dir,
        max_multi_fonts_len,
        max_mjsynth_len,
        len_range,
        max_space
):
    dataset = []
    data_file = open(os.path.join(data_file_dir, "data.txt"), "a+")
    multi_fonts_data = open(os.path.join(multi_fonts_dir, 'data.txt')).readlines()
    mjsynth_data = open(os.path.join(mjsynth_dir, "imlist.txt")).readlines()
    random.shuffle(multi_fonts_data)
    random.shuffle(mjsynth_data)
    multi_fonts_data = multi_fonts_data[:max_multi_fonts_len]
    mjsynth_data = mjsynth_data[:max_mjsynth_len]
    for line in tqdm(multi_fonts_data):
        splitted_line = line.split(' ', 1)
        label = splitted_line[1].split('\n')[0]
        dataset.append({"image_path": os.path.join(multi_fonts_dir, splitted_line[0]), "label": label})
    for image_name in tqdm(mjsynth_data):
        image_name = image_name.split('\n')[0]
        label = image_name.split('/')[-1].split('_')[1]
        dataset.append({"image_path": os.path.join(mjsynth_dir, image_name), "label": label})
    random.shuffle(dataset)
    dataset = list(filter(lambda data: filter_data(data, len_range, max_space), tqdm(dataset)))

    # for data in tqdm(dataset):
    #     data_file.write(f"{data['image_path']} {data['label']}\n")
    data_file.close()
    return dataset

def split_dataset(dataset, training_i, validation_i):
    train_ds = dataset[:int(training_i*len(dataset))] #98% of the whole dataset is train dataset
    validation_ds = dataset[int(training_i*len(dataset)):int(validation_i*len(dataset))] #1% is  validation dataset
    test_ds = dataset[int(validation_i*len(dataset)):] #1% is test dataset
    return train_ds, validation_ds, test_ds


def distortion_free_resize(image, img_size):
    w, h = img_size
    image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)

    # Check tha amount of padding needed to be done.
    pad_height = h - tf.shape(image)[0]
    pad_width = w - tf.shape(image)[1]

    # Only necessary if you want to do same amount of padding on both sides.
    if pad_height % 2 != 0:
        height = pad_height // 2
        pad_height_top = height + 1
        pad_height_bottom = height
    else:
        pad_height_top = pad_height_bottom = pad_height // 2

    if pad_width % 2 != 0:
        width = pad_width // 2
        pad_width_left = width + 1
        pad_width_right = width
    else:
        pad_width_left = pad_width_right = pad_width // 2

    image = tf.pad(
        image,
        paddings=[
            [pad_height_top, pad_height_bottom],
            [pad_width_left, pad_width_right],
            [0, 0],
        ],
    )

    image = tf.transpose(image, perm=[1, 0, 2])
    image = tf.image.flip_left_right(image)
    return image

def vectorize_label(label, max_len, char_to_num):
    label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
    length = tf.shape(label)[0]
    pad_amount = max_len - length
    label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=99) #Padding token = 99
    return label

def preprocess_image(image_path, img_size):
    image = tf.io.read_file(image_path) # Open file with tf
    image = tf.image.decode_png(image, channels=1) # transform to matrix of gray scale value
    image = distortion_free_resize(image, img_size) # Distort image
    image = tf.cast(image, tf.float32) / 255.0 # Transform image to data into matrix of gray scale float32 values in range [0, 1]
    return image

def process_images_labels(image_path, label, img_size, max_len, char_to_num):
    image = preprocess_image(image_path, img_size)
    label = vectorize_label(label, max_len, char_to_num)
    return {"image": image, "label": label}

def prepare_dataset(image_paths, labels, batch_size, img_size, max_len, char_to_num):
    return tf.data.Dataset.from_tensor_slices(
        (image_paths, labels)
    ).map(
        lambda image_path, label: process_images_labels(image_path, label, img_size, max_len, char_to_num), num_parallel_calls=AUTOTUNE
    ).batch(batch_size)

dataset = get_dataset(
    "./",
    "../datasets/multi-fonts-generated-text/",
    "../datasets/mjsynth/mnt/ramdisk/max/90kDICT32px/",
    1_000_000_000,
    1_000_000_000,
    len_range=(3, 32),
    max_space=3
)
labels = list(map(lambda data: data["label"].replace('|',  '\n'), dataset))
max_len = len(max(labels, key=len))
characters = sorted(list(set(char for label in labels for char in label)))
train_ds, validation_ds, test_ds = split_dataset(dataset, 0.98, 0.99)
train_ds, validation_ds, test_ds = split_dataset(dataset, 0.98, 0.99)
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
num_to_char = StringLookup(vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True)
# train_ds = prepare_dataset(
#     list(map(lambda data: data["image_path"], train_ds)),
#     list(map(lambda data: data["label"], train_ds)),
#     BATCH_SIZE,
#     (IMAGE_HEIGHT * 4, IMAGE_HEIGHT),
#     max_len,
#     char_to_num
# )
# validation_ds = prepare_dataset(
#     list(map(lambda data: data["image_path"], validation_ds)),
#     list(map(lambda data: data["label"], validation_ds)),
#     BATCH_SIZE,
#     (IMAGE_HEIGHT * 4, IMAGE_HEIGHT),
#     max_len,
#     char_to_num
# )
test_ds = prepare_dataset(
    list(map(lambda data: data["image_path"], test_ds)),
    list(map(lambda data: data["label"], test_ds)),
    BATCH_SIZE,
    (IMAGE_HEIGHT * 4, IMAGE_HEIGHT),
    max_len,
    char_to_num
)

# train_ds.save("./train_ds")
# validation_ds.save("./validation_ds")
test_ds.save("./test_ds")


  0%|          | 0/581581 [00:00<?, ?it/s]

  0%|          | 0/8919273 [00:00<?, ?it/s]

  0%|          | 0/9500854 [00:00<?, ?it/s]

: 

In [3]:
#### Load


1867
