In [None]:
# -*- coding: utf-8 -*-
from __future__ import annotations

import argparse
import sys
import re
import logging
from glob import glob
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import torch
import pandas as pd
from torch.utils.data import SequentialSampler

sys.path.insert(0, '.')

from configs import CONFIGS  # noqa
from src.dataset import DatasetRetriever  # noqa
from src.ctc_labeling import CTCLabeling  # noqa
from src.model import get_ocr_model  # noqa
import src.utils as utils  # noqa
from src.predictor import Predictor  # noqa
from src.metrics import string_accuracy, cer, wer  # noqa

def as_asr_text(text: str) -> str:
    """ Lowercase and remove punctuation from text """
    text = re.sub(r'[^\w\s]', '', text)
    return text.lower().strip()

def main(args: argparse.Namespace) -> None:
    logging.basicConfig(level=logging.INFO)

    assert args.dataset_name in CONFIGS, f"Invalid dataset_name: {args.dataset_name}"

    utils.seed_everything(args.seed)

    config = CONFIGS[args.dataset_name](
        data_dir=args.data_dir,
        image_w=args.image_w,
        image_h=args.image_h,
        bs=args.bs,
        num_workers=args.num_workers,
        seed=args.seed,
    )

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logging.info(f'DEVICE: {device}')
    logging.info(f'DATASET: {args.dataset_name}')

    ctc_labeling = CTCLabeling(config)

    df = pd.read_csv(args.data_dir / f'{args.dataset_name}/marking.csv', index_col='sample_id')

    valid_dataset = DatasetRetriever(df[df['stage'] == 'valid'], config, ctc_labeling)
    test_dataset = DatasetRetriever(df[df['stage'] == 'test'], config, ctc_labeling)

    model = get_ocr_model(config, pretrained=False)
    model = model.to(device)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config['bs'],
        sampler=SequentialSampler(valid_dataset),
        pin_memory=False,
        drop_last=False,
        num_workers=config['num_workers'],
        collate_fn=utils.kw_collate_fn
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['bs'],
        sampler=SequentialSampler(test_dataset),
        pin_memory=False,
        drop_last=False,
        num_workers=config['num_workers'],
        collate_fn=utils.kw_collate_fn
    )

    result_metrics: List[Dict[str, float]] = []
    for experiment_folder in glob(str(args.experiment_folder / '*')):
        logging.info(experiment_folder)
        exp_metrics = defaultdict(list)
        for checkpoint_path in glob(f'{experiment_folder}/*.pt'):
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()

            predictor = Predictor(model, device)
            valid_predictions = predictor.run_inference(valid_loader)
            test_predictions = predictor.run_inference(test_loader)

            df_valid_pred = pd.DataFrame([{
                'id': prediction['id'],
                'pred_text': ctc_labeling.decode(prediction['raw_output'].argmax(1)),
                'gt_text': prediction['gt_text']
            } for prediction in valid_predictions]).set_index('id')
            df_test_pred = pd.DataFrame([{
                'id': prediction['id'],
                'pred_text': ctc_labeling.decode(prediction['raw_output'].argmax(1)),
                'gt_text': prediction['gt_text']
            } for prediction in test_predictions]).set_index('id')

            exp_metrics['cer_valid'].append(round(cer(df_valid_pred['pred_text'], df_valid_pred['gt_text']), 5))
            exp_metrics['wer_valid'].append(round(wer(df_valid_pred['pred_text'], df_valid_pred['gt_text']), 5))
            exp_metrics['acc_valid'].append(round(
                string_accuracy(df_valid_pred['pred_text'], df_valid_pred['gt_text']), 5))

            exp_metrics['cer_test'].append(round(cer(df_test_pred['pred_text'], df_test_pred['gt_text']), 5))
            exp_metrics['wer_test'].append(round(wer(df_test_pred['pred_text'], df_test_pred['gt_text']), 5))
            exp_metrics['acc_test'].append(round(
                string_accuracy(df_test_pred['pred_text'], df_test_pred['gt_text']), 5))

            df_valid_pred['pred_text'] = df_valid_pred['pred_text'].apply(as_asr_text)
            df_valid_pred['gt_text'] = df_valid_pred['gt_text'].apply(as_asr_text)
            df_test_pred['pred_text'] = df_test_pred['pred_text'].apply(as_asr_text)
            df_test_pred['gt_text'] = df_test_pred['gt_text'].apply(as_asr_text)

            exp_metrics['cer_valid_asr'].append(round(cer(df_valid_pred['pred_text'], df_valid_pred['gt_text']), 5))
            exp_metrics['wer_valid_asr'].append(round(wer(df_valid_pred['pred_text'], df_valid_pred['gt_text']), 5))
            exp_metrics['acc_valid_asr'].append(round(
                string_accuracy(df_valid_pred['pred_text'], df_valid_pred['gt_text']), 5))

            exp_metrics['cer_test_asr'].append(round(cer(df_test_pred['pred_text'], df_test_pred['gt_text']), 5))
            exp_metrics['wer_test_asr'].append(round(wer(df_test_pred['pred_text'], df_test_pred['gt_text']), 5))
            exp_metrics['acc_test_asr'].append(round(
                string_accuracy(df_test_pred['pred_text'], df_test_pred['gt_text']), 5))

        result_metrics.append({
            'cer_valid': min(exp_metrics['cer_valid']),
            'wer_valid': min(exp_metrics['wer_valid']),
            'acc_valid': max(exp_metrics['acc_valid']),
            'cer_test': min(exp_metrics['cer_test']),
            'wer_test': min(exp_metrics['wer_test']),
            'acc_test': max(exp_metrics['acc_test']),
            'cer_valid_asr': min(exp_metrics['cer_valid_asr']),
            'wer_valid_asr': min(exp_metrics['wer_valid_asr']),
            'acc_valid_asr': max(exp_metrics['acc_valid_asr']),
            'cer_test_asr': min(exp_metrics['cer_test_asr']),
            'wer_test_asr': min(exp_metrics['wer_test_asr']),
            'acc_test_asr': max(exp_metrics['acc_test_asr']),
        })

        logging.info('---- VALID ----')
        logging.info(f'CER: {min(exp_metrics["cer_valid"])}')
        logging.info(f'WER: {min(exp_metrics["wer_valid"])}')
        logging.info(f'ACC: {max(exp_metrics["acc_valid"])}')
        logging.info('---- TEST -----')
        logging.info(f'CER: {min(exp_metrics["cer_test"])}')
        logging.info(f'WER: {min(exp_metrics["wer_test"])}')
        logging.info(f'ACC: {max(exp_metrics["acc_test"])}')
        logging.info('---- VALID as ASR ----')
        logging.info(f'CER: {min(exp_metrics["cer_valid_asr"])}')
        logging.info(f'WER: {min(exp_metrics["wer_valid_asr"])}')
        logging.info(f'ACC: {max(exp_metrics["acc_valid_asr"])}')
        logging.info('---- TEST as ASR -----')
        logging.info(f'CER: {min(exp_metrics["cer_test_asr"])}')
        logging.info(f'WER: {min(exp_metrics["wer_test_asr"])}')
        logging.info(f'ACC: {max(exp_metrics["acc_test_asr"])}')

    result_metrics = pd.DataFrame(result_metrics)

    def mean_std(key: str, ndigits: int = 4) -> str:
        mean = round(result_metrics[key].mean(), ndigits=ndigits)
        std = round(result_metrics[key].std(), ndigits=ndigits)
        return f'{mean} ± {std} [{round(std / mean * 100, 1)}%]'

    logging.info('---- ----- ----')
    logging.info('---- VALID ----')
    logging.info(f'CER: {mean_std("cer_valid")}')
    logging.info(f'WER: {mean_std("wer_valid")}')
    logging.info(f'ACC: {mean_std("acc_valid")}')
    logging.info('---- TEST -----')
    logging.info(f'CER: {mean_std("cer_test")}')
    logging.info(f'WER: {mean_std("wer_test")}')
    logging.info(f'ACC: {mean_std("acc_test")}')
    logging.info('---- VALID as ASR ----')
    logging.info(f'CER: {mean_std("cer_valid_asr")}')
    logging.info(f'WER: {mean_std("wer_valid_asr")}')
    logging.info(f'ACC: {mean_std("acc_valid_asr")}')
    logging.info('---- TEST as ASR -----')
    logging.info(f'CER: {mean_std("cer_test_asr")}')
    logging.info(f'WER: {mean_std("wer_test_asr")}')
    logging.info(f'ACC: {mean_std("acc_test_asr")}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run evaluation script.')
    parser.add_argument('--experiment_folder', type=Path, required=True, help='Path to the experiment folder')
    parser.add_argument('--dataset_name', type=str, required=True, help='Name of the dataset')
    parser.add_argument('--image_w', type=int, required=True, help='Image width')
    parser.add_argument('--image_h', type=int, required=True, help='Image height')
    parser.add_argument('--data_dir', type=Path, default=Path('../StackMix-OCR-DATA'), help='Path to the data directory')
    parser.add_argument('--bs', type=int, default=16, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
    parser.add_argument('--seed', type=int, default=6955, help='Random seed')

    args = parser.parse_args()

    main(args)

In [None]:
# -*- coding: utf-8 -*-
import argparse
import logging
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional

import albumentations as A
import neptune
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

sys.path.insert(0, '.')

from configs import CONFIGS  # noqa
from src.dataset import DatasetRetriever  # noqa
from src.ctc_labeling import CTCLabeling  # noqa
from src.model import get_ocr_model  # noqa
from src.experiment import OCRExperiment  # noqa
import src.utils as utils  # noqa
from src.predictor import Predictor  # noqa
from src.blot import get_blot_transforms  # noqa
from src.metrics import string_accuracy, cer, wer  # noqa
from src.stackmix import StackMix  # noqa

def get_transforms(config: Dict, use_augs: bool, use_blot: bool) -> Optional[A.Compose]:
    if use_blot and use_augs:
        return A.Compose([
            get_blot_transforms(config),
            A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.25, always_apply=False),
            A.Rotate(limit=3, interpolation=1, border_mode=0, p=0.5),
            A.JpegCompression(quality_lower=75, p=0.5),
        ], p=1.0)
    elif use_blot:
        return get_blot_transforms(config)
    elif use_augs:
        return A.Compose([
            A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.25, always_apply=False),
            A.Rotate(limit=3, interpolation=1, border_mode=0, p=0.5),
            A.JpegCompression(quality_lower=75, p=0.5),
        ], p=1.0)
    else:
        return None

