## Imports

In [None]:
!pip install thefuzz
!pip install thefuzz[speedup]

Collecting thefuzz
  Downloading thefuzz-0.22.1-py3-none-any.whl.metadata (3.9 kB)
Collecting rapidfuzz<4.0.0,>=3.0.0 (from thefuzz)
  Downloading rapidfuzz-3.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading thefuzz-0.22.1-py3-none-any.whl (8.2 kB)
Downloading rapidfuzz-3.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, thefuzz
Successfully installed rapidfuzz-3.12.2 thefuzz-0.22.1


In [None]:
from tensorflow.keras.layers import StringLookup
from tensorflow import keras

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
import pandas as pd
import os

from thefuzz import process
from thefuzz import fuzz

np.random.seed(42)
tf.random.set_seed(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive




## Importing IAM Dataset

In [None]:
base_path = "drive/MyDrive/Medical_OCR"
lines_list = []

lines = open(f"{base_path}/data/lines.txt", "r").readlines()
for line in lines:
    if line[0] == "#":
        continue
    if line.split(" ")[1] != "err":  # We don't need to deal with errored entries.
        lines_list.append(line)

np.random.shuffle(lines_list)
len(lines_list)

FileNotFoundError: [Errno 2] No such file or directory: 'drive/MyDrive/Medical_OCR/data/lines.txt'

In [None]:
lines_aug_list = []

lines_aug = open(f"{base_path}/aug_data/line_new.txt", "r").readlines()
for line in lines_aug:
    if line[0] == "#":
        continue
    if line.split(" ")[1] != "err":  # We don't need to deal with errored entries.
        lines_aug_list.append(line)

np.random.shuffle(lines_aug_list)
len(lines_aug_list)

In [None]:
i=np.random.randint(len(lines_aug_list),size=(int(len(lines_aug_list)/2)))
lines_aug_list=np.delete(lines_aug_list,i)

## Data input pipeline

We start building our data input pipeline by first preparing the image paths.

In [None]:
# orignal data
base_image_path = os.path.join(base_path,"data" ,"lines")


def get_image_paths_and_labels(samples):
    paths = []
    corrected_samples = []
    for (i, file_line) in enumerate(samples):
        line_split = file_line.strip()
        line_split = line_split.split(" ")

        # Each line split will have this format for the corresponding image:
        # part1/part1-part2/part1-part2-part3.png
        image_name = line_split[0]
        partI = image_name.split("-")[0]
        partII = image_name.split("-")[1]
        img_path = os.path.join(
            base_image_path, partI, partI + "-" + partII, image_name + ".png"
        )
        if os.path.getsize(img_path):
            paths.append(img_path)
            corrected_samples.append(file_line.split("\n")[0])

    return paths, corrected_samples

img_paths, labels = get_image_paths_and_labels(lines_list)

In [None]:
# augmented data
base_image_path = os.path.join(base_path,"aug_data" ,"aug_lines")


def get_image_paths_and_labels(samples):
    paths = []
    corrected_samples = []
    for (i, file_line) in enumerate(samples):
        base_image_path = os.path.join(base_path,"aug_data" ,"aug_lines")
        line_split=file_line.strip()
        line_split = line_split.split(" ")
        image_name=line_split[0]
        img_path = os.path.join(base_image_path, image_name)
        if os.path.getsize(img_path):
            paths.append(img_path)
            corrected_samples.append(file_line.split("\n")[0])

    return paths, corrected_samples

img_paths_aug, labels_aug = get_image_paths_and_labels(lines_aug_list)

In [None]:
img_paths=np.array(img_paths+img_paths_aug)
labels=np.array(labels+labels_aug)

r_indexes = np.arange(len(img_paths))
np.random.shuffle(r_indexes)

img_paths=img_paths[r_indexes]
labels=labels[r_indexes]

len(img_paths)

We will split the dataset into three subsets with a 90:5:5 ratio (train:validation:test).

In [None]:
split_idx = int(0.7 * len(img_paths))

train_img_paths = img_paths[:split_idx]
train_labels= labels[:split_idx]

test_img_paths = img_paths[split_idx:]
test_labels=labels[split_idx:]

val_split_idx = int(0.5 * len(test_img_paths))

validation_img_paths = test_img_paths[:val_split_idx]
validation_labels= test_labels[:val_split_idx]

test_img_paths = test_img_paths[val_split_idx:]
test_labels= test_labels[val_split_idx:]


assert len(img_paths) == len(train_img_paths) + len(validation_img_paths) + len(test_img_paths)

print(f"Total training samples: {len(train_img_paths)}")
print(f"Total validation samples: {len(validation_img_paths)}")
print(f"Total test samples: {len(test_img_paths)}")

Then we prepare the ground-truth labels.

In [None]:
# Find maximum length and the size of the vocabulary in the training data.
train_labels_cleaned = []
characters = set()
max_len = 0

for label in train_labels:
    label = label.split(" ")[-1].strip()
    label=label.replace("|"," ")
    for char in label:
        characters.add(char)

    max_len = max(max_len, len(label))
    train_labels_cleaned.append(label)

characters = sorted(list(characters))

print("Maximum length: ", max_len)
print("Vocab size: ", len(characters))

# Check some label samples.
train_labels_cleaned[:10]

Now we clean the validation and the test labels as well.

In [None]:

def clean_labels(labels):
    cleaned_labels = []
    for label in labels:
        label = label.split(" ")[-1].strip()
        label=label.replace("|"," ")
        cleaned_labels.append(label)
    return cleaned_labels


validation_labels_cleaned = clean_labels(validation_labels)
test_labels_cleaned = clean_labels(test_labels)

### Building the character vocabulary

Keras provides different preprocessing layers to deal with different modalities of data.
[This guide](https://keras.io/guides/preprocessing_layers/) provides a comprehensive introduction.
Our example involves preprocessing labels at the character
level. This means that if there are two labels, e.g. "cat" and "dog", then our character
vocabulary should be {a, c, d, g, o, t} (without any special tokens). We use the
[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup/)
layer for this purpose.

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)

# Mapping integers back to original characters.
num_to_char = StringLookup(
    vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)

### Resizing images without distortion

Instead of square images, many OCR models work with rectangular images. This will become
clearer in a moment when we will visualize a few samples from the dataset. While
aspect-unaware resizing square images does not introduce a significant amount of
distortion this is not the case for rectangular images. But resizing images to a uniform
size is a requirement for mini-batching. So we need to perform our resizing such that
the following criteria are met:

* Aspect ratio is preserved.
* Content of the images is not affected.

In [None]:
def distortion_free_resize(image, img_size):
    w, h = img_size
    image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)

    # Check tha amount of padding needed to be done.
    pad_height = h - tf.shape(image)[0]
    pad_width = w - tf.shape(image)[1]

    # Only necessary if you want to do same amount of padding on both sides.
    if pad_height % 2 != 0:
        height = pad_height // 2
        pad_height_top = height + 1
        pad_height_bottom = height
    else:
        pad_height_top = pad_height_bottom = pad_height // 2

    if pad_width % 2 != 0:
        width = pad_width // 2
        pad_width_left = width + 1
        pad_width_right = width
    else:
        pad_width_left = pad_width_right = pad_width // 2

    image = tf.pad(
        image,
        paddings=[
            [pad_height_top, pad_height_bottom],
            [pad_width_left, pad_width_right],
            [0, 0],
        ],
    )

    image = tf.transpose(image, perm=[1, 0, 2])
    image = tf.image.flip_left_right(image)
    return image


