In [None]:
import os
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.append(project_root)

from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    TrainingArguments,
    Trainer,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    EarlyStoppingCallback,
)
import torch

from utils_dev import save_model_and_history, evaluate_model, cer_score, plot_history
from trocr.utils.utils_inf import inference
# from OCR_VQA.data_preparation import VQAProcessor
from custom_dataset.data_preparation import CustomDataProcessor

# 1. Dataset preparation

In [None]:
# TrOCRProcessor class wraps image processor class and tokenizer class
dataset_name = 'wiki_sentences_300k' # 'S3_100k'
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")
data_path = os.path.join(project_root, 'custom_dataset', 'data', dataset_name)

data_processor = CustomDataProcessor(processor)
train_dataset, val_dataset, test_dataset, train_size = data_processor(
    dataset_type='S3', # change to 'local' for using local stored dataset
    train_frac=0.95,
    val_frac=0.025,
    bucket='ocr-dataset', # change param name to data_fold for using local stored dataset
    batch_size=32,
)

# 2. Train

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")

# Choosing a strategy for text generation
# All strategies can be found here: 
# https://huggingface.co/docs/transformers/v4.48.2/en/main_classes/text_generation#transformers.GenerationConfig

gen_config = dict(
    num_beams=7,
    num_beam_groups=1,
    do_sample=True,
    max_new_tokens=200,
    early_stopping=True, # True stops when num_beams candidates are reached
    temperature=1.5, # T <(>) 1 sharpens (smoothes) probability distribution
    top_k=100, # Only top k candidates with highest probabilities will be considered
    diversity_penalty=0, # The value is substracted from beam score if the token is generated by another group
    repetition_penalty=1.2, # AFAIK Directly multiplied by temperature
    # decoder_start_token_id=processor.tokenizer.bos_token_id,
    # pad_token_id=processor.tokenizer.pad_token_id,
)

model.generation_config.update(**gen_config) # Update existing generation config with new values
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

In [None]:
# Define parameters

model_name = f'26.03.25_{dataset_name}'
output_dir = f'checkpoints/{model_name}'
num_epochs = 10
batch_size = 8
init_learning_rate = 1e-3
max_steps = int((train_size / batch_size) * num_epochs)
eval_steps = logging_steps = 1000

# Initialize the optimizer. See this for optimizers:
# https://huggingface.co/docs/transformers/en/main_classes/optimizer_schedules
optimizer = torch.optim.AdamW(model.parameters(), lr=init_learning_rate)

# Set up a learning rate scheduler. See this for scheduler types:
# https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/optimizer_schedules#transformers.SchedulerType
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(max_steps * 0), # % of the steps for warmup
    num_training_steps=max_steps,
)

early_stopping = EarlyStoppingCallback(early_stopping_patience=5)

training_hyperparams = TrainingArguments(
    output_dir=output_dir,
    # learning_rate=1e-4,
    # lr_scheduler_type='linear', 
    eval_strategy='steps', # evaluate on eval_dataset every eval_steps
    eval_steps=eval_steps,
    eval_accumulation_steps=logging_steps,
    logging_steps=logging_steps, # update steps to perform before output logs
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    # max_steps=max_steps, # Overrides num_train_epochs
    save_total_limit=1, # Save only last checkpoint
    load_best_model_at_end=True, # Save best model
    metric_for_best_model='eval_loss', # Key from dict, returned by compute_metrics, or some predefined values
    greater_is_better=False,
    save_steps=10000,
    # logging_dir='trocr_checkpoints/logs',
    # fp16=True,
    fp16_full_eval=True,
)

In [None]:
# Training the model

if not os.path.exists(output_dir):
    trainer = Trainer(
        model=model,
        args=training_hyperparams,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        # compute_metrics=compute_metrics,
        optimizers=(optimizer, lr_scheduler),
        callbacks=[early_stopping]
        # processing_class=processor,
        # data_collator=...
    )
    trainer.train()
else:
    raise ValueError(f"Model '{model_name}' exists, specify another name")

# 3. Save model and history

In [None]:
save_model_and_history(model_name, trainer)

# 4. Evaluation and Inference

__1. Plot training history__

In [None]:
epochs = [i['epoch'] for i in trainer.state.log_history if 'eval_loss' in i]
train_loss = [i['loss'] for i in trainer.state.log_history if 'loss' in i]
val_loss = [i['eval_loss'] for i in trainer.state.log_history if 'eval_loss' in i]
# val_cer = [10 * i['eval_cer'] for i in trainer.state.log_history if 'eval_cer' in i]

hist = {'train_loss': train_loss, 'val_loss': val_loss}#, 'val_cer * 10': val_cer}

In [None]:
plot_history(
    epochs,
    hist,
    run_name=model_name,
    figsize=(15, 10),
)

__2. Load model__

In [None]:
model_name = f'20.03.25_{dataset_name}'

model_path = f'models/{model_name}/model'

model = VisionEncoderDecoderModel.from_pretrained(model_path)
model = model.eval()

__3. Evaluate model on datasets__

In [None]:
_, _, train_cer_value = evaluate_model(model, processor, train_dataset.dataset.indeces, cer_score, data_path=data_path)
train_cer_value

In [None]:
_, _, val_cer_value = evaluate_model(model, processor, val_dataset.dataset.indeces, cer_score, data_path=data_path)
val_cer_value

In [None]:
_, _, test_cer_value = evaluate_model(model, processor, test_dataset.dataset.indeces, cer_score, data_path=data_path)
test_cer_value

__4. Inference on images from train, valid and test datasets__

In [None]:
image_fold = os.path.join(data_path, 'images')

In [None]:
idx = train_dataset.dataset.indeces[5]

img, text = inference(f'{image_fold}/image_{idx}.png', model, processor)
print(text)
img

In [None]:
idx = val_dataset.dataset.indeces[9]

img, text = inference(f'{image_fold}/image_{idx}.png', model, processor)
print(text)
img

In [None]:
idx = test_dataset.dataset.indeces[4]

img, text = inference(f'{image_fold}/image_{idx}.png', model, processor)
print(text)
img

__5. Inference on new images__

In [None]:
image_fold = os.path.join(project_root, 'test_images')

In [None]:
img, text = inference(f'{image_fold}/test_screen.png', model, processor)
print(text)
img

In [None]:
img, text = inference(f'{image_fold}/one_channel_image.jpg', model, processor)
print(text)
img

In [None]:
img, text = inference(f'{image_fold}/test_screen_2.png', model, processor)
print(text)
img

In [None]:
img, text = inference(f'{image_fold}/a.png', model, processor)
print(text)
img

In [None]:
img, text = inference(f'{image_fold}/bred.png', model, processor)
print(text)
img