In [1]:
import sys
sys.path.append("../")

In [10]:
import os
import copy
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning_ocr.models import ABINetVision
from lightning_ocr.datasets import RecogTextDataset, RecogTextDataModule
from sklearn.model_selection import train_test_split
import albumentations as A
import lightning as L

In [7]:
# ENV
os.environ["TOKENIZERS_PARALLELISM"] = "true"

batch_size = 8

config = {
    "max_seq_len": 12,
    "tokenizer": {
        "dict_list": list("0123456789."),
    },
}

In [4]:
# MODEL

model = ABINetVision(config)

In [5]:
# DATASETS

train_dataset = RecogTextDataset(
    data_root="./datasets/MNIST/",
    ann_file="ann_file.json",
    pipeline=model.load_train_pipeline(),
)

TRAIN, TEST = train_test_split(
    train_dataset.data_list, test_size=0.2, random_state=42
)

test_dataset = copy.deepcopy(train_dataset)
test_dataset.data_list = TEST
test_dataset.transform = A.Compose(model.load_test_pipeline())
train_dataset.data_list = TRAIN

In [8]:
# CONFIG TRANER

log_every_n_steps = 50
if len(train_dataset) // batch_size < 50:
    log_every_n_steps = 5

checkpoint_callback = ModelCheckpoint(
    dirpath="./checkpoints/abinet",
    filename="model-{epoch:02d}-loss-{loss/total_epoch:.2f}",
    monitor="loss/total_epoch",
    save_weights_only=True,
    auto_insert_metric_name=False,
    every_n_epochs=1,
)

tb_logger = TensorBoardLogger(save_dir="logs/abinet/")

trainer = L.Trainer(
    precision="16-mixed",
    logger=tb_logger,
    log_every_n_steps=log_every_n_steps,
    callbacks=[checkpoint_callback],
    max_epochs=20,
)

Using 16bit Automatic Mixed Precision (AMP)
/home/mixaill76/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py:54: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
# DUMP MODEL CONFIG 

model.dump_config(checkpoint_callback.dirpath)

In [11]:
trainer.fit(
    model,
    datamodule=RecogTextDataModule(
        train_datasets=[train_dataset],
        eval_datasets=[test_dataset],
        batch_size=batch_size,
    ),
)

You are using a CUDA device ('NVIDIA GeForce RTX 3080 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: logs/abinet/lightning_logs
/home/mixaill76/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory ./checkpoints/abinet exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | backbone | ResNetABI        | 13.0 M
1 | encoder  | ABIEncoder       | 9.5 M 
2 | decoder  | ABIVisionDecoder | 1.1 M 
----------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.145    Total estimated mod

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/mixaill76/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [12]:
!ls -lah "./checkpoints/abinet"

total 91M
drwxr-xr-x 2 mixaill76 mixaill76 4.0K Jan 11 23:26 .
drwxr-xr-x 3 mixaill76 mixaill76 4.0K Jan 11 23:20 ..
-rw-r--r-- 1 mixaill76 mixaill76  288 Jan 11 23:20 base_config.json
-rw-r--r-- 1 mixaill76 mixaill76  91M Jan 11 23:26 model-04-loss-0.17.ckpt
