In [1]:
from __future__ import annotations

import os

import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageOps
from datasets import load_metric
from tqdm.notebook import tqdm
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor
)
from ultralytics import YOLO

In [2]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

In [3]:
class TextRecognizePipeline:
    
    def __init__(self):
        self.ocr_processor = TrOCRProcessor.from_pretrained("raxtemur/trocr-base-ru")
        # self.ocr_model = VisionEncoderDecoderModel.from_pretrained("../../models/text_recognizer/trocr_ru_pretrain_3epoch/", local_files_only=True).to(device)

        # self.ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
        self.ocr_model = VisionEncoderDecoderModel.from_pretrained("../../models/text_recognizer/checkpoint-1152/", local_files_only=True).to(device)
        self.ocr_model.eval()
        
        # self.detection_model = YOLO("../../models/new_text_detector/best.pt").to(device)
        self.detection_model = YOLO("/media/admin01/storage1/vadim/Historical-docs-OCR/models/text_detector/best_1024.pt").to(device)
        self.iou_threshold = 0.7
        
        # metrics
        self.iou_list = []
        self.cer_list = []
        self.wer_list = []
        
        # Set special tokens used for creating the decoder_input_ids from the labels.
        self.ocr_model.config.decoder_start_token_id = self.ocr_processor.tokenizer.cls_token_id
        self.ocr_model.config.pad_token_id = self.ocr_processor.tokenizer.pad_token_id
        # Set Correct vocab size.
        self.ocr_model.config.vocab_size = self.ocr_model.config.decoder.vocab_size
        self.ocr_model.config.eos_token_id = self.ocr_processor.tokenizer.sep_token_id
        
        self.ocr_model.config.max_length = 64
        self.ocr_model.config.early_stopping = True
        self.ocr_model.config.no_repeat_ngram_size = 3
        self.ocr_model.config.length_penalty = 2.0
        self.ocr_model.config.num_beams = 4
    
    def get_detections_and_crop_boxes(self, img: Image) -> list[Image]:
        
        def sort_bbox_by_y(bbox_list):
            sorted_bbox = sorted(bbox_list, key=lambda bbox: (bbox[1], bbox[0]))  # –°–æ—Ä—Ç–∏—Ä–æ–≤–∫–∞ –ø–æ –∫–æ–æ—Ä–¥–∏–Ω–∞—Ç–µ y, –∑–∞—Ç–µ–º –ø–æ x
            return sorted_bbox
        
        result = []
        for predict, image in zip(self.detection_model.predict([img], verbose=False), [img]):
            bboxes = predict.boxes.xyxy.cpu().tolist()
            sorted_bboxes = sort_bbox_by_y(bboxes)
            for box in sorted_bboxes:
                cropped_image = image.crop(box)
                result.append(cropped_image.convert("RGB"))
        return result
    
    def get_ocr_predictions(self, img_list: list[Image]) -> list[str]:
        with torch.no_grad():
            pixel_values = self.ocr_processor(img_list, return_tensors="pt").pixel_values.to(device)
            generated_ids = self.ocr_model.generate(pixel_values)
            generated_text = self.ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
            
        return generated_text
    
    def recognize(self, img_list: list[Image]) -> list[str]:
        cropped_images = self.get_detections_and_crop_boxes(img_list)
        recognized_text = self.get_ocr_predictions(cropped_images)
        return recognized_text

In [4]:
import pathlib

def get_rand_image():
    path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")
    img_path = np.random.choice(list(path.iterdir()))
    img = Image.open(img_path)
    img = ImageOps.exif_transpose(img)
    return img, img_path

def get_label_text(data: pd.DataFrame, filename: str) -> list[str]:
    return data[data["file_name"].str.contains(filename)]["text"].to_list()

def extract_filename(filename):
    base_name, extension = os.path.splitext(filename)
    parts = base_name.split("___")
    return parts[0] + extension

def get_image(img_path: str | pathlib.Path) -> Image:
    image = Image.open(img_path)
    image = ImageOps.exif_transpose(image)
    return image

