In [None]:
import logging
import os
import pathlib
from pathlib import Path
from sys import stdout

import torch
from torch import load as torch_load

from doctr.datasets.vocabs import VOCABS
from doctr.models import (
    crnn_vgg16_bn,
    db_resnet50,
    detection_predictor,
    ocr_predictor,
    recognition_predictor,
)
from pd_book_tools.pgdp.pgdp_results import PGDPExport

from data_labeler.ipynb_labeler import IpynbLabeler

In [None]:
class ThreadFilter:
    def __init__(self, id):
        self.id = id

    def filter(self, record):
        return record.thread == self.id


formatter = logging.Formatter(
    "%(asctime)s-%(name)s-%(levelname)s-%(filename)s-%(lineno)s-%(funcName)s | %(message)s"
)

sysout_handler = logging.StreamHandler(stdout)
sysout_handler.setLevel(logging.CRITICAL)
sysout_handler.setFormatter(formatter)

pd_book_tools_logger: logging.Logger = logging.getLogger("pd_book_tools")
pd_book_tools_logger.setLevel(logging.DEBUG)
if pd_book_tools_logger.hasHandlers():
    pd_book_tools_logger.handlers.clear()

doctr_logger: logging.Logger = logging.getLogger("doctr")
doctr_logger.setLevel(logging.ERROR)
if doctr_logger.hasHandlers():
    doctr_logger.handlers.clear()

matplotlib_logger: logging.Logger = logging.getLogger("matplotlib")
if matplotlib_logger.hasHandlers():
    matplotlib_logger.handlers.clear()
matplotlib_logger.setLevel(logging.ERROR)

ipynb_labeler_logger: logging.Logger = logging.getLogger("data_labeler")
if ipynb_labeler_logger.hasHandlers():
    ipynb_labeler_logger.handlers.clear()
ipynb_labeler_logger.setLevel(logging.DEBUG)

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
if logger.hasHandlers():
    logger.handlers.clear()

logger.addHandler(sysout_handler)


logfile = pathlib.Path("all-logs.log")
log_file_handler = logging.FileHandler(filename=logfile, mode="w", encoding="utf-8")
log_file_handler.setFormatter(formatter)
log_file_handler.setLevel(logging.DEBUG)
pd_book_tools_logger.addHandler(log_file_handler)
ipynb_labeler_logger.addHandler(log_file_handler)


# logfile = pathlib.Path("ipynb_labeler.log")
# ipynb_log_file_handler = logging.FileHandler(
#     filename=logfile, mode="w", encoding="utf-8"
# )
# # ipynb_formatter = logging.Formatter("%(levelname)s-%(funcName)s-%(message)s")
# ipynb_log_file_handler.setFormatter(formatter)
# ipynb_log_file_handler.setLevel(logging.DEBUG)
# ipynb_labeler_logger.addHandler(ipynb_log_file_handler)

In [None]:
full_predictor = None
# check if file exists
if os.path.exists("ml-models/detection-model-finetuned.pt") and os.path.exists(
    "ml-models/recognition-model-finetuned.pt"
):
    # Check if GPU is available
    device, device_nbr = (
        ("cuda", "cuda:0") if torch.cuda.is_available() else ("cpu", "cpu")
    )
    logger.info(f"Using {device} for OCR")

    finetuned_detection = "ml-models/detection-model-finetuned.pt"
    finetuned_recognition = "ml-models/recognition-model-finetuned.pt"

    det_model = db_resnet50(pretrained=True).to(device)
    det_params = torch_load(finetuned_detection, map_location=device_nbr)
    det_model.load_state_dict(det_params)

    vocab = "".join(
        sorted(
            dict.fromkeys(VOCABS["multilingual"] + "⸺¡¿—‘’“”′″" + VOCABS["currency"])
        )
    )

    reco_model = crnn_vgg16_bn(
        pretrained=True,
        pretrained_backbone=True,
        vocab=vocab,  # model was fine-tuned on multilingual data with some additional unicode characters
    ).to(device)
    reco_params = torch_load(finetuned_recognition, map_location=device_nbr)
    reco_model.load_state_dict(reco_params)

    full_predictor = ocr_predictor(
        det_arch=det_model,
        reco_arch=reco_model,
        pretrained=True,
        assume_straight_pages=True,
        disable_crop_orientation=True,
    )

    det_predictor = detection_predictor(
        arch=det_model,
        pretrained=True,
        assume_straight_pages=True,
    )

    reco_predictor = recognition_predictor(
        arch=reco_model,
        pretrained=True,
    )

    full_predictor.det_predictor = det_predictor
    full_predictor.reco_predictor = reco_predictor

In [None]:
# A history of the american people - projectID629292e7559a8
# Chile and the Nitrate Fields - projectID63ac684a641d4
# From magic to science - projectID6737b15d33ff3
# Credulities past and present - projectID63ac6757567bd
# French furniture and decoration (has sidenotes and footnotes) - projectID66c62fca99a93
# The book of filial duty - projectID67658de495d0c
project_id = "projectID63ac684a641d4"

source_file = f"source-pgdp-data/output/{project_id}/pages.json"
pgdp_export = PGDPExport.from_json_file(source_file)

i = IpynbLabeler(
    pgdp_export=pgdp_export,
    labeled_ocr_path=Path("./matched-ocr"),
    training_set_output_path=Path("./ml-training"),
    validation_set_output_path=Path("./ml-validation"),
    monospace_font_name="DPSansMono",
    monospace_font_path=Path("./DPSansMono.ttf"),
    start_page_idx=35,
    doctr_predictor=full_predictor,
)

In [None]:
# —