In [None]:
from datetime import datetime
import os
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
sys.path.append(project_root)

from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    TrainingArguments,
    Trainer,
    get_cosine_schedule_with_warmup,
    get_cosine_with_hard_restarts_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    EarlyStoppingCallback,
)
import torch
import mlflow

from train.utils_train import save_model_and_history, plot_history
# from OCR_VQA.data_preparation import VQAProcessor
from custom_dataset.data_preparation import CustomDataProcessor

  from .autonotebook import tqdm as notebook_tqdm


In [29]:
experiment_name = 'trocr_train'

mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment(experiment_name)

2025/04/09 20:41:30 INFO mlflow.tracking.fluent: Experiment with name 'trocr_train' does not exist. Creating a new experiment.


<Experiment: artifact_location='mlflow-artifacts:/778888683745344032', creation_time=1744220490751, experiment_id='778888683745344032', last_update_time=1744220490751, lifecycle_stage='active', name='trocr_train', tags={}>

In [30]:
experiment = mlflow.get_experiment_by_name(experiment_name)
runs_df = mlflow.search_runs(experiment_ids=[experiment.experiment_id])

In [33]:
runs_df

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time


# 1. Dataset preparation

In [None]:
# TrOCRProcessor class wraps image processor class and tokenizer class
dataset_name = 'ocr-dataset'
# dataset_name = os.path.join(project_root, 'custom_dataset', 'data', dataset_name) # For local dataset
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")


data_processor = CustomDataProcessor(processor)
train_dataset, val_dataset, test_dataset, train_size = data_processor(
    dataset_type='S3', # change to 'local' for using local stored dataset
    train_frac=0.95,
    val_frac=0.025,
    dataset_name=dataset_name,
    batch_size=16,
)

# 2. Train

In [None]:
# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
model = VisionEncoderDecoderModel.from_pretrained("checkpoints/06.04.25_S3_100k_v2/checkpoint-10000")

# Choosing a strategy for text generation
# All strategies can be found here: 
# https://huggingface.co/docs/transformers/v4.48.2/en/main_classes/text_generation#transformers.GenerationConfig

gen_config = dict(
    num_beams=7,
    num_beam_groups=1,
    do_sample=True,
    max_new_tokens=200,
    early_stopping=True, # True stops when num_beams candidates are reached
    temperature=1.5, # T <(>) 1 sharpens (smoothes) probability distribution
    top_k=100, # Only top k candidates with highest probabilities will be considered
    diversity_penalty=0, # The value is substracted from beam score if the token is generated by another group
    repetition_penalty=1.2, # AFAIK Directly multiplied by temperature
    # decoder_start_token_id=processor.tokenizer.bos_token_id,
    # pad_token_id=processor.tokenizer.pad_token_id,
)

model.generation_config.update(**gen_config) # Update existing generation config with new values
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

VisionEncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [None]:
# Define parameters

date = datetime.now().strftime("%d.%m.%y")

model_name = f'{date}_{dataset_name}_v1'

output_dir = f'checkpoints/{model_name}'
num_epochs = 10
batch_size = 8
init_learning_rate = 2.0e-6 #1e-5
max_steps = int((train_size / batch_size) * num_epochs)
eval_steps = logging_steps = 1000

# Initialize the optimizer. See this for optimizers:
# https://huggingface.co/docs/transformers/en/main_classes/optimizer_schedules
optimizer = torch.optim.AdamW(model.parameters(), lr=init_learning_rate)

# Set up a learning rate scheduler. See this for scheduler types:
# https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/optimizer_schedules#transformers.SchedulerType
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(max_steps * 0), # % of the steps for warmup
    num_training_steps=max_steps,
)

early_stopping = EarlyStoppingCallback(early_stopping_patience=5)

