In [None]:
import tensorflow as tf
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding
from mltu.augmentors import RandomBrightness, RandomErodeDilate, RandomSharpen
from mltu.annotations.images import CVImage
from mltu.tensorflow.dataProvider import DataProvider
from mltu.tensorflow.losses import CTCloss
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
from mltu.tensorflow.metrics import CERMetric, WERMetric
from model import train_model
from configs import ModelConfigs
import os
from tqdm import tqdm


try:
    [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
except:
    pass

sentences_txt_path = "Dataset/metadata/sentences.txt"
sentences_folder_path = "Dataset/dataset"

dataset, vocab, max_len = [], set(), 0
words = open(sentences_txt_path, "r").readlines()

for line in tqdm(words):
    if line.startswith("#"):
        continue

    line_split = line.split(" ")
    if line_split[2] == "err":
        continue

    file_name = line_split[0] + ".png"
    label = line_split[-1].rstrip("\n")
    label = label.replace("|", " ")

    rel_path = os.path.join(sentences_folder_path, file_name)
    if not os.path.exists(rel_path):
        print(f"File not found: {rel_path}")
        continue

    dataset.append([rel_path, label])
    vocab.update(list(label))
    max_len = max(max_len, len(label))


configs = ModelConfigs()
configs.vocab = "".join(vocab)
configs.max_text_length = max_len
configs.save()

data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[ImageReader(CVImage)],
    transformers=[
        ImageResizer(configs.width, configs.height, keep_aspect_ratio=True),
        LabelIndexer(configs.vocab),
        LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
    ],
)

train_data_provider, val_data_provider = data_provider.split(split=0.9)

train_data_provider.augmentors = [
    RandomBrightness(),
    RandomErodeDilate(),
    RandomSharpen(),
]

def create_tf_dataset(data_provider):
    def generator():
        for data, label in data_provider:
            
            for i in range(len(data)):
                yield data[i], label[i]  

    
    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=(configs.height, configs.width, 3), dtype=tf.float32),  
            tf.TensorSpec(shape=(configs.max_text_length,), dtype=tf.int32)  
        )
    )
    dataset = dataset.batch(configs.batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

train_dataset = create_tf_dataset(train_data_provider)
val_dataset = create_tf_dataset(val_data_provider)

model = train_model(
    input_dim=(configs.height, configs.width, 3),
    output_dim=len(configs.vocab),
)


class Float32CastingLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.cast(inputs, dtype=tf.float32)


outputs = Float32CastingLayer()(model.output)
model = tf.keras.Model(inputs=model.input, outputs=outputs)


model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
    loss=CTCloss(),
    metrics=[
        CERMetric(vocabulary=configs.vocab),
        WERMetric(vocabulary=configs.vocab)
    ],
    run_eagerly=False
)
model.summary(line_length=110)


earlystopper = EarlyStopping(monitor="val_CER", patience=20, verbose=1, mode="min")
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
train_logger = TrainLogger(configs.model_path)
tensorboard = TensorBoard(log_dir=os.path.join(configs.model_path, "logs"), update_freq=1)
reduce_lr = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=5, verbose=1, mode="auto")
model_to_onnx = Model2onnx(f"{configs.model_path}/model.onnx")


model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=configs.train_epochs,
    callbacks=[earlystopper, checkpoint, train_logger, reduce_lr, tensorboard, model_to_onnx],
)


train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))
