# Fast benchmark рукописная кириллица (TrOCR)



In [1]:

# !pip -q install -U "transformers>=4.40" datasets accelerate jiwer evaluate
# !pip -q install -U --no-cache-dir --force-reinstall "pillow==11.0.0"


In [None]:

import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

In [None]:
from datasets import load_dataset
ds = load_dataset("Timka28/cyrillic")
ds


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Resolving data files:   0%|          | 0/462 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/462 [00:00<?, ?files/s]

train-00000-of-00462.parquet:   0%|          | 0.00/523M [00:00<?, ?B/s]

train-00001-of-00462.parquet:   0%|          | 0.00/332M [00:00<?, ?B/s]

train-00002-of-00462.parquet:   0%|          | 0.00/633M [00:00<?, ?B/s]

train-00003-of-00462.parquet:   0%|          | 0.00/499M [00:00<?, ?B/s]

train-00004-of-00462.parquet:   0%|          | 0.00/448M [00:00<?, ?B/s]

train-00005-of-00462.parquet:   0%|          | 0.00/420M [00:00<?, ?B/s]

## Метрики CER/WER

In [None]:
import re
from jiwer import wer
from evaluate import load as load_eval

cer_metric = load_eval("cer")

def normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s

def calc_metrics(refs, hyps):
    refs_n = [normalize_text(x) for x in refs]
    hyps_n = [normalize_text(x) for x in hyps]
    cer = cer_metric.compute(references=refs_n, predictions=hyps_n)
    w = wer(refs_n, hyps_n)
    return float(cer), float(w)


## Приведение изображений к RGB (фикс для grayscale `ndim=2`)

In [None]:
import numpy as np
from PIL import Image

ex = ds["test"][0]["image"]

arr = np.asarray(ex)
print("type:", type(ex))
print("shape:", arr.shape, "dtype:", arr.dtype, "min:", arr.min(), "max:", arr.max())

In [None]:
def to_rgb(img):
    if isinstance(img, Image.Image):
        return img.convert("RGB")

    arr = np.asarray(img)
    # привести каналы
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=-1)
    elif arr.ndim == 3 and arr.shape[-1] == 1:
        arr = np.repeat(arr, 3, axis=-1)
    elif arr.ndim == 3 and arr.shape[-1] == 3:
        pass
    else:
        raise ValueError(f"Unexpected image shape: {arr.shape}")
    # привести тип/диапазон
    if arr.dtype != np.uint8:
        arr = arr.astype(np.float32)
        if arr.max() <= 1.0:
            arr *= 255.0
        arr = np.clip(arr, 0, 255).astype(np.uint8)

    return Image.fromarray(arr)

In [None]:
import cv2

def crop_to_ink(img, pad=20, min_area=2000):
    img = to_rgb(img)
    gray = np.array(img.convert("L"))

    # Otsu + инверсия: чернила -> 1
    thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

    # чуть склеим штрихи, чтобы bbox был стабильнее
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    thr = cv2.morphologyEx(thr, cv2.MORPH_OPEN, kernel, iterations=1)
    ys, xs = np.where(thr > 0)
    if len(xs) == 0:
        return img  # ничего не нашли
    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()

    # если контент почти на весь лист — не режем (или режем минимально)
    area = (x1 - x0) * (y1 - y0)
    if area < min_area:
        return img
    x0 = max(0, x0 - pad); y0 = max(0, y0 - pad)
    x1 = min(img.size[0], x1 + pad); y1 = min(img.size[1], y1 + pad)
    return img.crop((x0, y0, x1, y1))

In [None]:
def preprocess_image(img, mode="raw"):
    img = to_rgb(img)
    if mode == "raw":
        return img
    if mode == "crop_ink":
        return crop_to_ink(img, pad=20)
    raise ValueError(mode)

In [None]:
import matplotlib.pyplot as plt

ex = ds["test"][0]
raw = to_rgb(ex["image"])
cr  = crop_to_ink(ex["image"], pad=20)
print("raw size:", raw.size, "cropped size:", cr.size)
plt.figure(figsize=(12,5)); plt.imshow(raw); plt.axis("off"); plt.title("RAW"); plt.show()
plt.figure(figsize=(12,5)); plt.imshow(cr);  plt.axis("off"); plt.title("CROP_INK"); plt.show()