In [5]:
data = pd.read_csv("../../data/processed/3 Production/test.csv", index_col=0)
# data['file_name'] = data['file_name'].apply(lambda x: x.replace('.JPG', '.jpg'))

# 0 - –ì—É–±–µ—Ä–Ω–∞—Ç–æ—Ä—Å–∫–∏–µ –æ—Ç—á—ë—Ç—ã
# 1 - –£—Å—Ç–∞–≤–Ω—ã–µ –≥—Ä–∞–º–æ—Ç—ã ‚Äì –ê—Ñ–∞–Ω–∞—Å–µ–Ω–∫–æ–≤
# 2 - –£—Å—Ç–∞–≤–Ω—ã–µ –≥—Ä–∞–º–æ—Ç—ã –≤ jpg (–ü—Ä–æ—Å–≤–µ—Ç–æ–≤)
# 3 - –ü–æ–±–µ–¥–æ–Ω–æ—Å—Ü–µ–≤

# –æ—Ç–¥–µ–ª—è–µ–º –≥—É–±–µ—Ä–Ω–∞—Ç–æ—Ä—Å–∫–∏–µ –æ—Ç—á—ë—Ç—ã –∏ —É—Å—Ç–∞–≤–Ω—ã–µ –≥—Ä–∞–º–æ—Ç—ã
governors_reports = data[data["label"] == 0]
charter_letters = data[(data["label"] == 1) | (data["label"] == 2)]
segment_annotation = data[data["label"] == 3]

In [6]:
data['label'].unique()

array([0, 1, 2, 3])

In [7]:
ocr_pipeline = TextRecognizePipeline()

cer_metric = load_metric("cer", trust_remote_code=True)
wer_metric = load_metric("wer", trust_remote_code=True)

  cer_metric = load_metric("cer", trust_remote_code=True)


#### –ü–æ–¥—Å—á—ë—Ç CER/WER –¥–ª—è –≥—É–±–µ—Ä–Ω–∞—Ç–æ—Ä—Å–∫–∏—Ö –æ—Ç—á—ë—Ç–æ–≤

In [8]:
governors_reports["file_name"] = governors_reports["file_name"].apply(extract_filename)
filenames = list(governors_reports.file_name.unique())

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  governors_reports["file_name"] = governors_reports["file_name"].apply(extract_filename)


In [9]:
image_root_path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")

cer = []
wer = []

for filename in tqdm(filenames, total=len(filenames)):
    file_path = image_root_path / pathlib.Path(filename)
    
    if not os.path.exists(file_path):
        file_path = image_root_path / pathlib.Path(filename.replace('.JPG', '.jpg'))
        
    img = get_image(file_path)
        
    pred_text = ocr_pipeline.recognize(img)
    pred_text = " ".join(pred_text)
    
    label_text = get_label_text(governors_reports, filename)
    label_text = " ".join(label_text)
    
    cer.append(
        cer_metric.compute(predictions=[pred_text], 
                           references=[label_text])
    )
    
    wer.append(
        wer_metric.compute(predictions=[pred_text], 
                           references=[label_text])
    )

print(f"CER: {np.mean(cer)} | WER: {np.mean(wer)}")

  0%|          | 0/298 [00:00<?, ?it/s]

CER: 0.08646262894234258 | WER: 0.2559889583773403


#### –ü–æ–¥—Å—á—ë—Ç CER/WER –¥–ª—è –æ—Ç—á—ë—Ç–Ω—ã—Ö –≥—Ä–∞–º–æ—Ç

In [10]:
charter_letters["file_name"] = charter_letters["file_name"].apply(extract_filename)
filenames = list(charter_letters.file_name.unique())

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  charter_letters["file_name"] = charter_letters["file_name"].apply(extract_filename)


In [11]:
image_root_path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")

cer = []
wer = []

