In [None]:
import os
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    TrainingArguments,
    Trainer,
)
import torch

from utils import inference, save_model_and_history, evaluate_model, cer_score, plot_history
# from OCR_VQA.data_preparation import VQAProcessor
from custom_dataset.data_preparation import CustomDataProcessor

# 1. Dataset preparation

In [2]:
# TrOCRProcessor class wraps image processor class and tokenizer class
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")

data_processor = CustomDataProcessor(processor)
train_dataset, val_dataset, test_dataset, train_size = data_processor(dataset_batch_size=1000)

# 2. Train

In [3]:
# Postprocessing functions

def preprocess_logits_for_metrics(logits, labels):
    output_ids = torch.argmax(logits[0], dim=-1)
    return output_ids, labels

def compute_metrics(eval_pred):
    output_ids, labels_ids = eval_pred
    words_predicted = processor.tokenizer.batch_decode(output_ids[0], skip_special_tokens=False)
    words_labels = processor.tokenizer.batch_decode(labels_ids, skip_special_tokens=False)
    return {'cer': cer_score.compute(predictions=words_predicted, references=words_labels)}

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")

model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

In [5]:
# Define parameters

model_name = '04.02.25_12787_v1'
num_epochs = 30
batch_size = 8
max_steps = int((train_size / batch_size) * num_epochs)
eval_steps = logging_steps = 1000

training_hyperparams = TrainingArguments(
    output_dir=f'trocr_checkpoints/{model_name}',
    eval_strategy='steps', # evaluate on eval_dataset every eval_steps
    eval_steps=eval_steps,
    eval_accumulation_steps=logging_steps,
    logging_steps=logging_steps, # update steps to perform before output logs
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # num_train_epochs=num_epochs,
    save_total_limit=1, # Save only last checkpoint
    load_best_model_at_end=True, # Save best model
    # metric_for_best_model=cer_score, # Metric to evaluate checnkpoints
    # greater_is_better=True,
    save_steps=10000,
    # logging_dir='trocr_checkpoints/logs',
    max_steps=max_steps, # Overrides num_train_epochs
    # fp16=True,
    fp16_full_eval=True,
)

In [None]:
# Training the model

if not os.path.exists(f'trocr_checkpoints/{model_name}'):
    trainer = Trainer(
        model=model,
        args=training_hyperparams,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        compute_metrics=compute_metrics,
        # processing_class=processor,
        # data_collator=...
    )
    trainer.train()
else:
    raise ValueError(f"Model '{model_name}' exists, specify anothe name")

# 3. Save model and history

In [None]:
save_model_and_history(model_name, trainer)

# 4. Evaluation and Inference

__1. Plot training history__

In [None]:
epochs = [i['epoch'] for i in trainer.state.log_history if 'eval_cer' in i]
train_loss = [i['loss'] for i in trainer.state.log_history if 'loss' in i]
val_loss = [i['eval_loss'] for i in trainer.state.log_history if 'eval_loss' in i]
val_cer = [10 * i['eval_cer'] for i in trainer.state.log_history if 'eval_cer' in i]

hist = {'train_loss': train_loss, 'val_loss': val_loss, 'val_cer * 10': val_cer}

In [None]:
plot_history(
    epochs,
    hist,
    run_name=model_name,
    figsize=(15, 10)
)

__2. Load model__

In [None]:
model_name = '04.02.25_12787_v1'

model_path = f'models/{model_name}/model'

model = VisionEncoderDecoderModel.from_pretrained(model_path)
model = model.eval()

__3. Evaluate model on datasets__

In [None]:
_, _, train_cer_value = evaluate_model(model, processor, train_dataset.dataset.indeces, cer_score)
train_cer_value

In [None]:
_, _, val_cer_value = evaluate_model(model, processor, val_dataset.dataset.indeces, cer_score)
val_cer_value

In [None]:
_, _, test_cer_value = evaluate_model(model, processor, test_dataset.dataset.indeces, cer_score)
test_cer_value

__4. Inference on new images__

In [43]:
image_fold = 'test_images'

In [None]:
img, text = inference(f'{image_fold}/test_screen.png', model, processor)
print(text)
img

In [None]:
img, text = inference(f'{image_fold}/one_channel_image.jpg', model, processor)
print(text)
img

In [None]:
img, text = inference(f'{image_fold}/test_screen_2.png', model, processor)
print(text)
img

__5. Inference on images from dataset__

In [7]:
image_fold = 'custom_dataset/data/images'

In [None]:
idx = 2

img, text = inference(f'{image_fold}/image_{idx}.png', model, processor)
print(text)
img

In [None]:
idx = 35

img, text = inference(f'{image_fold}/image_{idx}.png', model, processor)
print(text)
img

In [None]:
idx = 1

img, text = inference(f'{image_fold}/image_{idx}.png', model, processor)
print(text)
img