In [None]:
import os
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.append(project_root)

from transformers import TrOCRProcessor
import mlflow

from utils_test import evaluate_model

from trocr.utils.utils import CER_SCORE, inference
# from OCR_VQA.data_preparation import VQAProcessor
from custom_dataset.data_preparation import CustomDataProcessor

In [None]:
# Set mlflow experiment

experiment_name = 'trocr_train'

mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment(experiment_name)

# 1. Get model and dataset

In [None]:
# Read run_name and run_id from upstream

%store -r run_id
%store -r model_name

In [None]:
model_name = f'runs:/{run_id}/{model_name}'

# Load model
pipeline = mlflow.transformers.load_model(model_name)
model = pipeline.model

In [None]:
# TrOCRProcessor class wraps image processor class and tokenizer class
dataset_name = 'ocr-dataset'
# dataset_name = os.path.join(project_root, 'custom_dataset', 'data', dataset_name) # For local dataset
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")

data_processor = CustomDataProcessor(processor)
train_dataset, val_dataset, test_dataset, train_size = data_processor(
    dataset_type='S3', # change to 'local' for using local stored dataset
    train_frac=0.95,
    val_frac=0.025,
    dataset_name=dataset_name,
    batch_size=16,
)

# 2. Evaluate model on train, valid and test data

In [None]:
_, _, train_cer_value = evaluate_model(model, processor, train_dataset.dataset.indeces, CER_SCORE)
train_cer_value

In [None]:
_, _, val_cer_value = evaluate_model(model, processor, val_dataset.dataset.indeces, CER_SCORE)
val_cer_value

In [None]:
_, _, test_cer_value = evaluate_model(model, processor, test_dataset.dataset.indeces, CER_SCORE)
test_cer_value

# 3. Inference

In [None]:
os.path.join(project_root, 'custom_dataset', 'data', dataset_name)

__1. Inference on images from train, val and test datasets__

In [None]:
img, text = train_dataset[5]

img, text_generated = inference(img, model, processor)
print(text_generated)
img

In [None]:
img, text = val_dataset[9]

img, text_generated = inference(img, model, processor)
print(text_generated)
img

In [None]:
img, text = test_dataset[4]

img, text_generated = inference(img, model, processor)
print(text_generated)
img

__2. Inference on new images__

In [None]:
image_fold = os.path.join(project_root, 'test_images')

In [None]:
img, text_generated = inference(f'{image_fold}/test_screen.png', model, processor)
print(text_generated)
img

In [None]:
img, text_generated = inference(f'{image_fold}/one_channel_image.jpg', model, processor)
print(text_generated)
img

In [None]:
img, text_generated = inference(f'{image_fold}/test_screen_2.png', model, processor)
print(text_generated)
img

In [None]:
img, text_generated = inference(f'{image_fold}/a.png', model, processor)
print(text_generated)
img

In [None]:
img, text_generated = inference(f'{image_fold}/bred.png', model, processor)
print(text_generated)
img