## TrOCR

In [1]:
import os
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [2]:
open_source_model_name = "microsoft/trocr-base-handwritten"
save_model_name = f"models/{open_source_model_name.split('/')[-1]}-finetuned"

In [3]:
def trocr_inference(model, processor, image_path):
    image = Image.open(image_path)
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

In [4]:
processor = TrOCRProcessor.from_pretrained(open_source_model_name)
model = VisionEncoderDecoderModel.from_pretrained(save_model_name)



## Our Goat

In [5]:
from inferenz_smartapp import handwriting_model

models/transferstudent_mafiaboss_ohne_datum
Loading pre-trained model and weights...
Model and weights loaded successfully.


## TEST

In [6]:
dataset_path = 'dataset/dataset_training/'
train_dataset_path = os.path.join(dataset_path, 'train')
val_dataset_path = os.path.join(dataset_path, 'val')

In [7]:
from Levenshtein import distance
import tensorflow as tf

In [8]:
train_df_list = os.listdir(train_dataset_path)
val_df_list = os.listdir(val_dataset_path)

train_df_jpg_list = [train_df_list[i] for i in range(len(train_df_list)) if train_df_list[i].endswith('.jpg')]
val_df_jpg_list = [val_df_list[i] for i in range(len(val_df_list)) if val_df_list[i].endswith('.jpg')]

train_df_jpg_list = [os.path.join(train_dataset_path, train_df_jpg_list[i]) for i in range(len(train_df_jpg_list)) ]
val_df_jpg_list = [os.path.join(val_dataset_path, val_df_jpg_list[i]) for i in range(len(val_df_jpg_list)) ]

df_jpg_list = train_df_jpg_list + val_df_jpg_list

In [10]:
goat_distances = []
trocr_distances = []

for image_path in val_df_jpg_list:
    text_path = image_path.replace('.jpg', '.txt')
    with open(text_path) as f:
        real_text = f.read()
    real_text = real_text.replace('|', ' ').strip()
    trocr_text = trocr_inference(model, processor, image_path)
    trocr_text = trocr_text.replace('|', ' ').strip()
    print(f'TROCR: {trocr_text}')
    print(f'REAL: {real_text}')
    trocr_distance = distance(trocr_text, real_text)

    trocr_distances.append(trocr_distance)

trocr_avg_distance = sum(trocr_distances) / len(trocr_distances)

print(f'TROCR average distance: {trocr_avg_distance}')

TROCR: Lüdtke
REAL: Lüdtke
TROCR: Markus
REAL: Markus
TROCR: 3366806
REAL: 3366806
TROCR: luedtke@gmx.de
REAL: luedtke@gmx.de
TROCR: Lüdtke
REAL: Lüdtke
TROCR: Mila
REAL: Mila
TROCR: 6b
REAL: 6b
TROCR: Schumannstrasse 10
REAL: Schumannstrasse 10
TROCR: 28213
REAL: 28213
TROCR: Bremen
REAL: Bremen
TROCR: Finn
REAL: Finn
TROCR: 6B
REAL: 6B
TROCR: Bremen
REAL: Bremen
TROCR: 28309
REAL: 28309
TROCR: Marschstrasse 2
REAL: Marschstrasse 2


KeyboardInterrupt: 

In [9]:
goat_distances = []
trocr_distances = []

for image_path in df_jpg_list:
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, 1)
    goat_text = handwriting_model.inference(image).strip()
    text_path = image_path.replace('.jpg', '.txt')
    with open(text_path) as f:
        real_text = f.read()
    real_text = real_text.replace('|', ' ').strip()
    trocr_text = trocr_inference(model, processor, image_path)
    trocr_text = trocr_text.replace('|', ' ').strip()
    print(f'TROCR: {trocr_text}')
    print(f'GOAT: {goat_text}')
    print(f'REAL: {real_text}')
    goat_distance = distance(goat_text, real_text)
    trocr_distance = distance(trocr_text, real_text)

    goat_distances.append(goat_distance)
    trocr_distances.append(trocr_distance)

goat_avg_distance = sum(goat_distances) / len(goat_distances)
trocr_avg_distance = sum(trocr_distances) / len(trocr_distances)

print(f'GOAT average distance: {goat_avg_distance}')
print(f'TROCR average distance: {trocr_avg_distance}')

TROCR: Fabian
GOAT: Fabian
REAL: Fabian
TROCR: 12A
GOAT: 12A
REAL: 12A
TROCR: Aachener Straße 5
GOAT: Aachener Straße 5
REAL: Aachener Straße 5
TROCR: 28327
GOAT: 28327
REAL: 28327
TROCR: Bremen
GOAT: Bremen
REAL: Bremen
TROCR: Huhn
GOAT: Huhn
REAL: Huhn
TROCR: Sandra
GOAT: Sandra
REAL: Sandra
TROCR: 1625128
GOAT: 1625128
REAL: 1625128
TROCR: huhns@web.de
GOAT: huhns@web.de
REAL: huhns@web.de
TROCR: Huhn
GOAT: Huhn
REAL: Huhn
TROCR: Dussen
GOAT: Dussen
REAL: Dussen
TROCR: 12A
GOAT: 12A
REAL: 12A
TROCR: Cuxhavener Straße 4
GOAT: Cuxhauener Straße 4
REAL: Cuxhavener Straße 4
TROCR: 28357
GOAT: 28357
REAL: 28357
TROCR: Bremen
GOAT: Bremen
REAL: Bremen
TROCR: Dussen
GOAT: Dussen
REAL: Dussen
TROCR: Julia
GOAT: Julia
REAL: Julia
TROCR: 7214608
GOAT: 7214608
REAL: 7214608
TROCR: dussduss@gmx.de
GOAT: dussduss@gmx.de
REAL: dussduss@gmx.de
TROCR: Gunnar
GOAT: Gunnar
REAL: Gunnar
TROCR: Anders
GOAT: Anders
REAL: Anders
TROCR: Iwona
GOAT: Iwona
REAL: Iwona
TROCR: 12A
GOAT: 12A
REAL: 12A
TROCR: D

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# Names of modules
modules = ['TrOCR', 'TransferStudentMafia']

# Corresponding values from each module
values = [trocr_avg_distance, goat_avg_distance]

# Creating the bar chart
plt.figure(figsize=(8, 6))  # Optional: Adjust the size of the figure
plt.bar(modules, values, color=['blue', 'green'])  # You can specify different colors for each bar

# Adding titles and labels
plt.title('Comparison of Models Values')
plt.xlabel('Models')
plt.ylabel('Values')

plt.svefig('test_results/comparison.png')
# Show the plot
plt.show()
