In [None]:
import os
import torch
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torch.utils.data import Dataset
from dataclasses import dataclass
import numpy as np
from jiwer import cer  # For CER calculation

@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-printed'

# Custom Dataset class for handling OCR
class CustomOCRDataset(Dataset):
    def __init__(self, dictionary_file, max_target_length=100):
        self.dictionary = self.load_dictionary(dictionary_file)
        self.word_to_id = {word: idx for idx, word in enumerate(self.dictionary)}

        # Add special tokens to the dictionary
        self.word_to_id["<START>"] = 0
        self.word_to_id["<END>"] = 1
        self.word_to_id["<UNK>"] = 2
        self.word_to_id["<PAD>"] = 3

    def load_dictionary(self, tokenized_file):
        with open(tokenized_file, 'r', encoding='utf-8') as file:
            words = [line.strip() for line in file.readlines()]
        special_tokens = ["<START>", "<END>", "<UNK>", "<PAD>", " "]
        return special_tokens + words

    def decode_labels(self, labels):
        label_str = []
        for label in labels:
            if label == self.word_to_id["<PAD>"] or label == self.word_to_id["<START>"] or label == self.word_to_id["<END>"]:
                continue
            elif label == self.word_to_id["<UNK>"]:
                label_str.append("<UNK>")
            else:
                label_str.append(self.dictionary[label])
        return "".join(label_str)

def load_ground_truth(ground_truth_file):
    ground_truth = {}
    with open(ground_truth_file, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            filename, text = line.strip().split(maxsplit=1)
            ground_truth[filename] = text
    return ground_truth

# Load your model configuration and trained model
dictionary_file = r"D:\Github\Khmer-OCR\Experiments\Dicts\unique_characters.txt"
custom_dataset = CustomOCRDataset(dictionary_file)

# Load the processor and trained model
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
trained_model = VisionEncoderDecoderModel.from_pretrained('D:/Github/Khmer-OCR/Experiments/Results/handwritten_v4_fine_tuning_ex_v6/checkpoint-26640').to(device)

# Paths
image_folder = "D:/Github/Khmer-OCR/Experiments/Tests/eng_char"
ground_truth_file = r"D:\Github\Khmer-OCR\Experiments\Tests\eng_char.txt"

# Load ground truth
ground_truth = load_ground_truth(ground_truth_file)

# Initialize CER and predictions
total_cer = 0.0
num_samples = 0

# Loop over all images in the folder
for image_file in os.listdir(image_folder):
    if image_file.endswith(".jpg") or image_file.endswith(".png"):
        image_path = os.path.join(image_folder, image_file)

        # Load and preprocess the image
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

        # Generate predictions
        with torch.no_grad():
            generated_ids = trained_model.generate(pixel_values)

        # Decode the labels using the custom dataset
        predicted_text = custom_dataset.decode_labels(generated_ids[0].cpu().numpy())

        # Get ground truth text for the image
        ground_truth_text = ground_truth.get(image_file, "")

        # Check if ground truth is present
        if not ground_truth_text:
            print(f"Skipping image {image_file} because ground truth is missing.")
            continue

        # Check if predicted text is not empty
        if not predicted_text:
            print(f"Skipping image {image_file} because predicted text is empty.")
            continue

        # Calculate CER for this sample
        sample_cer = cer(ground_truth_text, predicted_text)
        total_cer += sample_cer
        num_samples += 1

        # Output results for this image
        print(f"Image: {image_file}")
        print(f"Ground Truth: {ground_truth_text}")
        print(f"Predicted: {predicted_text}")
        print(f"CER: {sample_cer:.4f}")
        print("-" * 30)

# Calculate and display the average CER
if num_samples > 0:
    avg_cer = total_cer / num_samples
    print(f"Average CER: {avg_cer:.4f}")
else:
    print("No samples processed.")