for filename in tqdm(filenames, total=len(filenames)):
    file_path = image_root_path / pathlib.Path(filename)
    
    if not os.path.exists(file_path):
        file_path = image_root_path / pathlib.Path(filename.replace('.JPG', '.jpg'))
    
    img = get_image(file_path)
        
    pred_text = ocr_pipeline.recognize(img)
    pred_text = " ".join(pred_text)
    
    label_text = get_label_text(charter_letters, filename)
    label_text = " ".join(label_text)
    
    cer.append(
        cer_metric.compute(predictions=[pred_text], references=[label_text])
    )
    
    wer.append(
        wer_metric.compute(predictions=[pred_text], references=[label_text])
    )
    
print(f"CER: {np.mean(cer)} | WER: {np.mean(wer)}")

  0%|          | 0/59 [00:00<?, ?it/s]

CER: 0.12972285992238589 | WER: 0.37622438091415655


–ü–æ–¥—Å—á—ë—Ç CER/WER –¥–ª—è '–ü–æ–±–µ–¥–æ–Ω–æ—Å—Ü–µ–≤' (—Ä–µ–∑–º–µ—Ç–∫–∞ —Å–µ–≥–º–µ–Ω—Ç–∞–º–∏)

In [12]:
segment_annotation["file_name"] = segment_annotation["file_name"].apply(extract_filename)
filenames = list(segment_annotation.file_name.unique())

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  segment_annotation["file_name"] = segment_annotation["file_name"].apply(extract_filename)


In [13]:
image_root_path = pathlib.Path("../../data/processed/3 Production/text_detector/test/images")

cer = []
wer = []

for filename in tqdm(filenames, total=len(filenames)):
    file_path = image_root_path / pathlib.Path(filename)
    img = get_image(file_path)
    
    pred_text = ocr_pipeline.recognize(img)
    pred_text = " ".join(pred_text)
    
    label_text = get_label_text(segment_annotation, filename)
    label_text = " ".join(label_text)
    
    cer.append(
        cer_metric.compute(predictions=[pred_text], references=[label_text])
    )
    
    wer.append(
        wer_metric.compute(predictions=[pred_text], references=[label_text])
    )

print(f"CER: {np.mean(cer)} | WER: {np.mean(wer)}")

  0%|          | 0/9 [00:00<?, ?it/s]

CER: 0.4731776838193206 | WER: 0.7276771050396693


In [23]:
file_path

PosixPath('../../data/processed/3 Production/text_detector/test/images/c6d63ae5-15.png')

In [21]:
label_text

'1866. –í—ä –î–µ–∫–∞–±—Ä–µ, 2 —á–∏—Å–ª–∞, –Ω–∞—á–∞–ª—ä, –ø–æ –∂–µ–ª–∞–Ωi—é –ò–º–ø–µ—Ä–∞—Ç—Ä–∏—Ü—ã, –∑–∞–Ω—è—Çi—è —Å—ä —Ü–µ—Å–∞—Ä–µ–≤–Ω–æ–π —Ä—É—Å–∫–æ—é –∏—Å—Ç–æ—Äi–µ–π –ü—Ä–µ–æ–±–ª–∞–¥–∞–Ωi–µ –ì—Ä –ü. –ê. –®—É–≤–∞–ª–æ–≤–∞. –ì—Ä–∞—Ñ—ä –ü–∞–ª–µ–Ω—ä ‚Äì –º–∏–Ω–∏—Å—Ç—Ä—ä —é—Å—Ç–∏—Üi. 1867. –ê–ø—Ä–µ–ª—å. –ü–∞—Å—Ö–∞ –≤—ä –ú–æ—Å–∫–≤–µ, —Å—ä –ö–∞—Ç–µ–∏. 6 I—é–Ω—è. –ï–¥–µ–º—ä –Ω–∞ –ª–µ—Ç–æ –≤—ä –ì–æ–º–µ–ª—å, —Å—ä –°–æ—Ñ—å–µ–π –ù–∏–∫–æ–Ω–æ—Ä–æ–≤–Ω–æ–π, —Å—ä –°–æ–Ω–∏—á–∫–æ–π –∏ –í–æ–ª–æ–¥–µ–π. 17 –∞–≤–≥. —É–µ–∑–∂–∞–µ–º—ä –æ–±—Ä–∞—Ç–Ω–æ 28 –∞–≤–≥—É—Å—Ç–∞ ‚Äì –¥–æ 3 —Å–µ–Ω—Ç. —è –≤—ä –ú–æ—Å–∫–≤–µ, –æ–¥–∏–Ω. –Ω–æ –≤—Å–∫–æ—Ä–µ —É–µ—Ö–∞–ª—ä —Ç—É–¥–∞ —Å–Ω–æ–≤–∞, —Å—ä –ö–∞—Ç–µ–π. + 6 –°–µ–Ω—Ç—è–±—Ä—è —Å–∫–æ–Ω—á–∞–ª–∞—Å—å –º–∏–ª–∞—è –º–∞–º–µ–Ω—å–∫–∞. –ü–æ—Ö–æ—Ä–æ–Ω–∏–ª–∏ –µ–µ 9 —á–∏—Å–ª–∞ –Ω–∞ –í–∞–≥–∞–Ω—å–∫–æ–≤–µ! 15 –°–µ–Ω—Ç. —Å—Ç—Ä–∞—à–Ω–∞—è —Å–º–µ—Ä—Ç—å –ß–∏–≤–∏–ª–µ–≤–∞ –≤—ä –¶. –°–µ–ª–µ. –ó–∞–Ω—è—Çi—è —Å—ä –í. –ö. –í–ª–∞–¥–∏–ºi—Ä–æ–º—ä –ê–ª–µ–∫—Å–∞–Ω–¥—Ä–æ–≤–∏—á–µ–º—ä. 22 –æ–∫—Ç. —Å–≤–∞–¥—å–±–∞ –ø–ª

