In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping,  ReduceLROnPlateau,  TensorBoard
from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K
from model_funcs import character_error_rate, word_error_rate, CTCLoss2, build_CRNN_model,ctc_decoder
from data_processing import create_datasets, batch_generator
from tester_functions import inpute_batch_displayer, display_single_image
from configs import Configs 
import datetime


In [None]:
# empty out VRAM if being used for some reason
K.clear_session()
# allow for mixed prcision compute for more effienct compute
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
# enable GPU dynamic VRAM allocation 
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)


In [3]:
# Cell for creating tensorflow datasets to allow for variable images and ground truth labels
c = Configs()
# get complete dataset
batch_size = c.batch_size
# batch_size = 5
# amount of data to shuffle
buffer_size = c.buffer_size
total_dataset = create_datasets(c.image_paths, c.label_path, batch_size, c.image_height, c.image_max_width, c.augmentation_probability, c.cv_add_data)
# total_dataset = total_dataset.shuffle(buffer_size=buffer_size)
# get indivdual batches
training_datasets = total_dataset.map(lambda train, cv:train)
cv_datasets = total_dataset.map(lambda train, cv: cv)


In [4]:
# # test data set works
# for x, y in training_datasets.take(2):
#     for x ,y in zip(x,y):
#         tf.print(ctc_decoder(y))
#         tf.print(len(y))
#         tf.print(tf.shape(x))
#         tf.print(tf.reduce_max(x), tf.reduce_min(x))
#         display_single_image(x)

In [5]:
# Function to test a dataset
def test_dataset_seq_len(dataset, name):
    lengths = []
    for i, (x, y) in enumerate(dataset):
        # Print shapes and data types for debugging
        print(f"Testing {name} - Sample {i}:")
        print("Shape of x:", x.shape)
        print("Shape of y:", y.shape)
        print("Data type of x:", x.dtype)
        print("Data type of y:", y.dtype)
        
        # Calculate lengths of ground truth labels (y_true)
        true_lengths = tf.reduce_sum(tf.cast(tf.not_equal(y, c.seq_pad_val), dtype=tf.int32), axis=-1)
        max_lengths = tf.reduce_max(true_lengths)
        lengths.append(max_lengths)
        tf.print("Max length in this batch:", max_lengths, summarize=-1)

    # Print longest sequence for this dataset
    print(f'Longest sequence in {name}:', tf.reduce_max(lengths))

# Test training datasets: is 87
# test_dataset_seq_len(training_datasets, "Training Dataset")

# Test cross-validation datasets: is 75
# test_dataset_seq_len(cv_datasets, "Cross-Validation Dataset")

In [None]:
# load in model and get it ready for training
activation = c.activation_function
image_height = c.image_height
# image_height = 275
model = build_CRNN_model((image_height, None, 1), c.num_classes, activation)
# model = build_multi_branch_CRNN((c.image_height, None, 1), c.num_classes, activation)
model.summary()
learn_rate = c.learning_rate
# define the model optimizer, loss function and metrics we want to track
model.compile(optimizer=Adam(learning_rate=learn_rate, clipnorm = 1.0),
            #   loss=CTCLoss(blank_index=c.blank_index, seq_pad_val=c.seq_pad_val),
              loss=CTCLoss2( seq_pad_val=c.seq_pad_val),
              metrics=[character_error_rate, word_error_rate])
              # metrics = [])

# Callbacks for selecting the best model and early stopping if more training does nothing 
checkpoint = ModelCheckpoint('OCR model', monitor='val_loss', save_best_only=True, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

callbacks = [
    checkpoint,
    early_stopping,
    reduce_lr,
    tensorboard_callback
]

In [None]:
# number of epochs for training 
epochs = c.epoch_num 
# epochs = 1
model = model.fit(
    training_datasets,
    epochs=epochs,
    validation_data=cv_datasets,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# save the model to be able to import later
model.save('OCR model')