In [None]:
!pip install accelerate -U > /dev/null

In [None]:
!pip install transformers > /dev/null

In [None]:
!pip install -q jiwer > /dev/null

In [None]:
from datasets import load_metric

cer_metric = load_metric("cer")
wer_metric = load_metric("wer")

In [None]:
!kaggle datasets download -d constantinwerner/cyrillic-handwriting-dataset

In [None]:
!unzip -o /kaggle/working/cyrillic-handwriting-dataset.zip -d /kaggle/working/cyrillic-handwriting-dataset > /dev/null

In [None]:
!rm -rf /kaggle/working/cyrillic-handwriting-dataset.zip

In [None]:
# basic random seed
import os 
import random
import numpy as np 

DEFAULT_RANDOM_SEED = 2021
def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
# torch random seed
import torch
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
      
# basic + torch 
def seedEverything(seed=DEFAULT_RANDOM_SEED):
    seedBasic(seed)
    seedTorch(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
def seed(my_seed=42):
    seedEverything(my_seed)
    global g
    g = torch.Generator()
    g.manual_seed(my_seed)

In [None]:
seed(42)

In [None]:
import numpy as np
import torch
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import random
import os

import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform

In [None]:
train_df = pd.read_csv('/kaggle/working/cyrillic-handwriting-dataset/train.tsv', delimiter='\t', header=None)
train_df.rename(columns={0: "file_name", 1: "text"}, inplace=True)

test_df = pd.read_csv('/kaggle/working/cyrillic-handwriting-dataset/test.tsv', delimiter='\t', header=None)
test_df.rename(columns={0: "file_name", 1: "text"}, inplace=True)

In [None]:
train_df.head()

In [None]:
test_df.head()

In [None]:
class Erosion(ImageOnlyTransform):
    def __init__(self, safe_db_lists=[], prob=0.5) -> None:
        super(Erosion, self).__init__()
        self.safe_db_lists = safe_db_lists
        self.prob = prob

    def apply(self, img, copy=True, **params):
        if np.random.uniform(0, 1) > self.prob:
            return img
        if copy:
            img = img.copy()

        kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
        img = cv2.erode(img, kernel, iterations=1)

        return img

class Dilation(ImageOnlyTransform):
    def __init__(self, safe_db_lists=[], prob=0.5) -> None:
        super(Dilation, self).__init__()
        self.safe_db_lists = safe_db_lists
        self.prob = prob

    def apply(self, img, copy=True, **params):
        if np.random.uniform(0, 1) > self.prob:
            return img
        if copy:
            img = img.copy()

        kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
        img = cv2.dilate(img, kernel, iterations=1)

        return img

In [None]:
transform = A.Compose([
    A.OneOf([
        Erosion(),
        Dilation()
    ], p=.3),

    A.OneOf([
        A.GaussNoise(),
        A.ISONoise(),
        A.MultiplicativeNoise(),
        A.ImageCompression(),
        A.Sharpen()
    ], p=.3),

    A.ShiftScaleRotate(p=.3, shift_limit=(-0.0625, 0.0625), scale_limit=(-0.2, 0.05), rotate_limit=(-10, 10),
                       border_mode=0, value=(199, 185, 182)),

    A.OneOf([
        A.ElasticTransform(alpha=.5, sigma=10, alpha_affine=.75, border_mode=0, value=(199, 185, 182)),
        A.OpticalDistortion(distort_limit=(-0.3, 1.5), shift_limit=(-0.5, 0.5), border_mode=0, value=(199, 185, 182)),
        A.GridDistortion(distort_limit=(-0.2, 0.2), border_mode=0, value=(199, 185, 182)),
    ], p=.3),

    A.OneOf([
        A.ChannelDropout(),
        A.ChannelShuffle(),
        A.Posterize(),
        A.RGBShift(),
        A.ToGray(),
        A.ToSepia()
    ], p=.3),

    A.OneOf([
        A.CLAHE(clip_limit=2),
        A.RandomBrightnessContrast(),
        A.HueSaturationValue(),
    ], p=.3),

    A.OneOf([
        A.MotionBlur(blur_limit=3),
        A.Blur(blur_limit=3),
    ], p=.3),
], p=0.6)

In [None]:
from transformers import TrOCRProcessor

# microsoft/trocr-base-handwritten
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class CyrillicDataset(Dataset):
    def __init__(self, root_dir, df, processor, transform=None, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.transform = transform
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        
        if not isinstance(text, str):
            text = ''
        
        # prepare image (i.e. resize + normalize)
        image = Image.open(f"{self.root_dir}/{file_name}").convert("RGB")
            
#         image = cv2.imread(f"{self.root_dir}/{file_name}")
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform is not None:
            image = transform(image=image)['image']

        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [None]:
train_dataset_root_dir = '/kaggle/working/cyrillic-handwriting-dataset/train'
eval_dataset_root_dir = '/kaggle/working/cyrillic-handwriting-dataset/test'

train_dataset = CyrillicDataset(root_dir=train_dataset_root_dir,
                                df=train_df,
                                processor=processor)
eval_dataset = CyrillicDataset(root_dir=eval_dataset_root_dir,
                               df=test_df,
                               processor=processor)

In [None]:
from torch.utils.data import ConcatDataset, Subset

train_subset_size = len(train_dataset) // 8
train_subset_indices = torch.randperm(len(train_dataset))[:train_subset_size].tolist()
train_dataset = Subset(train_dataset, train_subset_indices)
train_subset_indices[:10]

In [None]:
len(train_dataset)

In [None]:
def visualize(image):
    plt.figure(figsize=(2, 2))
    plt.axis('off')
    plt.imshow(image)

In [None]:
path = f"{train_dataset_root_dir}/{train_df['file_name'][0]}"
print(f"{path=}")
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
visualize(image)

In [None]:
# image background color

for _ in range(2):
    random_number = random.randint(1, len(os.listdir(train_dataset_root_dir)))
    image = cv2.imread(f"{train_dataset_root_dir}/{train_df['file_name'][random_number]}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    if transform is not None:
        augmented_image = transform(image=image)['image']
    visualize(augmented_image)

In [None]:
encoding = train_dataset[0]
for k, v in encoding.items():
  print(k, v.shape)

In [None]:
image = Image.open(f"{train_dataset_root_dir}/{train_df['file_name'][0]}").convert("RGB")
image

In [None]:
labels = encoding['labels']
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from transformers import VisionEncoderDecoderModel

# microsoft/trocr-base-handwritten
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)

In [None]:
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

In [None]:
from pynvml import *


def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [None]:
print_gpu_utilization()

In [None]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    num_train_epochs=10,
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    output_dir="/kaggle/working/seq2seq_model_handwritten",
    logging_strategy="epoch",
    report_to="tensorboard",
    save_strategy="epoch",
)

In [None]:
def accuracy(labels, pred_labels):
    """
    Calculate accuracy of the model.
    """
    import numpy as np
    assert len(labels) == len(pred_labels)
    return np.sum(np.compare_chararrays(labels, pred_labels, "==", False)) / len(labels)


def compute_metrics(pred):
#     print(f"{pred.predictions=}")
#     print(f"{pred.label_ids=}")
    labels_ids = pred.label_ids
    pred_ids = pred.predictions[0]
    
#     pred_ids[pred_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
    
#     print(f"{pred_str=}\n{'-' * 20}\n{label_str=}")
#     print()
          
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    acc = accuracy(pred_str, label_str)

    return {"cer": cer, "wer": wer, "acc": acc}


def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [None]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()

In [None]:
!zip -r "/kaggle/working/seq2seq_model_handwritten.zip" "/kaggle/working/seq2seq_model_handwritten"

In [None]:
from IPython.display import FileLink
FileLink(r'seq2seq_model_handwritten.zip')

In [None]:
from PIL import Image

def batch(iterable, batch_size):
    """Yield successive batches of given size from the iterable."""
    for i in range(0, len(iterable), batch_size):
        yield iterable[i:i + batch_size]
        

def flatten(matrix):
    return list(chain.from_iterable(matrix))

def read_and_show(image_path):
    """
    :param image_path: String, path to the input image.
 
 
    Returns:
        image: PIL Image.
    """
    image = Image.open(image_path).convert('RGB')
    return image


def ocr(images, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.


    Returns:
        generated_text: the OCR'd text string.
    """
    # We can directly perform OCR on cropped images.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pixel_values = processor(images, return_tensors='pt').pixel_values.to(device)
    model.eval()
    with torch.no_grad(), torch.inference_mode():
        generated_ids = model.to(device).generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return generated_text

In [None]:
import os
import glob
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import time

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def eval_new_data(data_path, num_samples=50):
    image_paths = glob.glob(data_path)
    
    acc_avg = AverageMeter()
    wer_avg = AverageMeter()
    cer_avg = AverageMeter()
    
    for i, image_paths in tqdm(enumerate(batch(image_paths, num_samples)), total=(len(image_paths) // num_samples)):
        images = list(map(read_and_show, image_paths))
        
        def get_label(image_path):
            image_name = image_path.split('/')[-1]
            label = test_df.loc[test_df['file_name'] == image_name]['text'].values[0]
            return label
        
        labels = list(map(get_label, image_paths))
        print(labels[:10])
        
        start = time.time()
        texts = ocr(images, processor, model)
        print(texts[:10])
        end = time.time()
        
        print(f"Time: {end - start}")
        
        cer = cer_metric.compute(predictions=texts, references=labels)
        wer = wer_metric.compute(predictions=texts, references=labels)
        print(f"{wer=} {cer=}")
        
        wer_avg.update(wer, 1)
        cer_avg.update(cer, 1)
    
    return wer_avg, cer_avg

wer_avg, cer_avg = eval_new_data(
    data_path=os.path.join('/kaggle/working/cyrillic-handwriting-dataset', 'test', '*'),
    num_samples=100
)