training_hyperparams = TrainingArguments(
    output_dir=output_dir,
    report_to='mlflow',
    run_name=model_name, # for mlflow logging
    # learning_rate=1e-4,
    # lr_scheduler_type='linear', 
    eval_strategy='steps', # evaluate on eval_dataset every eval_steps
    eval_steps=eval_steps,
    eval_accumulation_steps=logging_steps,
    logging_steps=logging_steps, # update steps to perform before output logs
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    # max_steps=max_steps, # Overrides num_train_epochs
    save_total_limit=1, # Save only last checkpoint
    load_best_model_at_end=True, # Save best model
    metric_for_best_model='eval_loss', # Key from dict, returned by compute_metrics, or some predefined values
    greater_is_better=False,
    save_steps=10000,
    # logging_dir='trocr_checkpoints/logs',
    # fp16=True,
    fp16_full_eval=True,
)

In [None]:
# Training the model

if not os.path.exists(output_dir):
    trainer = Trainer(
        model=model,
        args=training_hyperparams,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        # compute_metrics=compute_metrics,
        optimizers=(optimizer, lr_scheduler),
        callbacks=[early_stopping]
        # processing_class=processor,
        # data_collator=...
    )
    trainer.train()
else:
    raise ValueError(f"Model '{model_name}' exists, specify another name")

  1%|          | 1000/118560 [09:58<24:48:53,  1.32it/s]

{'loss': 0.3108, 'grad_norm': 2.1685426235198975, 'learning_rate': 2.199613826365429e-06, 'epoch': 0.08}


                                                        
  1%|          | 1000/118560 [12:11<24:48:53,  1.32it/s]

{'eval_loss': 0.24410735070705414, 'eval_runtime': 133.0535, 'eval_samples_per_second': 18.752, 'eval_steps_per_second': 2.345, 'epoch': 0.08}


  2%|▏         | 2000/118560 [21:51<18:28:01,  1.75it/s]  

{'loss': 0.3356, 'grad_norm': 3.756221055984497, 'learning_rate': 2.1984555766073082e-06, 'epoch': 0.17}


                                                        
  2%|▏         | 2000/118560 [24:04<18:28:01,  1.75it/s]

{'eval_loss': 0.263593852519989, 'eval_runtime': 132.5658, 'eval_samples_per_second': 18.821, 'eval_steps_per_second': 2.354, 'epoch': 0.17}


  3%|▎         | 3000/118560 [33:54<20:28:30,  1.57it/s]  

{'loss': 0.2588, 'grad_norm': 4.268790245056152, 'learning_rate': 2.1965260639720362e-06, 'epoch': 0.25}


                                                        
  3%|▎         | 3000/118560 [36:12<20:28:30,  1.57it/s]

{'eval_loss': 0.2590153217315674, 'eval_runtime': 137.6969, 'eval_samples_per_second': 18.12, 'eval_steps_per_second': 2.266, 'epoch': 0.25}


  3%|▎         | 4000/118560 [46:25<21:41:21,  1.47it/s]  

{'loss': 0.3107, 'grad_norm': 2.8645238876342773, 'learning_rate': 2.1938266432358075e-06, 'epoch': 0.34}


                                                        
  3%|▎         | 4000/118560 [48:38<21:41:21,  1.47it/s]

{'eval_loss': 0.26178398728370667, 'eval_runtime': 133.611, 'eval_samples_per_second': 18.674, 'eval_steps_per_second': 2.335, 'epoch': 0.34}


  4%|▍         | 5000/118560 [59:03<24:45:14,  1.27it/s]  

{'loss': 0.2806, 'grad_norm': 1.5622378587722778, 'learning_rate': 2.1903592097533803e-06, 'epoch': 0.42}


                                                        
  4%|▍         | 5000/118560 [1:01:27<24:45:14,  1.27it/s]

{'eval_loss': 0.2603222727775574, 'eval_runtime': 144.2185, 'eval_samples_per_second': 17.3, 'eval_steps_per_second': 2.163, 'epoch': 0.42}


  5%|▌         | 6000/118560 [1:11:17<21:29:30,  1.45it/s]  

{'loss': 0.3178, 'grad_norm': 4.393310546875, 'learning_rate': 2.186126198127283e-06, 'epoch': 0.51}


                                                          
  5%|▌         | 6000/118560 [1:13:49<21:29:30,  1.45it/s]

