In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [16, 4]

# Configuration

In [3]:
# training config
experiment_name = 'swinformer-ocr'

batch_size = 2
num_workers = 4

max_train_steps = 20000
max_val_steps = max_train_steps // 100

# environment config
import torch
use_cuda = torch.cuda.is_available()

In [5]:
# tokenizer config
model_max_length = 96
bos_token_id = 0
eos_token_id = 1
pad_token_id = 2
unk_token_id = 3

# input image config
height, width = 128, 640

channels = 1
pixel_mean = (0.5,) # for one channel
pixel_std = (0.5,) # for one channel

# encoder architecture config
patch_size = 4
window_size = 8

embed_dim = 96
depths = [2, 6, 2]
num_heads = [6, 12, 24]

# decoder architecture config
decoder_config = dict(
    dim=384,
    depth=4,
    heads=8,
    cross_attend=True,
    ff_glu=False,
    attn_on_attn=False,
    use_scalenorm=False,
    rel_pos_bias=False
)

# auto regressive wrapper architecture config
from dataset.text_cleaning_utils import ALL_CHARACTERS
num_tokens = len(ALL_CHARACTERS)
max_seq_len = model_max_length

# optimizer config
from timm.optim import AdamW
optimizer_config = dict(
    base_class=AdamW,
    params=dict(
        lr=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ),
)

# schedueler config
from timm.scheduler.cosine_lr import CosineLRScheduler
scheduler_config = dict(
    base_class=CosineLRScheduler,
    params=dict(
        t_initial=200,
        lr_min=1e-6,
        cycle_mul=3,
        cycle_decay=0.8,
        cycle_limit=20,
        warmup_t=20,
        k_decay=1.5,
    ),
)

# Tokenization

In [6]:
from tokenization.character_tokenizer import CharacterTokenizer

character_tokenizer = CharacterTokenizer(
    characters=ALL_CHARACTERS,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    unk_token_id=unk_token_id,
    model_max_length=model_max_length,
)

# Transforms

In [7]:
from torchvision.transforms import Compose, Resize, Grayscale, ToTensor, Normalize

transform = Compose([
    Resize((height, width)),
    Grayscale(),
    ToTensor(),
    Normalize(pixel_mean, pixel_std),
])

# Dataset

In [8]:
import datasets as ds
from dataset.textline_dataset import TextLineDataset
from dataset.text_cleaning_utils import preprocess_wikipedia_dataset

train_dataset = ds.load_dataset("wikipedia", "20220301.fr", split="train[:90%]")
val_dataset = ds.load_dataset("wikipedia", "20220301.fr", split="train[-10%:]")

train_dataset = preprocess_wikipedia_dataset(train_dataset)
val_dataset = preprocess_wikipedia_dataset(val_dataset)

train_dataset = TextLineDataset(
    dataset=train_dataset,
    tokenizer=character_tokenizer,
    transform=transform,
)

val_dataset = TextLineDataset(
    val_dataset,
    tokenizer=character_tokenizer,
    transform=transform,
)

Missing modules for handwritten text generation.


Found cached dataset wikipedia (/home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Found cached dataset wikipedia (/home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-421afcd459ec1f9d_*_of_00004.arrow
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-c181330655572a4c_*_of_00004.arrow


# Model

In [9]:
from modeling.lightning_base import LightningBase
from modeling.encoder import SwinTransformerEncoder
from modeling.decoder import AutoregressiveDecoder
from modeling.vision_encoder_decoder import VisionEncoderDecoder

# create encoder
encoder = SwinTransformerEncoder(
    img_size=(height, width),
    patch_size=patch_size,
    in_chans=channels,
    embed_dim=embed_dim,
    depths=depths,
    num_heads=num_heads,
    window_size=window_size,
)

# create decoder
decoder = AutoregressiveDecoder(
    decoder_config=decoder_config,

    num_tokens=character_tokenizer.vocab_size,
    max_seq_len=character_tokenizer.model_max_length,

    bos_token_id=character_tokenizer.bos_token_id,
    eos_token_id=character_tokenizer.eos_token_id,
    pad_token_id=character_tokenizer.pad_token_id,
)

# create vision encoder decoder
vision_encoder_decoder = VisionEncoderDecoder(
    encoder=encoder,
    decoder=decoder,
)

# create lightning model
lightning_model = LightningBase(
    tokenizer=character_tokenizer,
    model=vision_encoder_decoder,
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config,
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# Dataloaders

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=use_cuda, # for faster cpu to gpu transfer
    shuffle=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=use_cuda, # for faster cpu to gpu transfer
    shuffle=False,
)

# Callbacks

In [12]:
import pytorch_lightning as pl

prog_bar = pl.callbacks.progress.TQDMProgressBar(
    refresh_rate=10,
)

logger = pl.loggers.TensorBoardLogger(
    save_dir=f"logs/{experiment_name}/",
)

ckpt_callback = pl.callbacks.ModelCheckpoint(
    dirpath=f"checkpoints/{experiment_name}/",
    filename="checkpoint-{epoch:03d}-{val_cer:.5f}",
    monitor="val_cer",
    save_last=True,
    save_top_k=3,
    mode="min",
)

lr_monitor = pl.callbacks.LearningRateMonitor(
    logging_interval="step",
)


# Training

In [19]:
trainer = pl.Trainer(
    accelerator="gpu" if use_cuda else None,
    benchmark=True,

    log_every_n_steps=1,
    # num_sanity_val_steps=1,

    limit_val_batches=max_val_steps,
    limit_train_batches=max_train_steps,

    callbacks=[ckpt_callback, lr_monitor, prog_bar],
    enable_progress_bar=True,
    logger=logger,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(
    lightning_model,
    train_dataloader,
    val_dataloader,
)

In [None]:
%tensorboard --logdir logs