In [None]:
def split_lines_projection(pil_img, pad=8, min_line_h=18):
    img = np.array(to_rgb(pil_img).convert("L"))

    # 1) бинаризация (инверт: чернила=1)
    thr = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

    # 2) чистим шум
    thr = cv2.medianBlur(thr, 3)
    # 3) горизонтальная проекция
    row_sum = thr.sum(axis=1)  # чем больше - тем больше чернил в строке
    # порог: что считать "есть текст"
    t = max(50, int(np.percentile(row_sum, 70) * 0.25))
    mask = row_sum > t

    # 4) находим непрерывные интервалы строк
    ranges = []
    in_run = False
    start = 0
    for y, v in enumerate(mask):
        if v and not in_run:
            in_run = True
            start = y
        elif not v and in_run:
            in_run = False
            end = y
            if end - start >= min_line_h:
                ranges.append((start, end))
    if in_run:
        end = len(mask)
        if end - start >= min_line_h:
            ranges.append((start, end))

    # 5) режем по найденным диапазонам
    H, W = img.shape
    lines = []
    for y0, y1 in ranges:
        y0p = max(0, y0 - pad); y1p = min(H, y1 + pad)
        crop = to_rgb(pil_img).crop((0, y0p, W, y1p))
        lines.append(crop)
    return lines

In [None]:
import torch, time
from transformers import VisionEncoderDecoderModel, TrOCRProcessor

def recognize_page_by_lines(model_id, pil_img, batch_size=16, max_new_tokens=64):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    fp16 = (device == "cuda")

    processor = TrOCRProcessor.from_pretrained(model_id, use_fast=True)
    model = VisionEncoderDecoderModel.from_pretrained(model_id).to(device).eval()
    if fp16:
        model = model.half()

    lines = split_lines_projection(pil_img)
    if len(lines) == 0:
        return ""

    texts = []
    with torch.inference_mode():
        for i in range(0, len(lines), batch_size):
            chunk = lines[i:i+batch_size]
            pv = processor(images=chunk, return_tensors="pt").pixel_values.to(device)
            if fp16:
                pv = pv.half()

            gen = model.generate(pv, num_beams=1, max_new_tokens=max_new_tokens, do_sample=False)
            preds = processor.batch_decode(gen, skip_special_tokens=True)
            texts.extend([p.strip() for p in preds])
    return " ".join([t for t in texts if t])

def benchmark_linelevel(model_id, max_samples=20, max_new_tokens=64):
    data = ds["test"]
    n = min(max_samples, len(data))
    refs, hyps = [], []

    t0 = time.perf_counter()
    for i in range(n):
        ref = data[i]["text"]
        img = data[i]["image"]
        hyp = recognize_page_by_lines(model_id, img, batch_size=16, max_new_tokens=max_new_tokens)

        refs.append(ref)
        hyps.append(hyp)

    dt = time.perf_counter() - t0
    cer, w = calc_metrics(refs, hyps)
    return {"model": model_id, "samples": n, "CER": cer, "WER": w, "sec_total": dt, "page_per_sec": n/dt}

In [None]:
MODELS = [
    "kazars24/trocr-base-handwritten-ru",
    "cyrillic-trocr/trocr-handwritten-cyrillic",
    "akushsky/trocr-large-handwritten-ru",]

import pandas as pd

res = []
for m in MODELS:
    r = benchmark_linelevel(m, max_samples=10, max_new_tokens=64)
    print(r)
    res.append(r)

pd.DataFrame(res).sort_values(["CER","sec_total"], ascending=[True, True])

