In [None]:
import pathlib

# import os
import subprocess

from doctr.datasets.vocabs import VOCABS

vocab = VOCABS["multilingual"] + "⸺¡¿—‘’“”′″" + VOCABS["currency"]


def finetune_detection(train_path, val_path, model_output_path, n_epochs=100):
    DOCTR_BASE = "../doctr"
    SCRIPT_PATH = "references/detection/train_pytorch.py"
    ARGS = [
        "--train_path",
        train_path,
        "--val_path",
        val_path,
        "db_resnet50",
        "--pretrained",
        "--lr",
        ".002",
        "--epochs",
        str(n_epochs),
        "--batch_size",
        str(2),
        "--device",
        "0",
        "--name",
        model_output_path,
    ]

    command = ["python", pathlib.Path(DOCTR_BASE, SCRIPT_PATH).resolve(), *ARGS]

    subprocess.call(command)


def finetune_recognition(train_path, val_path, name, n_epochs=100):
    DOCTR_BASE = "../doctr"
    SCRIPT_PATH = "references/recognition/train_pytorch.py"
    ARGS = [
        "--train_path",
        train_path,
        "--val_path",
        val_path,
        "crnn_vgg16_bn",
        "--pretrained",
        "--lr",
        ".001",
        "--epochs",
        str(n_epochs),
        "--name",
        name,
        "--device",
        "0",
        "--vocab",
        "CUSTOM:" + vocab,
    ]  # this does not include em dashes, upside down question marks, etc. TODO - define a vocab to use in training that includes these

    command = ["python", pathlib.Path(DOCTR_BASE, SCRIPT_PATH).resolve(), *ARGS]

    subprocess.call(command)

In [None]:
doctr_base = "doctr_package"

detection_train_path = "ml-training/detection"
detection_val_path = "ml-validation/detection"
recognition_train_path = "ml-training/recognition"
recognition_val_path = "ml-validation/recognition"

In [None]:
finetune_detection(
    detection_train_path, detection_val_path, "ml-models/detection-model-finetuned"
)

finetune_recognition(
    recognition_train_path,
    recognition_val_path,
    "ml-models/recognition-model-finetuned",
)

In [None]:
# BELOW is not working yet

In [None]:
# from doctr.models import ocr_predictor, detection_predictor, recognition_predictor
# from doctr.models import db_resnet50, crnn_vgg16_bn
# from importlib import reload
# import doctr.models
# from torch import load as torch_load

# from doctr_package.doctr.models.recognition.crnn.pytorch import CRNN

# reload(doctr.models)


# def create_doctr_model(det_pt_path=None, reco_pt_path=None):
#     det_model = db_resnet50(pretrained=True).to("cuda")

#     if det_pt_path is not None:
#         det_params = torch_load(det_pt_path, map_location="cuda:0")
#         det_model.load_state_dict(det_params)

#     if reco_pt_path is not None:
#         reco_model: CRNN = crnn_vgg16_bn(
#             pretrained=True,
#             pretrained_backbone=True,
#             vocab=vocab,
#         ).to("cuda")
#         reco_params = torch_load(reco_pt_path, map_location="cuda:0")
#         reco_model.load_state_dict(reco_params)
#     else:
#         reco_model = crnn_vgg16_bn(
#             pretrained=True,
#             pretrained_backbone=True,
#         ).to("cuda")

#     full_predictor = ocr_predictor(
#         det_arch=det_model,
#         reco_arch=reco_model,
#         pretrained=True,
#         assume_straight_pages=True,
#         disable_crop_orientation=True,
#     )

#     det_predictor = detection_predictor(
#         arch=det_model,
#         pretrained=True,
#         assume_straight_pages=True,
#     )

#     reco_predictor = recognition_predictor(
#         arch=reco_model,
#         pretrained=True,
#     )

#     # baseline performs optimally with default
#     # the finetuned model seems to work much better when instantiated this way
#     # the ocr_predictor must have some extra default config setting
#     # that hinders the finetuned model
#     if det_pt_path is not None:
#         full_predictor.det_predictor = det_predictor
#     if reco_pt_path is not None:
#         full_predictor.reco_predictor = reco_predictor

#     # this might tighten up boxes a bit
#     # full_predictor.det_predictor.model.postprocessor.unclip_ratio = 1.2
#     # det_predictor.model.postprocessor.unclip_ratio = 1.2

#     return det_predictor, full_predictor


# baseline_det_predictor, baseline_ocr_predictor = create_doctr_model(None)
# finetuned_det_predictor, finetuned_ocr_predictor = create_doctr_model(
#     "training-output/detection-model-finetuned.pt",
#     "training-output/recognition-model-finetuned.pt",
# )

