In [None]:
import os
from dataclasses import dataclass
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from datasets import load_metric
from torch.utils.data import Dataset
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

In [None]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE: int = 32
    EPOCHS: int = 5
    LEARNING_RATE: float = 0.00005

@dataclass(frozen=True)
class DatasetConfig:
    # DATA_ROOT: str = '../../data/processed/2 For OCR'
    DATA_ROOT: str = '/media/admin01/storage1/vadim/Historical-docs-OCR/data/processed/3 Production'


@dataclass(frozen=True)
class ModelConfig:
    # MODEL_NAME: str = 'microsoft/trocr-small-printed'
    # MODEL_NAME: str = 'microsoft/trocr-small-handwritten'
    MODEL_NAME: str = 'raxtemur/trocr-base-ru'
    # MODEL_NAME: str = 'microsoft/trocr-small-stage1'


In [None]:
train_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'train.csv'), index_col=0
)

valid_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'valid.csv'), index_col=0
)

test_df = pd.read_csv(
    os.path.join(DatasetConfig.DATA_ROOT, 'test.csv'), index_col=0
    )

train_df.dropna(inplace=True)
valid_df.dropna(inplace=True)
test_df.dropna(inplace=True)

train_df = train_df[train_df['text'] != 'unlabelled']
valid_df = valid_df[valid_df['text'] != 'unlabelled']
test_df = test_df[test_df['text'] != 'unlabelled']

train_df = train_df[train_df['text'] != '.']
valid_df = valid_df[valid_df['text'] != '.']
test_df = test_df[test_df['text'] != '.']

f"Размер обучающей выборки: {len(train_df)} | Размер валидационной выборки: {len(valid_df)} | Размер тестовой выборки: {len(test_df)}"

In [None]:
def find_bad_imgs(root_dir, train_df):
    bad = []
    for idx in tqdm(range(len(train_df))):
        file_name = train_df['file_name'].iloc[idx]
        try:
            image = Image.open(root_dir + file_name).convert('RGB')
        except:
            bad.append(idx)
    
    return bad

root_dir = DatasetConfig.DATA_ROOT + "/text_recognizer/train/"
bad_train = find_bad_imgs(root_dir, train_df)
train_df.drop(index=train_df.iloc[bad_train].index, inplace=True)

root_dir = DatasetConfig.DATA_ROOT + "/text_recognizer/valid/"
bad_valid = find_bad_imgs(root_dir, valid_df)
valid_df.drop(index=valid_df.iloc[bad_valid].index, inplace=True)

root_dir = DatasetConfig.DATA_ROOT + "/text_recognizer/test/"
bad_test = find_bad_imgs(root_dir, test_df)
test_df.drop(index=test_df.iloc[bad_test].index, inplace=True)


In [None]:
f"Размер обучающей выборки: {len(train_df)} | Размер валидационной выборки: {len(valid_df)} | Размер тестовой выборки: {len(test_df)}"

In [None]:
# Augmentations.
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
])

In [None]:
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length


    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        # The image file name.
        file_name = self.df['file_name'].iloc[idx]
        # The text (label).
        text = self.df['text'].iloc[idx]
        # Read the image, apply augmentations, and get the transformed pixels.
        image = Image.open(self.root_dir + file_name).convert('RGB')
        
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        # Pass the text through the tokenizer and get the labels,
        # i.e. tokenized labels.
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        # We are using -100 as the padding token.
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [None]:
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/train/'),
    df=train_df,
    processor=processor
)

valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/valid/'),
    df=valid_df,
    processor=processor
)

test_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/test/'),
    df=test_df,
    processor=processor
)

In [None]:
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
# model = VisionEncoderDecoderModel.from_pretrained("../../models/text_recognizer/checkpoint-1152/", local_files_only=True)
model.to(device)

print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

In [None]:
# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

In [None]:
cer_metric = load_metric("cer", trust_remote_code=True)
wer_metric = load_metric("wer", trust_remote_code=True)

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer, "wer": wer}

In [None]:
# если есть ClearML, то укажите свои настройки для логирования обучения детектора текста
# с инструкцией, как поднять собственный ClearML, можно ознакомиться тут: https://github.com/allegroai/clearml-server 
%env CLEARML_WEB_HOST=http://localhost:8080
%env CLEARML_API_HOST=http://localhost:8008
%env CLEARML_FILES_HOST=http://localhost:8081
%env CLEARML_API_ACCESS_KEY=LOIP4T1VXIPLP16VZJR9
%env CLEARML_API_SECRET_KEY=RYVetvGfembTTfDKxnlWaXVWc60XWWka2WjNeRlczJmV5k2mgt

In [None]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='seq2seq_model_checkpoints/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='clearml',
    num_train_epochs=TrainingConfig.EPOCHS,
    dataloader_num_workers=4
)

In [None]:
# Initialize trainer.
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)

# Обучим модель и посмотрим качество на тесте

In [None]:
res = trainer.train()

In [None]:
trainer.evaluate(test_dataset)

# Инференс по грамотам

In [None]:
guber_df = test_df[test_df['label'] == 0]

In [None]:
guber_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/test/'),
    df=guber_df,
    processor=processor
)

In [None]:
trainer.evaluate(guber_dataset)

# Инференс по уставным

In [None]:
otchet_df = test_df[(test_df['label'] == 1) | (test_df['label'] == 2)]

In [None]:
otchet_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/test/'),
    df=otchet_df,
    processor=processor
)

In [None]:
trainer.evaluate(otchet_dataset)

# Инференс только по Победоносцеву

In [None]:
pobed_df = test_df[test_df['label'] == 3]

In [None]:
pobed_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'text_recognizer/test/'),
    df=pobed_df,
    processor=processor
)

In [None]:
trainer.evaluate(pobed_dataset)

In [None]:
# Example for a sequence-to-sequence task
predictions = trainer.predict(pobed_dataset)

decoded_predictions = [processor.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in predictions.predictions]

In [None]:
pobed_df['pred'] = decoded_predictions

In [None]:
pobed_df

In [None]:
pobed_df.to_csv('../../reports/pobed_pred.csv', index=False)