In [None]:
from __future__ import annotations

import os

import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageOps
from datasets import load_metric
from tqdm.notebook import tqdm
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor
)
from ultralytics import YOLO

In [None]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

In [None]:
class TextRecognizePipeline:
    
    def __init__(self):
        self.ocr_processor = TrOCRProcessor.from_pretrained("raxtemur/trocr-base-ru")
        # self.ocr_model = VisionEncoderDecoderModel.from_pretrained("../../models/text_recognizer/trocr_ru_pretrain_3epoch/", local_files_only=True).to(device)

        # self.ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
        # self.ocr_model = VisionEncoderDecoderModel.from_pretrained("../../models/text_recognizer/checkpoint-1152/", local_files_only=True).to(device)
        self.ocr_model = VisionEncoderDecoderModel.from_pretrained("/storage3/vadim/HTR-historical/models/recognizer/trocr_ru_pretrain_3epoch", local_files_only=True).to(device)

        self.ocr_model.eval()
        
        # self.detection_model = YOLO("../../models/new_text_detector/best.pt").to(device)
        # self.detection_model = YOLO("/media/admin01/storage1/vadim/Historical-docs-OCR/models/text_detector/best_1024.pt").to(device)
        self.detection_model = YOLO("/storage3/vadim/HTR-historical/models/detector/best.pt").to(device)

        self.iou_threshold = 0.7
        
        # metrics
        self.iou_list = []
        self.cer_list = []
        self.wer_list = []
        
        # Set special tokens used for creating the decoder_input_ids from the labels.
        self.ocr_model.config.decoder_start_token_id = self.ocr_processor.tokenizer.cls_token_id
        self.ocr_model.config.pad_token_id = self.ocr_processor.tokenizer.pad_token_id
        # Set Correct vocab size.
        self.ocr_model.config.vocab_size = self.ocr_model.config.decoder.vocab_size
        self.ocr_model.config.eos_token_id = self.ocr_processor.tokenizer.sep_token_id
        
        self.ocr_model.config.max_length = 64
        self.ocr_model.config.early_stopping = True
        self.ocr_model.config.no_repeat_ngram_size = 3
        self.ocr_model.config.length_penalty = 2.0
        self.ocr_model.config.num_beams = 4
    
    def get_detections_and_crop_boxes(self, img: Image) -> list[Image]:
        
        def sort_bbox_by_y(bbox_list):
            sorted_bbox = sorted(bbox_list, key=lambda bbox: (bbox[1], bbox[0]))  # Сортировка по координате y, затем по x
            return sorted_bbox
        
        result = []
        for predict, image in zip(self.detection_model.predict([img], verbose=False), [img]):
            bboxes = predict.boxes.xyxy.cpu().tolist()
            sorted_bboxes = sort_bbox_by_y(bboxes)
            for box in sorted_bboxes:
                cropped_image = image.crop(box)
                result.append(cropped_image.convert("RGB"))
        return result
    
    def get_ocr_predictions(self, img_list: list[Image]) -> list[str]:
        with torch.no_grad():
            pixel_values = self.ocr_processor(img_list, return_tensors="pt").pixel_values.to(device)
            generated_ids = self.ocr_model.generate(pixel_values)
            generated_text = self.ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
            
        return generated_text
    
    def recognize(self, img_list: list[Image]) -> list[str]:
        cropped_images = self.get_detections_and_crop_boxes(img_list)
        recognized_text = self.get_ocr_predictions(cropped_images)
        return recognized_text

In [None]:
import pathlib

def get_rand_image():
    path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")
    img_path = np.random.choice(list(path.iterdir()))
    img = Image.open(img_path)
    img = ImageOps.exif_transpose(img)
    return img, img_path

def get_label_text(data: pd.DataFrame, filename: str) -> list[str]:
    return data[data["file_name"].str.contains(filename)]["text"].to_list()

def extract_filename(filename):
    base_name, extension = os.path.splitext(filename)
    parts = base_name.split("___")
    return parts[0] + extension

def get_image(img_path: str | pathlib.Path) -> Image:
    image = Image.open(img_path)
    image = ImageOps.exif_transpose(image)
    return image