In [None]:
def benchmark_page_htr_two_models(
    model_id: str,
    split="test",
    max_pages=200,           # нормальный прогон- 200 страниц
    max_new_tokens_line=64,  # для строки обычно достаточно 32–64
    batch_size_lines=32,
    min_line_h=22,
    log_every_pages=25,
    out_dir="bench_page_htr"
):
    import os, json
    os.makedirs(out_dir, exist_ok=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    fp16 = (device == "cuda")

    processor = TrOCRProcessor.from_pretrained(model_id, use_fast=True)
    model = VisionEncoderDecoderModel.from_pretrained(model_id).to(device).eval()
    if fp16:
        model = model.half()

    data = ds[split]
    n_pages = min(max_pages, len(data))

    # 1) Подготовка: сегментируем страницы -> список линий
    page_refs = []
    page_line_imgs = []
    line_counts = []

    t_prep0 = time.perf_counter()
    for i in range(n_pages):
        ex = data[i]
        page_refs.append(ex["text"])
        lines = split_lines_projection(ex["image"], min_line_h=min_line_h)
        page_line_imgs.append(lines)
        line_counts.append(len(lines))

        if (i + 1) % log_every_pages == 0 or (i + 1) == n_pages:
            avg_lines = sum(line_counts) / len(line_counts)
            print(f"[prep] {i+1}/{n_pages} pages | avg lines/page: {avg_lines:.1f}")

    t_prep1 = time.perf_counter()

    # если сегментация почти ничего не нашла — метрики будут плохие
    total_lines = sum(line_counts)
    if total_lines == 0:
        raise RuntimeError("Сегментация не нашла ни одной строки. Проверь min_line_h/порог.")

    # плоский список всех линий
    flat_lines = [im for lines in page_line_imgs for im in lines]

    # 2) Инференс батчами по линиям
    line_texts = []
    if device == "cuda":
        torch.cuda.synchronize()
    t_inf0 = time.perf_counter()

    with torch.inference_mode():
        for i in range(0, len(flat_lines), batch_size_lines):
            chunk = flat_lines[i:i+batch_size_lines]
            pv = processor(images=chunk, return_tensors="pt").pixel_values.to(device)
            if fp16:
                pv = pv.half()

            gen = model.generate(
                pv,
                num_beams=1,
                max_new_tokens=max_new_tokens_line,
                do_sample=False,
            )
            preds = processor.batch_decode(gen, skip_special_tokens=True)
            line_texts.extend([p.strip() for p in preds])

    if device == "cuda":
        torch.cuda.synchronize()
    t_inf1 = time.perf_counter()

    # 3) Склейка линий обратно в текст страницы
    page_hyps = []
    idx = 0
    for cnt in line_counts:
        txt = " ".join([t for t in line_texts[idx:idx+cnt] if t])
        page_hyps.append(txt)
        idx += cnt

    # 4) Метрики
    cer, wer_ = calc_metrics(page_refs, page_hyps)

    prep_time = t_prep1 - t_prep0
    inf_time = t_inf1 - t_inf0
    total_time = prep_time + inf_time

    result = {
        "model": model_id,
        "split": split,
        "pages": n_pages,
        "lines_total": total_lines,
        "lines_per_page_avg": float(total_lines / n_pages),
        "CER": float(cer),
        "WER": float(wer_),
        "sec_prep": float(prep_time),
        "sec_infer": float(inf_time),
        "sec_total": float(total_time),
        "pages_per_sec": float(n_pages / total_time),
        "lines_per_sec": float(total_lines / inf_time),
        "device": device,
        "fp16": fp16,
        "params": {
            "batch_size_lines": batch_size_lines,
            "max_new_tokens_line": max_new_tokens_line,
            "min_line_h": min_line_h,
        }
    }

    safe = model_id.replace("/", "__")
    pd.DataFrame([result]).to_csv(f"{out_dir}/{safe}__summary.csv", index=False)

    # примеры страниц (первые 10)
    ex_df = pd.DataFrame({"ref": page_refs[:10], "pred": page_hyps[:10]})
    ex_df.to_csv(f"{out_dir}/{safe}__page_examples_first10.csv", index=False)

    return result

In [None]:
MODELS_2 = [
    "cyrillic-trocr/trocr-handwritten-cyrillic",
    "kazars24/trocr-base-handwritten-ru",]
results = []
for m in MODELS_2:
    r = benchmark_page_htr_two_models(
        m,
        split="test",
        max_pages=200,           # нормальный объём
        max_new_tokens_line=64,  # строка
        batch_size_lines=32,     # если OOM -> 16
        min_line_h=22,
        out_dir="bench_page_htr_2models"
    )
    print(r)
    results.append(r)

pd.DataFrame(results).sort_values(["CER", "sec_total"], ascending=[True, True])