In [1]:
import sys

sys.path.append("../")

In [10]:
import os
import copy
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning_ocr.models import ABINetVision
from lightning_ocr.datasets import RecogTextDataset, RecogTextDataModule
from sklearn.model_selection import train_test_split
import albumentations as A
import lightning as L

In [7]:
# ENV
os.environ["TOKENIZERS_PARALLELISM"] = "true"

batch_size = 8

config = {
    "max_seq_len": 12,
    "tokenizer": {
        "dict_list": list("0123456789."),
    },
}

In [4]:
# MODEL

model = ABINetVision(config)

In [5]:
# DATASETS

train_dataset = RecogTextDataset(
    data_root="./datasets/MNIST/",
    ann_file="ann_file.json",
    pipeline=model.load_train_pipeline(),
)

TRAIN, TEST = train_test_split(train_dataset.data_list, test_size=0.2, random_state=42)

test_dataset = copy.deepcopy(train_dataset)
test_dataset.data_list = TEST
test_dataset.transform = A.Compose(model.load_test_pipeline())
train_dataset.data_list = TRAIN

In [None]:
# CONFIG TRANER

log_every_n_steps = 50
if len(train_dataset) // batch_size < 50:
    log_every_n_steps = 5

checkpoint_callback = ModelCheckpoint(
    dirpath="./checkpoints/abinet",
    filename="model-{epoch:02d}-loss-{loss/total_epoch:.2f}",
    monitor="loss/total_epoch",
    save_weights_only=True,
    auto_insert_metric_name=False,
    every_n_epochs=1,
)

tb_logger = TensorBoardLogger(save_dir="logs/abinet/")

trainer = L.Trainer(
    precision="16-mixed",
    logger=tb_logger,
    log_every_n_steps=log_every_n_steps,
    callbacks=[checkpoint_callback],
    max_epochs=20,
)

In [9]:
# DUMP MODEL CONFIG

model.dump_config(checkpoint_callback.dirpath)

In [None]:
trainer.fit(
    model,
    datamodule=RecogTextDataModule(
        train_datasets=[train_dataset],
        eval_datasets=[test_dataset],
        batch_size=batch_size,
    ),
)

In [None]:
!ls -lah "./checkpoints/abinet"