In [None]:
# from doctr_package.doctr.io.reader import DocumentFile
# from os import listdir

# # this is a bit of a hack, but it works
# train_finetune_preds = []
# train_baseline_preds = []
# train_finetune_exports = []
# train_baseline_exports = []
# for f in listdir(detection_train_path + "/images"):
#     print(f)
#     doctr_image = DocumentFile.from_images(detection_train_path + "/images/" + f)

#     baseline_result = baseline_det_predictor(doctr_image)
#     train_baseline_preds.append(baseline_result[0]["words"])

#     baseline_export = baseline_ocr_predictor(doctr_image).export()
#     train_baseline_exports.append(baseline_export)

#     finetuned_result = finetuned_det_predictor(doctr_image)
#     train_finetune_preds.append(finetuned_result[0]["words"])

#     finetuned_export = finetuned_ocr_predictor(doctr_image).export()
#     train_finetune_exports.append(finetuned_export)

# val_finetune_preds = []
# val_baseline_preds = []
# val_finetune_exports = []
# val_baseline_exports = []
# for f in listdir(detection_val_path + "/images"):
#     print(f)
#     doctr_image = DocumentFile.from_images(detection_val_path + "/images/" + f)
#     # print(doctr_image)

#     finetuned_result = finetuned_ocr_predictor.det_predictor(doctr_image)
#     val_finetune_preds.append(finetuned_result[0]["words"])

#     finetuned_export = finetuned_ocr_predictor(doctr_image).export()
#     val_finetune_exports.append(finetuned_export)

#     baseline_result = baseline_ocr_predictor.det_predictor(doctr_image)
#     val_baseline_preds.append(baseline_result[0]["words"])

#     baseline_export = baseline_ocr_predictor(doctr_image).export()
#     train_baseline_exports.append(baseline_export)

In [None]:
# ## TODO this needs to be update for the targets to check

# from doctr.utils.metrics import LocalizationConfusion

# lc_baseline_train = LocalizationConfusion(iou_thresh=0.5)
# lc_baseline_val = LocalizationConfusion(iou_thresh=0.5)
# lc_finetune_train = LocalizationConfusion(iou_thresh=0.5)
# lc_finetune_val = LocalizationConfusion(iou_thresh=0.5)

# print("format: recall, precision, mean_iou")
# print("TRAIN")
# for i, fname in enumerate(listdir(detection_train_path + "/images")):
#     print("========")
#     print(fname)
#     lc_baseline_for_file = LocalizationConfusion(iou_thresh=0.5)
#     lc_finetune_for_file = LocalizationConfusion(iou_thresh=0.5)

#     target = targets[i]
#     baseline_pred = train_baseline_preds[i]
#     finetune_pred = train_finetune_preds[i]

#     lc_baseline_for_file.update(target["words"], baseline_pred[:, :4])
#     lc_finetune_for_file.update(target["words"], finetune_pred[:, :4])
#     print("baseline:", lc_baseline_for_file.summary())
#     print("finetune:", lc_finetune_for_file.summary())
#     lc_baseline_train.update(target["words"], baseline_pred[:, :4])
#     lc_finetune_train.update(target["words"], finetune_pred[:, :4])

# print("\nVAL")
# for i, fname in enumerate(val_fnames):
#     print("========")
#     print(fname)
#     lc_baseline_for_file = LocalizationConfusion(iou_thresh=0.5)
#     lc_finetune_for_file = LocalizationConfusion(iou_thresh=0.5)

#     target = val_targets[i]
#     baseline_pred = val_baseline_preds[i]
#     finetune_pred = val_finetune_preds[i]

#     lc_baseline_for_file.update(target["words"], baseline_pred[:, :4])
#     lc_finetune_for_file.update(target["words"], finetune_pred[:, :4])
#     print("baseline:", lc_baseline_for_file.summary())
#     print("finetune:", lc_finetune_for_file.summary())
#     lc_baseline_val.update(target["words"], baseline_pred[:, :4])
#     lc_finetune_val.update(target["words"], finetune_pred[:, :4])


# print("\noverall baseline (train):", lc_baseline_train.summary())
# print("overall finetune (train):", lc_finetune_train.summary())
# print("overall baseline (val):", lc_baseline_val.summary())
# print("overall finetune (val):", lc_finetune_val.summary())

In [None]:
# # TODO this is broken
# doctr_image = DocumentFile.from_images(
#     detection_train_path + "/images/" + train_fnames[4]
# )

# baseline_ocr_predictor(doctr_image).show()
# finetuned_ocr_predictor(doctr_image).show()

# # TODO - add comparison of the baseline vs finetuned recongition models