## TrOCR

In [None]:
import os
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [None]:
base_model_name = "microsoft/trocr-small-handwritten"
save_model_name = "models/trocr/trocr-small-handwritten-finetuned"

In [None]:
def trocr_inference(model, processor, image_path):
    image = Image.open(image_path)
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

In [None]:
processor = TrOCRProcessor.from_pretrained(base_model_name)
model = VisionEncoderDecoderModel.from_pretrained(save_model_name)

## Our Goat

In [None]:
from inference_smartapp import handwriting_model

## TEST

In [None]:
dataset_path = 'dataset/transfer_dataset/'
val_dataset_path = os.path.join(dataset_path, 'val')

In [None]:
from Levenshtein import distance
import tensorflow as tf

In [None]:
val_df_list = os.listdir(val_dataset_path)
val_df_jpg_list = [val_df_list[i] for i in range(len(val_df_list)) if val_df_list[i].endswith('.jpg')]
val_df_jpg_list = [os.path.join(val_dataset_path, val_df_jpg_list[i]) for i in range(len(val_df_jpg_list)) ]

In [None]:
goat_distances = []
trocr_distances = []

for image_path in val_df_jpg_list:
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, 1)
    goat_text = handwriting_model.inference(image).strip()
    text_path = image_path.replace('.jpg', '.txt')
    with open(text_path) as f:
        real_text = f.read()
    real_text = real_text.replace('|', ' ').strip()
    trocr_text = trocr_inference(model, processor, image_path)
    trocr_text = trocr_text.replace('|', ' ').strip()
    print(f'TROCR: {trocr_text}')
    print(f'GOAT: {goat_text}')
    print(f'REAL: {real_text}')
    goat_distance = distance(goat_text, real_text)
    trocr_distance = distance(trocr_text, real_text)

    goat_distances.append(goat_distance)
    trocr_distances.append(trocr_distance)

goat_avg_distance = sum(goat_distances) / len(goat_distances)
trocr_avg_distance = sum(trocr_distances) / len(trocr_distances)

print(f'GOAT average distance: {goat_avg_distance}')
print(f'TROCR average distance: {trocr_avg_distance}')

In [None]:
import matplotlib.pyplot as plt

# Names of modules
modules = ['TrOCR-small', 'Model9v3_Transfer']

# Corresponding values from each module
values = [trocr_avg_distance, goat_avg_distance]

# Creating the bar chart
plt.figure(figsize=(8, 6))  # Optional: Adjust the size of the figure
plt.bar(modules, values, color=['blue', 'green'])  # You can specify different colors for each bar

# Adding titles and labels
plt.title('Comparison of Models Values')
plt.xlabel('Models')
plt.ylabel('Levenshtein Distance')

# Show the plot
plt.show()
