In [1]:
import os
from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from datasets import load_metric
from torch.utils.data import Dataset
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

In [2]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE: int = 64
    EPOCHS: int = 30
    LEARNING_RATE: float = 0.00005

@dataclass(frozen=True)
class DatasetConfig:
    # DATA_ROOT: str = '../../data/processed/2 For OCR'
    DATA_ROOT: str = '/media/admin01/storage1/vadim/Historical-docs-OCR/data/processed/3 Production'


@dataclass(frozen=True)
class ModelConfig:
    # MODEL_NAME: str = 'microsoft/trocr-small-printed'
    MODEL_NAME: str = 'microsoft/trocr-small-handwritten'
    # MODEL_NAME: str = 'raxtemur/trocr-base-ru'
    # MODEL_NAME: str = 'microsoft/trocr-small-stage1'


In [4]:
train_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'train.csv'), index_col=0
)

test_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'valid.csv'), index_col=0
)

train_df.dropna(inplace=True)
test_df.dropna(inplace=True)

train_df = train_df[train_df['text'] != 'unlabelled']
test_df = test_df[test_df['text'] != 'unlabelled']

f"Размер обучающей выборки: {len(train_df)} | Размер тестовой выборки: {len(test_df)}"

'Размер обучающей выборки: 23609 | Размер тестовой выборки: 5849'

In [5]:
train_df

Unnamed: 0,file_name,text,label
0,IMG_5827___0.JPG,1001 томъ; Пронской 452 названiя,0
1,IMG_5827___1.JPG,921 томъ; Ряжской 410 названiй,0
2,IMG_5827___2.JPG,634 тома и Скопинской 401 названiе,0
3,IMG_5827___3.JPG,783 тома. Пожертвовано въ пользу,0
4,IMG_5827___4.JPG,библiотекъ Раненбургской 49 р. и,0
...,...,...,...
23643,11227470_doc1___17.jpg,тиромъ на выкопировке съ плна,0
23644,11227470_doc1___18.jpg,"шета, которая имеетъ быть переда",0
23645,11227470_doc1___19.jpg,"на Г. Мировому Посреднику. Но,",0
23646,11227470_doc1___20.jpg,"независимо сего, на основанiи ст.",0


In [6]:
# Augmentations.
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
])

In [7]:
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length


    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        # The image file name.
        file_name = self.df['file_name'].iloc[idx]
        # The text (label).
        text = self.df['text'].iloc[idx]
        # Read the image, apply augmentations, and get the transformed pixels.
        image = Image.open(self.root_dir + file_name).convert('RGB')
        
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        # Pass the text through the tokenizer and get the labels,
        # i.e. tokenized labels.
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        # We are using -100 as the padding token.
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [8]:
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/train/'),
    df=train_df,
    processor=processor
)
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/valid/'),
    df=test_df,
    processor=processor
)



In [9]:
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


VisionEncoderDecoderModel(
  (encoder): DeiTModel(
    (embeddings): DeiTEmbeddings(
      (patch_embeddings): DeiTPatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DeiTEncoder(
      (layer): ModuleList(
        (0-11): 12 x DeiTLayer(
          (attention): DeiTAttention(
            (attention): DeiTSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): DeiTSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): DeiTIntermediate(
            (dense): Linear(

In [10]:
# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [11]:
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

In [12]:
cer_metric = load_metric("cer", trust_remote_code=True)
wer_metric = load_metric("wer", trust_remote_code=True)

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer, "wer": wer}

  cer_metric = load_metric("cer", trust_remote_code=True)


In [13]:
# если есть ClearML, то укажите свои настройки для логирования обучения детектора текста
# с инструкцией, как поднять собственный ClearML, можно ознакомиться тут: https://github.com/allegroai/clearml-server 
%env CLEARML_WEB_HOST=http://localhost:8080
%env CLEARML_API_HOST=http://localhost:8008
%env CLEARML_FILES_HOST=http://localhost:8081
%env CLEARML_API_ACCESS_KEY=LOIP4T1VXIPLP16VZJR9
%env CLEARML_API_SECRET_KEY=RYVetvGfembTTfDKxnlWaXVWc60XWWka2WjNeRlczJmV5k2mgt

env: CLEARML_WEB_HOST=http://localhost:8080
env: CLEARML_API_HOST=http://localhost:8008
env: CLEARML_FILES_HOST=http://localhost:8081
env: CLEARML_API_ACCESS_KEY=LOIP4T1VXIPLP16VZJR9
env: CLEARML_API_SECRET_KEY=RYVetvGfembTTfDKxnlWaXVWc60XWWka2WjNeRlczJmV5k2mgt


In [14]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='seq2seq_model_checkpoints/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='clearml',
    num_train_epochs=TrainingConfig.EPOCHS,
    dataloader_num_workers=4
)

In [15]:
# Initialize trainer.
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)



In [16]:
res = trainer.train()

ClearML Task: created new task id=57e787289ffd4a9bb5775cec5bd56d55
2024-03-19 00:54:48,514 - clearml.Task - INFO - Storing jupyter notebook directly as code


Unsupported key of type '<class 'int'>' found when connecting dictionary. It will be converted to str


ClearML results page: http://localhost:8080/projects/4ee2d7e866464b1995dd559ab5add5f7/experiments/57e787289ffd4a9bb5775cec5bd56d55/output/log




Epoch,Training Loss,Validation Loss,Cer,Wer
1,5.1326,3.452236,1.112115,1.468708
2,3.3815,3.068127,0.960886,1.102586
3,3.1066,2.86054,0.895642,1.062962
4,2.9212,2.655235,0.867623,1.081888
5,2.5681,2.134887,0.737278,0.94997
6,1.889,1.373746,0.458928,0.775675
7,1.2462,0.938605,0.259292,0.613821
8,0.8486,0.662828,0.178883,0.498115
9,0.6423,0.532609,0.140488,0.426746
10,0.5234,0.473127,0.126709,0.397451


ClearML results page: http://localhost:8080/projects/4ee2d7e866464b1995dd559ab5add5f7/experiments/57e787289ffd4a9bb5775cec5bd56d55/output/log
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start


Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_si

In [18]:
trainer.save_model()

Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