### Putting the utilities together

In [None]:
batch_size = 24
padding_token = 99
image_width = 512
image_height = 128


def preprocess_image(image_path, img_size=(image_width, image_height)):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, 1)
    image = distortion_free_resize(image, img_size)
    image = tf.cast(image, tf.float32) / 255.0
    return image


def vectorize_label(label):
    label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
    length = tf.shape(label)[0]
    pad_amount = max_len - length
    label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
    return label


def process_images_labels(image_path, label):
    image = preprocess_image(image_path)
    label = vectorize_label(label)
    return {"image": image, "label": label}


def prepare_dataset(image_paths, labels):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(
        process_images_labels, num_parallel_calls=AUTOTUNE
    )
    return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)


## Prepare `tf.data.Dataset` objects

In [None]:
train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)
validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)
test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)

## Visualize a few samples

In [None]:
for data in train_ds.take(1):
    images, labels = data["image"], data["label"]

    _, ax = plt.subplots(2, 2, figsize=(15, 8))

    for i in range(4):
        img = images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        # Gather indices where label!= padding_token.
        label = labels[i]
        indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
        # Convert to string.
        label = tf.strings.reduce_join(num_to_char(indices))
        label = label.numpy().decode("utf-8")

        ax[i // 2, i % 2].imshow(img, cmap="gray")
        ax[i // 2, i % 2].set_title(label)
        ax[i // 2, i % 2].axis("off")


plt.show()

You will notice that the content of original image is kept as faithful as possible and has
been padded accordingly.

## Model

Our model will use the CTC loss as an endpoint layer. For a detailed understanding of the
CTC loss, refer to [this post](https://distill.pub/2017/ctc/).

In [None]:
class CTCLayer(keras.layers.Layer):

    def __init__(self, name=None):
        super().__init__(name=name)
        self.loss_fn = keras.backend.ctc_batch_cost

    def call(self, y_true, y_pred):
        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

        input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        loss = self.loss_fn(y_true, y_pred, input_length, label_length)
        self.add_loss(loss)

        # At test time, just return the computed predictions.
        return y_pred


def build_model():
    # Inputs to the model
    input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
    labels = keras.layers.Input(name="label", shape=None)

    # Conv blocks
    x = keras.layers.Conv2D(32,(3, 3),activation="relu",kernel_initializer="he_normal",padding="same",name="Conv1",)(input_img)
    x = keras.layers.Conv2D(128,(3, 3),activation="relu",kernel_initializer="he_normal",padding="same",name="Conv2",)(x)
    x = keras.layers.MaxPooling2D((2, 2), name="pool1")(x)
    x = keras.layers.Conv2D(256,(3, 3),activation="relu",kernel_initializer="he_normal",padding="same",name="Conv3",)(x)
    x = keras.layers.Dropout(0.2)(x)
    x = keras.layers.Conv2D(1024,(3, 3),activation="relu",kernel_initializer="he_normal",padding="same",name="Conv4",)(x)
    x = keras.layers.MaxPooling2D((2, 2), name="pool2")(x)
    x = keras.layers.Conv2D(64,(3, 3),activation="relu",kernel_initializer="he_normal",padding="same",name="Conv5",)(x)
    new_shape = ((image_width // 4), (image_height // 4) * 64)
    x = keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
    x = keras.layers.Dense(64, activation="relu", name="dense1")(x)
    x = keras.layers.Dropout(0.2)(x)

    # RNNs.
    x = keras.layers.Bidirectional(
        keras.layers.LSTM(1024, return_sequences=True, dropout=0.3)
    )(x)
    x = keras.layers.Bidirectional(
        keras.layers.LSTM(512, return_sequences=True, dropout=0.25)
    )(x)
    x = keras.layers.Bidirectional(
        keras.layers.LSTM(64, return_sequences=True, dropout=0.2)
    )(x)

    x = keras.layers.Dense(len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="dense2")(x)

    # Add CTC layer for calculating CTC loss at each step.
    output = CTCLayer(name="ctc_loss")(labels, x)

    # Define the model.
    model = keras.models.Model(
        inputs=[input_img, labels], outputs=output, name="handwriting_recognizer"
    )
    # Optimizer.
    opt = keras.optimizers.Adam()
    # Compile the model and return.
    model.compile(optimizer=opt)
    return model


## Training

Now we are ready to kick off model training.

In [None]:
model = build_model()
# model.load_weights(base_path+"/handwriting.h5")
model.summary()

In [None]:
# Train the model.
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(base_path+"/handwriting.h5", save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,monitor='val_loss',restore_best_weights=True)
epochs = 50

history = model.fit(
    train_ds,
    validation_data=validation_ds,
    epochs=epochs,
    callbacks=[checkpoint_cb,early_stopping_cb],
)

In [None]:
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

## Evaluation metric

[Edit Distance](https://en.wikipedia.org/wiki/Edit_distance)
is the most widely used metric for evaluating OCR models. In this section, we will
implement it and use it as a callback to monitor our model.

In [None]:
def calculate_edit_distance(labels, predictions):
    # Get a single batch and convert its labels to sparse tensors.
    saprse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)

    # Make predictions and convert them to sparse tensors.
    input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
    predictions_decoded = keras.backend.ctc_decode(
        predictions, input_length=input_len, greedy=True
    )[0][0][:, :max_len]
    sparse_predictions = tf.cast(
        tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
    )

    # Compute individual edit distances and average them out.
    edit_distances = tf.edit_distance(
        sparse_predictions, saprse_labels, normalize=False
    )
    return tf.reduce_mean(edit_distances)

## Inference

In [None]:
model=build_model()
model.load_weights(base_path+"/handwriting.h5")
prediction_model = keras.models.Model(
     model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

In [None]:
# A utility function to decode the output of the network.
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search.
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_len
    ]
    # Iterate over the results and get back the text.
    output_text = []
    for res in results:
        res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text


#  Let's check results on some test samples.
for batch in test_ds.take(1):
    # print(batch["label"])
    batch_images,batch_labels = batch["image"],batch["label"]
    print("len is : ",len(batch))
    print("batch img shape: ",batch_images[1].shape)
    # print(batch)
    _, ax = plt.subplots(2, 2, figsize=(15, 8))

    preds = prediction_model.predict(batch_images)
    pred_texts = decode_batch_predictions(preds)

    for i in range(4):
        img = batch_images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        label = batch_labels[i]
        indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
        # Convert to string.
        label = tf.strings.reduce_join(num_to_char(indices))
        label = label.numpy().decode("utf-8")

        title = f"Prediction: {pred_texts[i]} \n\n Original : {label}"
        ax[i // 2, i % 2].imshow(img, cmap="gray")
        ax[i // 2, i % 2].set_title(title)
        ax[i // 2, i % 2].axis("off")

plt.show()

In [None]:
# Edit Distance

ed_sum = 0
count=0

for batch in test_ds.take(len(test_img_paths)//batch_size + 1):
    batch_images,batch_labels = batch["image"],batch["label"]

    preds = model.predict(batch)
    ed_sum+=calculate_edit_distance(batch_labels,preds)
    count+=1


print("\nEdit Distance : ",ed_sum/count)

In [None]:
img_path = base_path+'/test1.png'
def prepare_dataset_custom(img_paths_3):
    dataset = tf.data.Dataset.from_tensor_slices((img_paths_3)).map(
        preprocess_image, num_parallel_calls=AUTOTUNE
    )
    return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)
img_paths_3 = [img_path]
custom_ds = prepare_dataset_custom(img_paths_3)

In [None]:
for batch in custom_ds.take(1):
    print("len is : ",len(batch))
    print("batch img shape: ",batch[0].shape)

    preds = prediction_model.predict(batch)
    pred_texts = decode_batch_predictions(preds)
    print(pred_texts)

    for i in range(1):
        img = batch[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        title = f"Prediction: {pred_texts[i]}"
        plt.imshow(img, cmap="gray")
        plt.title(title)
        plt.axis("off")

plt.show()

## Fine Tuning

### Import data and preprocessing

In [None]:
df=pd.read_csv(base_path+"/custom_data/annot.csv")
df_aug=pd.read_csv(base_path+"/custom_data/annot1.csv")

df=pd.concat([df,df_aug])
df = df.sample(frac = 1,ignore_index=True)
df.head()

In [None]:
data_path=base_path+"/custom_data/images/"
def get_paths(x):
  return data_path+x

df['images']=df['images'].apply(lambda x:get_paths(x))
df.head()

In [None]:
X_train,X_test,y_train,y_test=train_test_split(df['images'],df['label'],test_size=0.2)
X_val,X_test,y_val,y_test=train_test_split(X_test,y_test,test_size=0.5)

print("Training Size:",len(X_train))
print("Validation Size:",len(X_val))
print("Testing Size:",len(X_test))

In [None]:
train_ds = prepare_dataset(X_train, y_train)
validation_ds = prepare_dataset(X_val, y_val)
test_ds = prepare_dataset(X_test, y_test)

## Performance before transfer learning

In [None]:
ed_sum = 0
count=0

for batch in train_ds.take(len(test_img_paths)//batch_size + 1):
    batch_images,batch_labels = batch["image"],batch["label"]

    preds = model.predict(batch)
    ed_sum+=calculate_edit_distance(batch_labels,preds)
    count+=1


print("\nEdit Distance : ",ed_sum/count)

In [None]:
ed_sum = 0
count=0

for batch in test_ds.take(len(test_img_paths)//batch_size + 1):
    batch_images,batch_labels = batch["image"],batch["label"]

    preds = model.predict(batch)
    ed_sum+=calculate_edit_distance(batch_labels,preds)
    count+=1


print("\nEdit Distance : ",ed_sum/count)

### Visualize Data

In [None]:
for data in train_ds.take(1):
    images, labels = data["image"], data["label"]

    _, ax = plt.subplots(2, 2, figsize=(15, 8))

    for i in range(4):
        img = images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        # Gather indices where label!= padding_token.
        label = labels[i]
        indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
        # Convert to string.
        label = tf.strings.reduce_join(num_to_char(indices))
        label = label.numpy().decode("utf-8")

        ax[i // 2, i % 2].imshow(img, cmap="gray")
        ax[i // 2, i % 2].set_title(label)
        ax[i // 2, i % 2].axis("off")


plt.show()

### Fine Tuning

In [None]:
model=build_model()
model.load_weights(base_path+"/handwriting.h5")

In [None]:
# Train the model.
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(base_path+"/fine_tuned.h5", save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,monitor='val_loss',restore_best_weights=True)
epochs = 50

history = model.fit(
    train_ds,
    validation_data=validation_ds,
    epochs=epochs,
    callbacks=[early_stopping_cb,checkpoint_cb],
)

In [None]:
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.show()

### Evaluation

In [None]:
model=build_model()
model.load_weights(base_path+"/fine_tuned.h5")

In [None]:
prediction_model = keras.models.Model(
     model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

In [None]:
# A utility function to decode the output of the network.
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search.
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_len
    ]
    # Iterate over the results and get back the text.
    output_text = []
    for res in results:
        res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text


#  Let's check results on some test samples.
for batch in test_ds.take(1):
    # print(batch["label"])
    batch_images,batch_labels = batch["image"],batch["label"]
    print("len is : ",len(batch))
    print("batch img shape: ",batch_images[1].shape)
    # print(batch)
    _, ax = plt.subplots(2, 2, figsize=(15, 8))

    preds = model.predict(batch)
    pred_texts = decode_batch_predictions(preds)

    for i in range(4):
        img = batch_images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, perm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]

        label = batch_labels[i]
        indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
        # Convert to string.
        label = tf.strings.reduce_join(num_to_char(indices))
        label = label.numpy().decode("utf-8")

        title = f"Prediction: {pred_texts[i]} \n\n Original : {label}"
        ax[i // 2, i % 2].imshow(img, cmap="gray")
        ax[i // 2, i % 2].set_title(title)
        ax[i // 2, i % 2].axis("off")

plt.show()

In [None]:
ed_sum = 0
count=0

for batch in test_ds.take(len(X_test)//batch_size + 1):
    batch_images,batch_labels = batch["image"],batch["label"]


    preds = model.predict(batch)
    ed_sum+=calculate_edit_distance(batch_labels,preds)
    count+=1


print("\nEdit Distance : ",ed_sum/count)

## Predictions

In [None]:
df_med=pd.read_csv(base_path+"/medicines.csv")
df_med.drop('id',axis=1,inplace=True)

In [None]:
def prepare_dataset_custom(img_paths_3):
    dataset = tf.data.Dataset.from_tensor_slices((img_paths_3)).map(
        preprocess_image, num_parallel_calls=AUTOTUNE
    )
    return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)

def make_pred(img_path):
  img_paths = [img_path]
  custom_ds = prepare_dataset_custom(img_paths)
  for batch in custom_ds.take(1):
    preds = prediction_model.predict(batch)
    pred_texts = decode_batch_predictions(preds)

    img = batch[0]
    img = tf.image.flip_left_right(img)
    img = tf.transpose(img, perm=[1, 0, 2])
    img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
    img = img[:, :, 0]

    title = f"Prediction: {pred_texts[0]}"
    plt.imshow(img, cmap="gray")
    plt.title(title)
    plt.axis("off")

  plt.show()
  return pred_texts[0]

In [None]:
def extract_medicine_info(pred_texts):
  collection = df_med['name']
  med=process.extract(pred_texts, collection, scorer=fuzz.token_sort_ratio)
  med=med[0]
  if med[1]<80:
    print("Medicine not found")
  else:
    print(df_med.iloc[med[2]])
    return df_med.iloc[med[2]].to_dict()

In [None]:
# Test Image 1
img_path1 = base_path+'/test1.png'
pred_texts1=make_pred(img_path1)
medicine1=extract_medicine_info(pred_texts1)

In [None]:
print(medicine1)

In [None]:
# Test Image 2
img_path2 = base_path+'/test2.png'
pred_texts2=make_pred(img_path2)
medicine2=extract_medicine_info(pred_texts2)

In [None]:
print(medicine2)

In [None]:
# Test Image 3
img_path3 = base_path+'/test3.png'
pred_texts3=make_pred(img_path3)
medicine3=extract_medicine_info(pred_texts3)

In [None]:
print(medicine3)