In [1]:
import os
from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from datasets import load_metric
from torch.utils.data import Dataset
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

In [2]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE: int = 16
    EPOCHS: int = 5
    LEARNING_RATE: float = 0.00005

@dataclass(frozen=True)
class DatasetConfig:
    DATA_ROOT: str = '../../data/processed/4 Segmenter test/'

@dataclass(frozen=True)
class ModelConfig:
    # MODEL_NAME: str = 'microsoft/trocr-small-printed'
    # MODEL_NAME: str = 'microsoft/trocr-small-handwritten'
    MODEL_NAME: str = 'raxtemur/trocr-base-ru'

In [4]:
train_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'train.csv'), index_col=0
).reset_index()

valid_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'valid.csv'), index_col=0
).reset_index()

test_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'test.csv'), index_col=0
).reset_index()

f"Размер обучающей выборки: {len(train_df)} | Размер валидационной выборки: {len(valid_df)} | Размер тестовой выборки: {len(test_df)}"

'Размер обучающей выборки: 957 | Размер валидационной выборки: 263 | Размер тестовой выборки: 325'

In [5]:
train_df.dropna(inplace=True)
valid_df.dropna(inplace=True)
test_df.dropna(inplace=True)

train_df = train_df[train_df['text'] != '.']
valid_df = valid_df[valid_df['text'] != '.']
test_df = test_df[test_df['text'] != '.']

f"Размер обучающей выборки: {len(train_df)} | Размер валидационной выборки: {len(valid_df)} | Размер тестовой выборки: {len(test_df)}"

'Размер обучающей выборки: 899 | Размер валидационной выборки: 228 | Размер тестовой выборки: 325'

In [6]:
# Augmentations.
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
])

In [7]:
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length


    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        # The image file name.
        file_name = self.df['file_name'].iloc[idx]
        # The text (label).
        text = self.df['text'].iloc[idx]
        # Read the image, apply augmentations, and get the transformed pixels.
        image = Image.open(self.root_dir + file_name).convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        # Pass the text through the tokenizer and get the labels,
        # i.e. tokenized labels.
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        # We are using -100 as the padding token.
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [8]:
train_df

Unnamed: 0,file_name,text
0,0_fdf0882e-20.png,1875
1,1_fdf0882e-20.png,Мартъ. Вышелъ 2 томъ курса гр. права.
2,2_fdf0882e-20.png,16 Мая Свадьба Кати Пеликанъ.
3,3_fdf0882e-20.png,22 – 26. въ Москве.
4,4_fdf0882e-20.png,4 Iюня. Уехали заграницу съ Соничкой и съ М. Е.
...,...,...
952,31_10a09237-0191.jpeg,"Счастье, счастье увиделъ я на лице у"
953,32_10a09237-0191.jpeg,милой моей Катюши!
954,33_10a09237-0191.jpeg,"Мы скрываемся ото всехъ днемъ, а вечеромъ"
955,34_10a09237-0191.jpeg,въ темноте – прогулки къ нашему бревнышку.


In [9]:
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer', 'train/'),
    df=train_df,
    processor=processor
)

valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer', 'valid/'),
    df=valid_df,
    processor=processor
)

test_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer', 'test/'),
    df=test_df,
    processor=processor
)

In [10]:
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fea

In [11]:
# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [12]:
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

In [13]:
cer_metric = load_metric("cer", trust_remote_code=True)
wer_metric = load_metric("wer", trust_remote_code=True)

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer, "wer": wer}

  cer_metric = load_metric("cer", trust_remote_code=True)


In [14]:
# если есть ClearML, то укажите свои настройки для логирования обучения детектора текста
# с инструкцией, как поднять собственный ClearML, можно ознакомиться тут: https://github.com/allegroai/clearml-server 
%env CLEARML_WEB_HOST=http://localhost:8080
%env CLEARML_API_HOST=http://localhost:8008
%env CLEARML_FILES_HOST=http://localhost:8081
%env CLEARML_API_ACCESS_KEY=LOIP4T1VXIPLP16VZJR9
%env CLEARML_API_SECRET_KEY=RYVetvGfembTTfDKxnlWaXVWc60XWWka2WjNeRlczJmV5k2mgt