def main(args: argparse.Namespace) -> None:
    logging.basicConfig(level=logging.INFO)

    assert args.dataset_name in CONFIGS, f"Invalid dataset_name: {args.dataset_name}"

    if args.checkpoint_path:
        seed = round(datetime.utcnow().timestamp()) % 10000  # warning! in resume need change seed
    else:
        seed = args.seed

    utils.seed_everything(seed)

    config = CONFIGS[args.dataset_name](
        data_dir=args.data_dir,
        experiment_name=args.experiment_name,
        experiment_description=args.experiment_description,
        image_w=args.image_w,
        image_h=args.image_h,
        num_epochs=args.num_epochs,
        bs=args.bs,
        num_workers=args.num_workers,
        seed=seed,
        use_blot=args.use_blot,
        use_augs=args.use_augs,
        use_stackmix=args.use_stackmix,
        use_pretrained_backbone=args.use_pretrained_backbone,
    )

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logging.info(f'DEVICE: {device}')
    logging.info(f'DATASET: {args.dataset_name}')

    ctc_labeling = CTCLabeling(config)

    df = pd.read_csv(args.data_dir / f'{args.dataset_name}/marking.csv', index_col='sample_id')

    transforms = get_transforms(config, args.use_augs, args.use_blot)

    train_dataset_kwargs = {'transforms': transforms}
    if args.use_stackmix:
        stackmix = StackMix(
            mwe_tokens_dir=args.mwe_tokens_dir,
            data_dir=args.data_dir,
            dataset_name=args.dataset_name,
            image_h=args.image_h,
        )
        stackmix.load()
        stackmix.load_corpus(ctc_labeling, args.data_dir / f'corpora/{config.corpus_name}')
        train_dataset_kwargs['stackmix'] = stackmix

    df_train = df[~df['stage'].isin(['valid', 'test'])]
    train_dataset = DatasetRetriever(df_train, config, ctc_labeling, **train_dataset_kwargs)
    valid_dataset = DatasetRetriever(df[df['stage'] == 'valid'], config, ctc_labeling)
    test_dataset = DatasetRetriever(df[df['stage'] == 'test'], config, ctc_labeling)

    model = get_ocr_model(config, pretrained=bool(args.use_pretrained_backbone))

    model = model.to(device)
    criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), **config['optimizer']['params'])

    train_loader = DataLoader(
        train_dataset,
        batch_size=config['bs'],
        sampler=RandomSampler(train_dataset),
        pin_memory=False,
        drop_last=True,
        num_workers=config['num_workers'],
        collate_fn=utils.kw_collate_fn
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=config['bs'],
        sampler=SequentialSampler(valid_dataset),
        pin_memory=False,
        drop_last=False,
        num_workers=config['num_workers'],
        collate_fn=utils.kw_collate_fn
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config['bs'],
        sampler=SequentialSampler(test_dataset),
        pin_memory=False,
        drop_last=False,
        num_workers=config['num_workers'],
        collate_fn=utils.kw_collate_fn
    )
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        epochs=config['num_epochs'],
        steps_per_epoch=len(train_loader),
        **config['scheduler']['params'],
    )

    neptune_kwargs = {}
    if args.neptune_project:
        tags = [args.dataset_name]
        if args.use_blot and args.use_stackmix and args.use_augs:
            tags.append('blots_augs_stackmix')
        elif args.use_blot and args.use_stackmix:
            tags.append('blots_stackmix')
        elif args.use_blot and args.use_augs:
            tags.append('blots_augs')
        elif args.use_stackmix and args.use_augs:
            tags.append('augs_stackmix')
        elif args.use_stackmix:
            tags.append('stackmix')
        elif args.use_blot:
            tags.append('blots')
        elif args.use_augs:
            tags.append('augs')
        else:
            tags.append('base')

        neptune.init(
            project_qualified_name=args.neptune_project,
            api_token=args.neptune_token,
        )
        neptune_kwargs = dict(
            neptune=neptune,
            neptune_params={
                'description': config['experiment_description'],
                'params': config.params,
                'tags': tags,
            }
        )

    if not args.checkpoint_path:
        experiment = OCRExperiment(
            experiment_name=config['experiment_name'],
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            scheduler=scheduler,
            device=device,
            base_dir=args.output_dir,
            best_saving={'cer': 'min', 'wer': 'min', 'acc': 'max'},
            last_saving=True,
            low_memory=True,
            verbose_step=10**5,
            seed=seed,
            use_progress_bar=bool(args.use_progress_bar),
            **neptune_kwargs,
            ctc_labeling=ctc_labeling,
        )
        experiment.fit(train_loader, valid_loader, config['num_epochs'])
    else:
        logging.info(f'RESUMED FROM: {args.checkpoint_path}')
        experiment = OCRExperiment.resume(
            checkpoint_path=args.checkpoint_path,
            train_loader=train_loader,
            valid_loader=valid_loader,
            n_epochs=config['num_epochs'],
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            scheduler=scheduler,
            device=device,
            seed=seed,
            neptune=neptune_kwargs.get('neptune'),
            ctc_labeling=ctc_labeling,
        )

    time_inference = []
    for best_metric in ['best_cer', 'best_wer', 'best_acc', 'last']:
        experiment.load(experiment.experiment_dir / f'{best_metric}.pt')
        experiment.model.eval()
        predictor = Predictor(experiment.model, device)
        time_a = time.time()
        predictions = predictor.run_inference(test_loader)
        time_b = time.time()
        time_inference.append(time_b - time_a)
        df_pred = pd.DataFrame([{
            'id': prediction['id'],
            'pred_text': ctc_labeling.decode(prediction['raw_output'].argmax(1)),
            'gt_text': prediction['gt_text']
        } for prediction in predictions]).set_index('id')

        cer_metric = round(cer(df_pred['pred_text'], df_pred['gt_text']), 5)
        wer_metric = round(wer(df_pred['pred_text'], df_pred['gt_text']), 5)
        acc_metric = round(string_accuracy(df_pred['pred_text'], df_pred['gt_text']), 5)

        if args.neptune_project:
            experiment.neptune.log_metric(f'cer_test__{best_metric}', cer_metric)
            experiment.neptune.log_metric(f'wer_test__{best_metric}', wer_metric)
            experiment.neptune.log_metric(f'acc_test__{best_metric}', acc_metric)

        mistakes = df_pred[df_pred['pred_text'] != df_pred['gt_text']]
        df_pred.to_csv(experiment.experiment_dir / f'pred__{best_metric}.csv')

        if args.neptune_project:
            experiment.neptune.log_metric(f'mistakes__{best_metric}', mistakes.shape[0])
            experiment.neptune.log_artifact(str(experiment.experiment_dir / f'pred__{best_metric}.csv'))

        experiment._log(  # noqa
            f'Results for {best_metric}.pt.',
            cer=cer_metric,
            wer=wer_metric,
            acc=acc_metric,
            speed_inference=len(test_dataset) / (time_b - time_a),
        )

    if args.neptune_project:
        experiment.neptune.log_metric('time_inference', np.mean(time_inference))
        experiment.neptune.log_metric('speed_inference', len(test_dataset) / np.mean(time_inference))  # sample / sec
        experiment.neptune.log_metric('train_per_epoch_iterations', len(train_loader))

    experiment.destroy()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run train script.')
    parser.add_argument('--checkpoint_path', type=Path, default=None, help='Path to the checkpoint file')
    parser.add_argument('--experiment_name', type=str, required=True, help='Name of the experiment')
    parser.add_argument('--use_augs', type=int, required=True, help='Whether to use augmentations')
    parser.add_argument('--use_blot', type=int, required=True, help='Whether to use blot augmentations')
    parser.add_argument('--use_stackmix', type=int, required=True, help='Whether to use StackMix augmentations')
    parser.add_argument('--neptune_project', type=str, default=None, help='Neptune project name')
    parser.add_argument('--neptune_token', type=str, default=None, help='Neptune token')
    parser.add_argument('--data_dir', type=Path, required=True, help='Path to the data directory')
    parser.add_argument('--mwe_tokens_dir', type=Path, required=True, help='Path to the MWE tokens directory')
    parser.add_argument('--output_dir', type=Path, required=True, help='Path to the output directory')
    parser.add_argument('--experiment_description', type=str, required=True, help='Description of the experiment')
    parser.add_argument('--dataset_name', type=str, required=True, help='Name of the dataset')
    parser.add_argument('--image_w', type=int, required=True, help='Image width')
    parser.add_argument('--image_h', type=int, required=True, help='Image height')
    parser.add_argument('--num_epochs', type=int, required=True, help='Number of epochs')
    parser.add_argument('--bs', type=int, required=True, help='Batch size')
    parser.add_argument('--num_workers', type=int, required=True, help='Number of workers for data loading')
    parser.add_argument('--seed', type=int, default=6955, help='Random seed')
    parser.add_argument('--use_progress_bar', type=int, default=0, help='Whether to use progress bar')
    parser.add_argument('--use_pretrained_backbone', type=int, default=1, help='Whether to use pretrained backbone')

    args = parser.parse_args()

    main(args)

In [None]:
from PIL import Image, ImageDraw, ImageFont
import textwrap
import os

# Text to be rendered
txt = """Python is an interpreted high-level general-purpose programming
Its design philosophy emphasizes code readability with its use of signi"""

# Create a new image with a white background
width, height = 800, 600
image = Image.new('RGB', (width, height), color=(255, 255, 255))

# Create a drawing object
draw = ImageDraw.Draw(image)

# Use a built-in font provided by PIL
# font = ImageFont.truetype(font_path, 24)
font = ImageFont.truetype('wg_goodbye_font.ttf', 24)  # Replace 'Arial.ttf' with any built-in font name

# Split the text into lines based on a maximum line width
max_line_width = 50
lines = textwrap.wrap(txt, max_line_width)

# Draw the lines of text onto the image
y = 10
for line in lines:
    draw.text((10, y), line, font=font, fill=(0, 0, 138))
    y += font.size + 10

# Save the image
image.save("handwritten_text.png")
print("END")