In [None]:
data = pd.read_csv("../../data/processed/3 Production/test.csv", index_col=0)
# data['file_name'] = data['file_name'].apply(lambda x: x.replace('.JPG', '.jpg'))

# 0 - Губернаторские отчёты
# 1 - Уставные грамоты – Афанасенков
# 2 - Уставные грамоты в jpg (Просветов)
# 3 - Победоносцев

# отделяем губернаторские отчёты и уставные грамоты
governors_reports = data[data["label"] == 0]
charter_letters = data[(data["label"] == 1) | (data["label"] == 2)]
segment_annotation = data[data["label"] == 3]

In [None]:
data['label'].unique()

In [None]:
ocr_pipeline = TextRecognizePipeline()

cer_metric = load_metric("cer", trust_remote_code=True)
wer_metric = load_metric("wer", trust_remote_code=True)

#### Подсчёт CER/WER для губернаторских отчётов

In [None]:
governors_reports["file_name"] = governors_reports["file_name"].apply(extract_filename)
filenames = list(governors_reports.file_name.unique())

In [None]:
image_root_path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")

cer = []
wer = []

for filename in tqdm(filenames, total=len(filenames)):
    file_path = image_root_path / pathlib.Path(filename)
    
    if not os.path.exists(file_path):
        file_path = image_root_path / pathlib.Path(filename.replace('.JPG', '.jpg'))
        
    img = get_image(file_path)
        
    pred_text = ocr_pipeline.recognize(img)
    pred_text = " ".join(pred_text)
    
    label_text = get_label_text(governors_reports, filename)
    label_text = " ".join(label_text)
    
    cer.append(
        cer_metric.compute(predictions=[pred_text], 
                           references=[label_text])
    )
    
    wer.append(
        wer_metric.compute(predictions=[pred_text], 
                           references=[label_text])
    )

print(f"CER: {np.mean(cer)} | WER: {np.mean(wer)}")

#### Подсчёт CER/WER для отчётных грамот

In [None]:
charter_letters["file_name"] = charter_letters["file_name"].apply(extract_filename)
filenames = list(charter_letters.file_name.unique())

In [None]:
image_root_path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")

cer = []
wer = []

for filename in tqdm(filenames, total=len(filenames)):
    file_path = image_root_path / pathlib.Path(filename)
    
    if not os.path.exists(file_path):
        file_path = image_root_path / pathlib.Path(filename.replace('.JPG', '.jpg'))
    
    img = get_image(file_path)
        
    pred_text = ocr_pipeline.recognize(img)
    pred_text = " ".join(pred_text)
    
    label_text = get_label_text(charter_letters, filename)
    label_text = " ".join(label_text)
    
    cer.append(
        cer_metric.compute(predictions=[pred_text], references=[label_text])
    )
    
    wer.append(
        wer_metric.compute(predictions=[pred_text], references=[label_text])
    )
    
print(f"CER: {np.mean(cer)} | WER: {np.mean(wer)}")

Подсчёт CER/WER для 'Победоносцев' (резметка сегментами)

In [None]:
segment_annotation["file_name"] = segment_annotation["file_name"].apply(extract_filename)
filenames = list(segment_annotation.file_name.unique())

In [None]:
image_root_path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")

cer = []
wer = []

for filename in tqdm(filenames, total=len(filenames)):
    file_path = image_root_path / pathlib.Path(filename)
    img = get_image(file_path)
    
    pred_text = ocr_pipeline.recognize(img)
    pred_text = " ".join(pred_text)
    
    label_text = get_label_text(segment_annotation, filename)
    label_text = " ".join(label_text)
    
    cer.append(
        cer_metric.compute(predictions=[pred_text], references=[label_text])
    )
    
    wer.append(
        wer_metric.compute(predictions=[pred_text], references=[label_text])
    )

print(f"CER: {np.mean(cer)} | WER: {np.mean(wer)}")

In [None]:
file_path

In [None]:
label_text

In [None]:
pred_text

In [None]:
data.groupby('label')['label'].value_counts()

#TODO: Добавить выгрузку по bbox распознавание + разметка

In [None]:
segment_annotation