env: CLEARML_WEB_HOST=http://localhost:8080
env: CLEARML_API_HOST=http://localhost:8008
env: CLEARML_FILES_HOST=http://localhost:8081
env: CLEARML_API_ACCESS_KEY=LOIP4T1VXIPLP16VZJR9
env: CLEARML_API_SECRET_KEY=RYVetvGfembTTfDKxnlWaXVWc60XWWka2WjNeRlczJmV5k2mgt


In [15]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='seq2seq_model_checkpoints/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='clearml',
    num_train_epochs=TrainingConfig.EPOCHS
)

In [16]:
# Initialize trainer.
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)



In [17]:
res = trainer.train()

ClearML Task: created new task id=ed0dc172068f4599bf844c666e67a278
2024-03-28 16:05:59,009 - clearml.Task - INFO - Storing jupyter notebook directly as code


Unsupported key of type '<class 'int'>' found when connecting dictionary. It will be converted to str


ClearML results page: http://localhost:8080/projects/4ee2d7e866464b1995dd559ab5add5f7/experiments/ed0dc172068f4599bf844c666e67a278/output/log




Epoch,Training Loss,Validation Loss,Cer,Wer
1,1.1845,0.666644,0.293156,0.67148
2,0.4804,0.429812,0.218776,0.550903
3,0.2704,0.393281,0.204967,0.5213
4,0.1709,0.374909,0.193701,0.506859
5,0.1261,0.378726,0.190672,0.510469


Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


## Test

In [18]:
trainer.evaluate(test_dataset)



{'eval_loss': 0.4248153865337372,
 'eval_cer': 0.19709905443863032,
 'eval_wer': 0.4975793437331899,
 'eval_runtime': 222.8998,
 'eval_samples_per_second': 1.458,
 'eval_steps_per_second': 0.049,
 'epoch': 5.0}

In [20]:
# Example for a sequence-to-sequence task
predictions = trainer.predict(test_dataset)

decoded_predictions = [processor.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in predictions.predictions]



In [21]:
test_df['pred'] = decoded_predictions

In [22]:
test_df.to_csv('pobed_seg.csv')

In [23]:
test_df

Unnamed: 0,file_name,text,pred
0,0_909ccbb0-18.png,Начало общества у В.К. Константина. Отделъ Общ...,На что общества у В. К. Насильичина. Отделъ Об...
1,1_909ccbb0-18.png,Валуевъ – М-ръ Госуд. имуществъ.,Валуевъ – М заъ Госуда имущества.
2,2_909ccbb0-18.png,20 мая. Поездка съ Катей и Соничкой черезъ Москву,28. Мая. Поездка съ Котей и Соничкой чередъ М...
3,3_909ccbb0-18.png,въ Смоленскъ. у а. в. шевандиной и у Дiодора,въ смоленiя. у А. в. Шевандиной и утодора
4,4_909ccbb0-18.png,въ Александровскомъ. вернулись 1 Iюня.,въ Александровскомъ. вернулись 17юня.
...,...,...,...
320,37_c6d63ae5-15.png,"Варшаву и Берлинъ, и Парижъ и Лондонъ,","Варшаву и берлинъ, и ""парижъ и-Лондоне,"
321,38_c6d63ae5-15.png,на о-въ Вайтъ. – Шенклинъ. На обратномъ,на О въ Войте. – Шенклинъ. На обратность
322,39_c6d63ae5-15.png,Пути черезъ Ломжу – возвращаемся,пути черезъ Ломжу – возвращаемся
323,40_c6d63ae5-15.png,1 Сентября.,1 Сентябрь.
