In [1]:
import os
import tensorflow as tf
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from configs import ModelConfigs
from model import train_model

# Enable GPU memory growth
gpus = tf.config.experimental.list_physical_devices("GPU")
if gpus:
    for gpu in gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)
else:
    print("No GPUS")

No GPUS


In [None]:
# DO NOT RUN MULTIPLE TIMES

from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile
from tqdm import tqdm

def download_and_unzip(url, extract_to="Datasets"):
    response = urlopen(url)
    zipfile = ZipFile(BytesIO(response.read()))
    zipfile.extractall(extract_to)

dataset_path = os.path.join("Datasets", "IAM_Words")
if not os.path.exists(dataset_path):
    download_and_unzip("https://git.io/J0fjL", extract_to="Datasets")

In [None]:
# DO NOT RUN MULTIPLE TIMES

import tarfile

# Extract the words.tgz archive
file = tarfile.open(os.path.join(dataset_path, "words.tgz"))
file.extractall(os.path.join(dataset_path, "words"))

In [2]:
dataset_path = os.path.join("Datasets", "IAM_Words")

In [3]:
from tqdm import tqdm

# Initialize dataset, vocab, and max_len
dataset = []
vocab = set()
max_len = 0

# Load and preprocess the words.txt file
words_file = open(os.path.join(dataset_path, "words.txt"), "r").readlines()

for line in tqdm(words_file):
    if line.startswith("#") or "err" in line:
        continue
    
    line_split = line.split(" ")
    folder1 = line_split[0][:3]
    folder2 = "-".join(line_split[0].split("-")[:2])
    file_name = line_split[0] + ".png"
    label = line_split[-1].rstrip("\n")

    rel_path = os.path.join(dataset_path, "words", folder1, folder2, file_name)
    if not os.path.exists(rel_path):
        continue
    
    dataset.append([rel_path, label])
    vocab.update(list(label))
    max_len = max(max_len, len(label))


100%|██████████| 115338/115338 [00:06<00:00, 18404.10it/s]


In [4]:
# Create a ModelConfigs object to store configurations
configs = ModelConfigs()
configs.vocab = "".join(vocab)
configs.max_text_length = max_len
configs.save()

# Create DataProvider
from mltu.tensorflow.dataProvider import DataProvider
from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding
from mltu.annotations.images import CVImage

data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[ImageReader(CVImage)],
    transformers=[
        ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
        LabelIndexer(configs.vocab),
        LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
    ]
)

# Split dataset into training and validation
train_data_provider, val_data_provider = data_provider.split(split=0.9)


In [5]:
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen

# Apply augmentations to the training data
train_data_provider.augmentors = [
    RandomBrightness(), 
    RandomErodeDilate(),
    RandomSharpen(),
    RandomRotate(angle=10), 
]


In [6]:
# Create the model using the configurations
model = train_model(
    input_dim=(configs.height, configs.width, 3),
    output_dim=len(configs.vocab),
)

# Compile the model
from mltu.tensorflow.losses import CTCloss
from mltu.tensorflow.metrics import CWERMetric

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
    loss=CTCloss(),
    metrics=[CWERMetric(padding_token=len(configs.vocab))],
)

# Print model summary
model.summary(line_length=110)







In [10]:
# Define callbacks for training
earlystopper = EarlyStopping(
    monitor="val_CER", 
    patience=20, 
    verbose=1
)

checkpoint = ModelCheckpoint(
    f"{configs.model_path}/model.h5", 
    monitor="val_CER", 
    verbose=1, 
    save_best_only=True, 
    mode="min"
)

tb_callback = TensorBoard(
    log_dir=f"{configs.model_path}/logs", 
    update_freq=1
)

reduceLROnPlat = ReduceLROnPlateau(
    monitor="val_CER", 
    factor=0.9, 
    min_delta=1e-10, 
    patience=10, 
    verbose=1, 
    mode="auto"
)

from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")
trainLogger = TrainLogger(configs.model_path)





In [None]:
# Start the training process
history = model.fit(
    train_data_provider,
    validation_data=val_data_provider,
    epochs=configs.train_epochs,
    callbacks=[earlystopper, checkpoint, trainLogger, reduceLROnPlat, tb_callback, model2onnx],
    verbose = 1
)
