In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [16, 4]

# Configuration

In [3]:
# training config
experiment_name = 'swin-autoregressive-ocr'

batch_size = 2
num_workers = 2

max_train_steps = 1000
max_val_steps = 100

# environment config
import torch
use_cuda = torch.cuda.is_available()

# tokenizer config
from dataset.wikipedia_dataset import FRENCH_CHARACTERS
characters = FRENCH_CHARACTERS
model_max_length = 96
bos_token_id = 0
eos_token_id = 1
pad_token_id = 2
unk_token_id = 3

# input image config
height, width = 128, 640

channels = 1
pixel_mean = (0.5,) # for one channel
pixel_std = (0.5,) # for one channel

# encoder architecture config
patch_size = 4
window_size = 8
# encoder architecture config
embed_dim = 96
depths = [2, 6, 2]
num_heads = [6, 12, 24]

# decoder architecture config
dim = 384
heads = 8
dropout = 0.1
activation = 'gelu'
norm_first = False
# decoder stack config
num_layers = 4
# language modeling config
num_tokens = len(characters) + 4
max_seq_len = model_max_length

# optimizer config
from timm.optim import AdamW
optimizer_config = dict(
    base_class=AdamW,
    params=dict(
        lr=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ),
)

# schedueler config
from timm.scheduler.cosine_lr import CosineLRScheduler
scheduler_config = dict(
    base_class=CosineLRScheduler,
    params=dict(
        t_initial=200,
        lr_min=1e-6,
        cycle_mul=3,
        cycle_decay=0.8,
        cycle_limit=20,
        warmup_t=20,
        k_decay=1.5,
    ),
)

Missing modules for handwritten text generation.


  from .autonotebook import tqdm as notebook_tqdm


# Tokenizer

In [4]:
from tokenization.character_tokenizer import CharacterTokenizer

character_tokenizer = CharacterTokenizer(
    characters=characters,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    unk_token_id=unk_token_id,
    model_max_length=model_max_length,
)

# Transform

In [5]:
from torchvision.transforms import Compose, Resize, Grayscale, ToTensor, Normalize

simple_transform = Compose([
    Resize((height, width)),
    Grayscale(),
    ToTensor(),
    Normalize(pixel_mean, pixel_std),
])

# Dataset

In [6]:
from dataset.wikipedia_dataset import WikipediaTextLineDataModule

data_module = WikipediaTextLineDataModule(
    name='20220301.fr',
    transform=simple_transform,
    tokenizer=character_tokenizer,
    batch_size=batch_size,
    num_workers=num_workers,
    characters=characters,
)
data_module.setup()

Found cached dataset wikipedia (/home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Found cached dataset wikipedia (/home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-3e3b644afaa1b40a_*_of_00004.arrow
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-1499956ad50080bc_*_of_00004.arrow


# Model

In [12]:
from modeling.encoder import SwinTransformerEncoder
from modeling.decoder import AutoregressiveTransformerDecoder
from modeling.encoder_decoder import VisionEncoderLanguageDecoder
from modeling.lightning_wrapper import VisionEncoderLanguageDecoderWrapper

# create encoder
vision_encoder = SwinTransformerEncoder(
    img_size=(height, width),
    patch_size=patch_size,
    in_chans=channels,
    embed_dim=embed_dim,
    depths=depths,
    num_heads=num_heads,
    window_size=window_size,
)

# create autoregressive decoder
language_decoder = AutoregressiveTransformerDecoder(
    dim=dim,
    heads=heads,
    dropout=dropout,
    activation=activation,
    norm_first=norm_first,
    
    num_layers=num_layers,

    num_tokens=num_tokens,
    max_seq_len=max_seq_len,

    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
)

# create vision encoder decoder
vision_encoder_language_decoder = VisionEncoderLanguageDecoder(
    vision_encoder=vision_encoder,
    language_decoder=language_decoder,
)

# create lightning model
lightning_model = VisionEncoderLanguageDecoderWrapper(
    model=vision_encoder_language_decoder,
    tokenizer=character_tokenizer,
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config,
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# Callbacks

In [13]:
import pytorch_lightning as pl

prog_bar = pl.callbacks.progress.TQDMProgressBar(
    refresh_rate=10,
)

logger = pl.loggers.TensorBoardLogger(
    save_dir=f"logs/{experiment_name}/",
)

ckpt_callback = pl.callbacks.ModelCheckpoint(
    dirpath=f"checkpoints/{experiment_name}/",
    filename="checkpoint-{epoch:03d}-{val_cer:.5f}-{val_acc:.5f}",
    monitor="val_cer",
    save_last=True,
    save_top_k=3,
    mode="min",
)

lr_monitor = pl.callbacks.LearningRateMonitor(
    logging_interval="step",
)

# Training

In [14]:
trainer = pl.Trainer(
    accelerator="gpu" if use_cuda else 'cpu',

    log_every_n_steps=1,
    num_sanity_val_steps=1,

    limit_val_batches=max_val_steps,
    limit_train_batches=max_train_steps,

    callbacks=[ckpt_callback, lr_monitor, prog_bar],
    enable_progress_bar=True,
    logger=logger,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(
    lightning_model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader(),
)

  rank_zero_warn(

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | VisionEncoderLanguageDecoder | 18.4 M
-------------------------------------------------------
18.4 M    Trainable params
0         Non-trainable params
18.4 M    Total params
73.654    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/1000 [00:00<?, ?it/s]                          

In [None]:
%tensorboard --logdir logs