In [20]:
pred_text

'–í—ä –¥–µ–∫–∞–±—Ä—å, 2 —á–∏—Å–ª–∞, –Ω–∏—á–∞—Å—Ç—å, –ø–æ–∂–µ–ª–∞–Ωi—é –ò–º–ø–µ—Ä–∞—Ç—Ä–∏—Ü—ã, –∑–∞ —Å—ä —Ü–µ—Å–∞—Ä–µ–≤–Ω–æ–π —Ä—É—Å—Å–∫–æ—é –∏—Å—Ç–æ—Ä—á–µ–π –ü—Ä–µ–æ–±–ª—é–¥–∞—Ç–µ –¥–µ –ü. –ê. –®—É–≤–∞–ª–æ–≤–∞. –ü–æ—Å—Ç—ä –ü–æ–ª—è–Ω—ä ‚Äì –ú–∏–Ω–∏—Å—Ç—Ä—ä —é—Å—Ç–∏—Ü—ã. –ê–ø—Ä–µ–ª—å. –ü–∞–Ω–∞ –≤—ä –ú–æ—Å–∫–≤–µ, —Å—ä–∫–∞—Ç–µ—é 6 I—é–Ω—è: –í–¥–µ–Ω—å –Ω–∞ –ø–æ—ç—Ç–æ–≤—ä –ì–∞–Ω—Å–æ–Ω—ä, —Å—ä —Å–∞–∂–µ–π –ù–∏–∫–æ–Ω–æ—Ä–æ–≤–Ω–æ–π —Å—ä —Å–æ–Ω–∏—á–∫–æ–π –∏ –ï–≥–æ–ª–æ–≤–µ–π. 17 –∞–≤–≥. —É–µ–∑–∂–∞–µ–º—ä –æ–±—Ä–æ—Ç–∏–ª–∞ 28 –∞–≤–≥—É—Å—Ç–∞ ‚Äì –¥–æ 3 –°–µ–Ω—Ç. —è –≤—ä –ú–æ—Å–∫–≤–µ, –æ–¥–∏–Ω—ä, –ù–∞ –≤—Å–∫–æ—Ä–µ —É–µ—Ö–∞–ª—ä —Ç—É–¥–∞ —Å–Ω–æ–≤–∞, —Å—ä –ö–∞—Ç–µ ‚Ä† 6 –°–µ–Ω—Ç—è–±—Ä—è —Å–∫–æ–Ω—á–∞–ª–∞—Å—å –º–∏–ª–∞—è –º–æ–Ω–µ–Ω—å–∫–∞ –ü–æ–∫–æ—Ä–Ω–∏–∫–∏ –µ–µ 9 —á–∏—Å–ª–æ –Ω–∞ –í–∞–≥–∞–Ω—å–∫–æ–≤–µ. 15 –°–µ–Ω—Ç. —Å—Ç—Ä–æ–µ–Ω–Ω–∞—è –°–º–µ—Ä—Ç—å –ß–∏–≤–∏–∫–µ–µ–≤—ä –≤—ä —Ü. –°–∏–ª–µ—Ç –ó–∞–Ω—è—Çi—è —Å—ä –í. –ö. –í–ª–∞–¥–∏–º–∏—Ä–∞ –ê–ª–µ–∫—Å–∞–Ω–¥—Ä–æ–≤–∏—á–µ–º—ä. 22 –æ–∫—Ç. —Å–≤–∞–¥—å–±–∞ –ø–ª–µ—Ç—è–Ω–∏–Ω—Ü—ã.

