In [None]:
from datetime import datetime
import os
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.append(project_root)

from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    TrainingArguments,
    Trainer,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    EarlyStoppingCallback,
    pipeline,
)
import torch
import mlflow

from utils_train import save_model_and_history, plot_history
# from OCR_VQA.data_preparation import VQAProcessor
from custom_dataset.data_preparation import CustomDataProcessor

In [None]:
# Set mlflow experiment

experiment_name = 'trocr_train'

mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment(experiment_name)

In [None]:
# Get all runs

experiment = mlflow.get_experiment_by_name(experiment_name)
df_runs = mlflow.search_runs(experiment_ids=[experiment.experiment_id])

# 1. Dataset preparation

In [None]:
# TrOCRProcessor class wraps image processor class and tokenizer class
dataset_name = 'ocr-dataset'
# dataset_name = os.path.join(project_root, 'custom_dataset', 'data', dataset_name) # For local dataset
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")

data_processor = CustomDataProcessor(processor)
train_dataset, val_dataset, test_dataset, train_size = data_processor(
    dataset_type='S3', # change to 'local' for using local stored dataset
    train_frac=0.95,
    val_frac=0.025,
    dataset_name=dataset_name,
    batch_size=16,
)

# 2. Train

In [None]:
# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
model = VisionEncoderDecoderModel.from_pretrained("checkpoints/06.04.25_S3_100k_v2/checkpoint-10000")

# Choosing a strategy for text generation
# All strategies can be found here: 
# https://huggingface.co/docs/transformers/v4.48.2/en/main_classes/text_generation#transformers.GenerationConfig

gen_config = dict(
    num_beams=7,
    num_beam_groups=1,
    do_sample=True,
    max_new_tokens=200,
    early_stopping=True, # True stops when num_beams candidates are reached
    temperature=1.5, # T <(>) 1 sharpens (smoothes) probability distribution
    top_k=100, # Only top k candidates with highest probabilities will be considered
    diversity_penalty=0, # The value is substracted from beam score if the token is generated by another group
    repetition_penalty=1.2, # AFAIK Directly multiplied by temperature
    # decoder_start_token_id=processor.tokenizer.bos_token_id,
    # pad_token_id=processor.tokenizer.pad_token_id,
)

model.generation_config.update(**gen_config) # Update existing generation config with new values
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

In [None]:
# Give name to current run (model) and create it
v = 1
date = datetime.now().strftime("%d.%m.%y")
model_name = f'{date}_{dataset_name}_v{v}'
if not df_runs.empty:
    while (df_runs['params.run_name'] == model_name).any():
        v += 1
        model_name = f'{date}_{dataset_name}_v{v}'
print(model_name)

with mlflow.start_run(run_name=model_name) as run:
    run_id = run.info.run_id
os.environ['MLFLOW_RUN_ID'] = run_id
print(run_id)

In [None]:
# Store run_name and run_id for downstream use

%store run_id
%store model_name

In [None]:
# Define parameters

output_dir = f'checkpoints/{model_name}'
num_epochs = 1
batch_size = 8
init_learning_rate = 2.0e-6 #1e-5
max_steps = int((train_size / batch_size) * num_epochs)
eval_steps = logging_steps = 1000

# Initialize the optimizer. See this for optimizers:
# https://huggingface.co/docs/transformers/en/main_classes/optimizer_schedules
optimizer = torch.optim.AdamW(model.parameters(), lr=init_learning_rate)

# Set up a learning rate scheduler. See this for scheduler types:
# https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/optimizer_schedules#transformers.SchedulerType
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(max_steps * 0), # % of the steps for warmup
    num_training_steps=max_steps,
)

early_stopping = EarlyStoppingCallback(early_stopping_patience=5)

training_hyperparams = TrainingArguments(
    output_dir=output_dir,
    report_to='mlflow',
    run_name=model_name, # for mlflow logging
    # learning_rate=1e-4,
    # lr_scheduler_type='linear', 
    eval_strategy='steps', # evaluate on eval_dataset every eval_steps
    eval_steps=eval_steps,
    eval_accumulation_steps=logging_steps,
    logging_steps=logging_steps, # update steps to perform before output logs
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    # max_steps=max_steps, # Overrides num_train_epochs
    save_total_limit=1, # Save only last checkpoint
    load_best_model_at_end=True, # Save best model
    metric_for_best_model='eval_loss', # Key from dict, returned by compute_metrics, or some predefined values
    greater_is_better=False,
    save_steps=10000,
    # logging_dir='trocr_checkpoints/logs',
    # fp16=True,
    fp16_full_eval=True,
)

In [None]:
# Training the model

if not os.path.exists(output_dir):
    trainer = Trainer(
        model=model,
        args=training_hyperparams,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        # compute_metrics=compute_metrics,
        optimizers=(optimizer, lr_scheduler),
        callbacks=[early_stopping]
        # processing_class=processor,
        # data_collator=...
    )
    trainer.train()
else:
    raise ValueError(f"Model '{model_name}' exists, specify another name")

# 3. Save model and history

In [None]:
# Wrap model and processor to pipeline to log into mlflow

ocr_pipeline = pipeline(
    "image-to-text",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device='cuda',
)

In [None]:
epochs = [i['epoch'] for i in trainer.state.log_history if 'eval_loss' in i]
train_loss = [i['loss'] for i in trainer.state.log_history if 'loss' in i]
val_loss = [i['eval_loss'] for i in trainer.state.log_history if 'eval_loss' in i]
# val_cer = [10 * i['eval_cer'] for i in trainer.state.log_history if 'eval_cer' in i]

hist = {'train_loss': train_loss, 'val_loss': val_loss}#, 'val_cer * 10': val_cer}

In [None]:
# Create history plot

fig = plot_history(
    epochs,
    hist,
    run_name=model_name,
    figsize=(15, 10),
    to_image=True,
)

In [None]:
with mlflow.start_run(run_name=model_name, run_id=run_id):

    mlflow.transformers.log_model(ocr_pipeline, model_name)
    mlflow.log_params(training_hyperparams.to_dict())
    mlflow.log_image(image=fig, artifact_file="history.png")