In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [16, 4]

# Configuration

In [3]:
# tokenizer config
model_max_length = 96
bos_token_id = 0
eos_token_id = 1
pad_token_id = 2
unk_token_id = 3

# encoder architecture config
height, width = 128, 640
channels = 3

patch_size = 4
window_size = 8

embed_dim = 96
depths = [2, 6, 2]
num_heads = [6, 12, 24]

# decoder architecture config
decoder_config = dict(
    dim=384,
    depth=4,
    heads=8,
    cross_attend=True,
    ff_glu=False,
    attn_on_attn=False,
    use_scalenorm=False,
    rel_pos_bias=False
)

# auto regressive wrapper architecture config
from dataset.text_cleaning_utils import ALL_CHARACTERS
num_tokens = len(ALL_CHARACTERS)
max_seq_len = model_max_length

# optimizer config
from timm.optim import AdamW
optimizer_config = dict(
    base_class=AdamW,
    params=dict(
        lr=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
    ),
)

# schedueler config
from timm.scheduler.cosine_lr import CosineLRScheduler
scheduler_config = dict(
    base_class=CosineLRScheduler,
    params=dict(
        t_initial=500,
        lr_min=1e-6,
        cycle_mul=3,
        cycle_decay=0.8,
        cycle_limit=20,
        warmup_t=20,
        k_decay=1.5,
    ),
)

# Tokenization

In [4]:
from tokenization.character_tokenizer import CharacterTokenizer

tokenizer = CharacterTokenizer(
    characters=ALL_CHARACTERS,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    unk_token_id=unk_token_id,
    model_max_length=model_max_length,
)

# Dataset

In [5]:
import datasets as ds
from dataset.textline_dataset import TextLineDataset
from dataset.text_cleaning_utils import preprocess_wikipedia_dataset

train_dataset = ds.load_dataset("wikipedia", "20220301.fr", split="train[:90%]")
val_dataset = ds.load_dataset("wikipedia", "20220301.fr", split="train[-10%:]")

train_dataset = preprocess_wikipedia_dataset(train_dataset)
val_dataset = preprocess_wikipedia_dataset(val_dataset)

train_dataset = TextLineDataset(
    tokenizer,
    train_dataset,
)

val_dataset = TextLineDataset(
    tokenizer,
    val_dataset,
)

Missing modules for handwritten text generation.


Found cached dataset wikipedia (/home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Found cached dataset wikipedia (/home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-136ef6fc905cefd4_*_of_00004.arrow
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-f20708804beaf2ed_*_of_00004.arrow
Loading cached processed dataset at /home/ilyas/.cache/huggingface/datasets/wikipedia/20220301.fr/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559/cache-d91aeb68258ac9ef_*_of_00004.arrow
Loading cached processed dataset at /home

In [6]:
from modeling.encoder import SwinTransformerEncoder
from modeling.decoder import AutoregressiveDecoder
from modeling.vision_encoder_decoder import VisionEncoderDecoder

from modeling.lightning_base import LightningBase

# create encoder
encoder = SwinTransformerEncoder(
    img_size=(height, width),
    patch_size=patch_size,
    in_chans=channels,
    embed_dim=embed_dim,
    depths=depths,
    num_heads=num_heads,
    window_size=window_size,
)

# create decoder
decoder = AutoregressiveDecoder(
    decoder_config=decoder_config,

    num_tokens=num_tokens,
    max_seq_len=max_seq_len,

    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
)

# create vision encoder decoder
model = VisionEncoderDecoder(
    encoder=encoder,
    decoder=decoder,
)

# create lightning model
lightning_model = LightningBase(
    model=model,
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config,
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