In [26]:
data.groupby('label')['label'].value_counts()

label
0    6101
1     720
2     595
3     326
Name: count, dtype: int64

#TODO: –î–æ–±–∞–≤–∏—Ç—å –≤—ã–≥—Ä—É–∑–∫—É –ø–æ bbox —Ä–∞—Å–ø–æ–∑–Ω–∞–≤–∞–Ω–∏–µ + —Ä–∞–∑–º–µ—Ç–∫–∞

In [27]:
segment_annotation

Unnamed: 0,file_name,text,label
0,909ccbb0-18.png,–ù–∞—á–∞–ª–æ –æ–±—â–µ—Å—Ç–≤–∞ —É –í.–ö. –ö–æ–Ω—Å—Ç–∞–Ω—Ç–∏–Ω–∞. –û—Ç–¥–µ–ª—ä –û–±—â...,3
1,909ccbb0-18.png,–í–∞–ª—É–µ–≤—ä ‚Äì –ú-—Ä—ä –ì–æ—Å—É–¥. –∏–º—É—â–µ—Å—Ç–≤—ä.,3
2,909ccbb0-18.png,20 –º–∞—è. –ü–æ–µ–∑–¥–∫–∞ —Å—ä –ö–∞—Ç–µ–π –∏ –°–æ–Ω–∏—á–∫–æ–π —á–µ—Ä–µ–∑—ä –ú–æ—Å–∫–≤—É,3
3,909ccbb0-18.png,–≤—ä –°–º–æ–ª–µ–Ω—Å–∫—ä. —É –∞. –≤. —à–µ–≤–∞–Ω–¥–∏–Ω–æ–π –∏ —É –îi–æ–¥–æ—Ä–∞,3
4,909ccbb0-18.png,–≤—ä –ê–ª–µ–∫—Å–∞–Ω–¥—Ä–æ–≤—Å–∫–æ–º—ä. –≤–µ—Ä–Ω—É–ª–∏—Å—å 1 I—é–Ω—è.,3
...,...,...,...
321,c6d63ae5-15.png,"–í–∞—Ä—à–∞–≤—É –∏ –ë–µ—Ä–ª–∏–Ω—ä, –∏ –ü–∞—Ä–∏–∂—ä –∏ –õ–æ–Ω–¥–æ–Ω—ä,",3
322,c6d63ae5-15.png,–Ω–∞ –æ-–≤—ä –í–∞–π—Ç—ä. ‚Äì –®–µ–Ω–∫–ª–∏–Ω—ä. –ù–∞ –æ–±—Ä–∞—Ç–Ω–æ–º—ä,3
323,c6d63ae5-15.png,–ü—É—Ç–∏ —á–µ—Ä–µ–∑—ä –õ–æ–º–∂—É ‚Äì –≤–æ–∑–≤—Ä–∞—â–∞–µ–º—Å—è,3
324,c6d63ae5-15.png,1 –°–µ–Ω—Ç—è–±—Ä—è.,3
