In [None]:
%%capture

!pip install --upgrade accelerate
!pip install jiwer

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

Ноутбук основан на следующем туториале: [ссылка](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_Seq2SeqTrainer.ipynb)

Github: NielsRogge/Transformers-Tutorials

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, images_path: Path, image_labels: pd.DataFrame, processor, max_target_length=100):
        self.images_path = images_path
        self.image_labels = image_labels
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, idx):
        file_name = self.image_labels['Id'].iloc[idx]
        image = Image.open(self.images_path / file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        encoding = {"pixel_values": pixel_values.squeeze()}

        if "Expected" in self.image_labels.columns:
            text = self.image_labels['Expected'].iloc[idx]
            labels = self.processor.tokenizer(
                text,
                padding="max_length",
                max_length=self.max_target_length
            ).input_ids

            labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
            encoding["labels"] = torch.tensor(labels)

        return encoding

# Подготовка данных

In [None]:
train_labels_path = Path("/kaggle/input/vk-made-ocr/train_labels.csv")
sample_submission_path = Path("/kaggle/input/vk-made-ocr/sample_submission.csv")
train_folder_path = Path("/kaggle/input/vk-made-ocr/train/train")
test_folder_path = Path("/kaggle/input/vk-made-ocr/test/test")


assert train_labels_path.is_file()
assert sample_submission_path.is_file()
assert train_folder_path.is_dir()
assert test_folder_path.is_dir()

In [None]:
train_labels_df = pd.read_csv(train_labels_path)
train_labels_df["Expected"] = train_labels_df["Expected"].astype(str)

sample_submission_df = pd.read_csv(sample_submission_path)

train_labels_df.head(2)

In [None]:
train_images, val_images = train_test_split(train_labels_df, test_size=0.1, random_state=12)
train_images = train_images.drop(190919)

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

train_dataset = Dataset(train_folder_path, train_images, processor)
val_dataset = Dataset(train_folder_path, val_images, processor)
test_dataset = Dataset(test_folder_path, sample_submission_df, processor)

Ссылка на модель: [HuggingFace](https://huggingface.co/microsoft/trocr-base-stage1)

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

In [None]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
model_output_dir = Path("/kaggle/working/trocr_base_v1")
model_output_dir.mkdir(exist_ok=True)

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    fp16=True, 
    output_dir=str(model_output_dir),
    logging_strategy="epoch",
    save_strategy="epoch",
    dataloader_num_workers=2,
    load_best_model_at_end=True,
    report_to="none"
)

In [None]:
from datasets import load_metric

cer_metric = load_metric("cer")

In [None]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

In [None]:
from transformers import default_data_collator

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=default_data_collator,
)

trainer.train()

Обучение на kaggle занимает около 10 часов.

# Inference

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

In [None]:
from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("./trocr_base_v1/checkpoint-15525")

In [None]:
%%capture

model.to(device)
model.eval()

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [None]:
save_path = Path("./predictions_v1.txt")

In [None]:
from tqdm.notebook import tqdm
predictions = []

print("Running evaluation...")
i = 0
for batch in tqdm(test_dataloader):
    i += 1
    # predict using generate
    pixel_values = batch["pixel_values"].to(device)
    outputs = model.generate(pixel_values)

    # decode
    pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
    predictions.append(pred_str)
    
    if i % 100 == 0:
        with save_path.open("w") as f:
            f.write(str(predictions))

In [None]:
with save_path.open("w") as f:
    f.write(str(predictions))

In [None]:
concat_predictions = np.concatenate(predictions)

In [None]:
submission_df = sample_submission_df.copy()

In [None]:
submission_df["Predicted"] = concat_predictions

In [None]:
submission_df.to_csv("submission_1.csv", index=False)