In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from dataset_creator import DatasetCreator
from data_preparation import DataPreperation
from tensorflow.keras.layers import StringLookup

In [None]:
ds_prep = DataPreperation("../data")
ds_prep()

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

char_to_num = StringLookup(vocabulary=list(ds_prep.characters), mask_token=None)

num_to_char = StringLookup(
    vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)

In [None]:
batch_size = 64
padding_token = 99
image_width = 128
image_height = 32


ds_creator = DatasetCreator(
    ds_prep.train_img_paths,
    ds_prep.train_labels_cleaned,
    ds_prep.max_len,
    char_to_num,
    num_to_char,
    AUTOTUNE,
    batch_size,
    padding_token,
    image_width,
    image_height
)

In [None]:
train_ds = ds_creator.prepare_dataset(ds_prep.train_img_paths, ds_prep.train_labels_cleaned)
validation_ds = ds_creator.prepare_dataset(ds_prep.validation_img_paths, ds_prep.validation_labels_cleaned)
test_ds = ds_creator.prepare_dataset(ds_prep.test_img_paths, ds_prep.test_labels_cleaned)

In [None]:
for data in train_ds.take(1):
    images, labels = data["image"], data["label"]

    _, ax = plt.subplots(2, 2, figsize=(15, 8))

    for i in range(4):
        img = images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        label = labels[i]
        indices = tf.gather(label, tf.where(tf.math.not_equal(label, 99)))
        
        label = tf.strings.reduce_join(num_to_char(indices))
        label = label.numpy().decode("utf-8")

        ax[i // 2, i % 2].imshow(img, cmap="gray")
        ax[i // 2, i % 2].set_title(label)
        ax[i // 2, i % 2].axis("off")


plt.show()

In [None]:
validation_images = []
validation_labels = []

for batch in validation_ds:
    validation_images.append(batch["image"])    
    validation_labels.append(batch["label"])

In [None]:
def calculate_edit_distance(labels, predictions):
    saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)

    input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
    predictions_decoded = tf.keras.backend.ctc_decode(
        predictions, input_length=input_len, greedy=True
    )[0][0][:, :ds_prep.max_len]
    sparse_predictions = tf.cast(
        tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
    )

    edit_distances = tf.edit_distance(
        sparse_predictions, saprse_labels, normalize=False
    )
    return tf.reduce_mean(edit_distances)


class EditDistanceCallback(tf.keras.callbacks.Callback):
    def __init__(self, pred_model):
        super().__init__()
        self.prediction_model = pred_model

    def on_epoch_end(self, epoch, logs=None):
        edit_distances = []

        for i in range(len(validation_images)):
            labels = validation_labels[i]
            predictions = self.prediction_model.predict(validation_images[i])
            edit_distances.append(calculate_edit_distance(labels, predictions).numpy())

        print(
            f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"
        )

In [None]:
from model_architacture import build_model

model = build_model(image_width, image_height, char_to_num)
model.summary()

In [None]:
epochs = 10  # 50

model = build_model(image_width, image_height, char_to_num)
prediction_model = tf.keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
edit_distance_callback = EditDistanceCallback(prediction_model)


history = model.fit(
    train_ds,
    validation_data=validation_ds,
    epochs=epochs,
    callbacks=[edit_distance_callback]
)

In [None]:
fig = plt.figure(figsize=(12, 8))
ax = fig.add_axes([0, 0, 1, 1])
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss Values")
ax.plot(history.history['loss'], lw=2.5, c='gray', label='Train Cost function output')
ax.plot(history.history['val_loss'], lw=2.5, c='black', label='Validation Cost function output')
ax.legend(loc="lower right")
plt.plot()

In [None]:
prediction_model.summary()

In [25]:
model.summary()

Model: "handwriting_recognizer"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 image (InputLayer)          [(None, 128, 32, 1)]         0         []                            
                                                                                                  
 Conv1 (Conv2D)              (None, 128, 32, 32)          320       ['image[0][0]']               
                                                                                                  
 pool1 (MaxPooling2D)        (None, 64, 16, 32)           0         ['Conv1[0][0]']               
                                                                                                  
 Conv2 (Conv2D)              (None, 64, 16, 64)           18496     ['pool1[0][0]']               
                                                                             

In [None]:
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]

    results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :ds_prep.max_len
    ]
    
    output_text = []
    for res in results:
        res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

for batch in test_ds.take(1):
    batch_images = batch["image"]
    _, ax = plt.subplots(4, 4, figsize=(15, 8))

    preds = prediction_model.predict(batch_images)
    pred_texts = decode_batch_predictions(preds)

    for i in range(16):
        img = batch_images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        title = f"Prediction: {pred_texts[i]}"
        ax[i // 4, i % 4].imshow(img, cmap="gray")
        ax[i // 4, i % 4].set_title(title)
        ax[i // 4, i % 4].axis("off")

plt.show()