{'eval_loss': 0.2639763057231903, 'eval_runtime': 151.4982, 'eval_samples_per_second': 16.469, 'eval_steps_per_second': 2.059, 'epoch': 0.51}


  6%|▌         | 7000/118560 [1:24:29<23:54:49,  1.30it/s]  

{'loss': 0.2644, 'grad_norm': 1.710511565208435, 'learning_rate': 2.181130580498397e-06, 'epoch': 0.59}


                                                          
  6%|▌         | 7000/118560 [1:26:54<23:54:49,  1.30it/s]

{'eval_loss': 0.2622835636138916, 'eval_runtime': 145.2707, 'eval_samples_per_second': 17.175, 'eval_steps_per_second': 2.148, 'epoch': 0.59}


  7%|▋         | 8000/118560 [1:37:05<21:20:30,  1.44it/s]  

{'loss': 0.3103, 'grad_norm': 3.2963521480560303, 'learning_rate': 2.1753758644591165e-06, 'epoch': 0.67}


                                                          
  7%|▋         | 8000/118560 [1:39:34<21:20:30,  1.44it/s]

{'eval_loss': 0.26190298795700073, 'eval_runtime': 149.4012, 'eval_samples_per_second': 16.7, 'eval_steps_per_second': 2.088, 'epoch': 0.67}


  8%|▊         | 9000/118560 [1:49:21<17:40:15,  1.72it/s]  

{'loss': 0.2306, 'grad_norm': 2.1879565715789795, 'learning_rate': 2.1688660905905485e-06, 'epoch': 0.76}


                                                          
  8%|▊         | 9000/118560 [1:51:28<17:40:15,  1.72it/s]

{'eval_loss': 0.28553342819213867, 'eval_runtime': 127.6767, 'eval_samples_per_second': 19.542, 'eval_steps_per_second': 2.444, 'epoch': 0.76}


  8%|▊         | 10000/118560 [2:02:31<24:43:06,  1.22it/s] 

{'loss': 0.3361, 'grad_norm': 1.6545602083206177, 'learning_rate': 2.161605829625483e-06, 'epoch': 0.84}


                                                           
  8%|▊         | 10000/118560 [2:05:11<24:43:06,  1.22it/s]

{'eval_loss': 0.25895318388938904, 'eval_runtime': 160.2443, 'eval_samples_per_second': 15.57, 'eval_steps_per_second': 1.947, 'epoch': 0.84}


  9%|▉         | 11000/118560 [2:18:00<24:53:03,  1.20it/s]  

{'loss': 0.3071, 'grad_norm': 3.4142446517944336, 'learning_rate': 2.15360017923913e-06, 'epoch': 0.93}


                                                           
  9%|▉         | 11000/118560 [2:20:43<24:53:03,  1.20it/s]

{'eval_loss': 0.2567855417728424, 'eval_runtime': 162.7511, 'eval_samples_per_second': 15.33, 'eval_steps_per_second': 1.917, 'epoch': 0.93}


  9%|▉         | 11113/118560 [2:22:05<19:48:53,  1.51it/s]  

# 3. Save model and history

In [None]:
with mlflow.start_run(run_name=model_name):

    mlflow.transformers.log_model(trainer.model, artifact_path=model_name)
    mlflow.log_params(training_hyperparams.to_dict())
    metrics = trainer.evaluate()
    mlflow.log_metrics(metrics)

In [None]:
save_model_and_history(model_name, trainer)

# 4. Evaluation

__1. Plot training history__

In [None]:
epochs = [i['epoch'] for i in trainer.state.log_history if 'eval_loss' in i]
train_loss = [i['loss'] for i in trainer.state.log_history if 'loss' in i]
val_loss = [i['eval_loss'] for i in trainer.state.log_history if 'eval_loss' in i]
# val_cer = [10 * i['eval_cer'] for i in trainer.state.log_history if 'eval_cer' in i]

hist = {'train_loss': train_loss, 'val_loss': val_loss}#, 'val_cer * 10': val_cer}

In [None]:
fig = plot_history(
    epochs,
    hist,
    run_name=model_name,
    figsize=(15, 10),
)
