In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import yaml

hyperparameters = yaml.load(open('hyperparameters.yaml'), Loader=yaml.FullLoader)

In [None]:
from data import MathEquationsDatamodule
data_module = MathEquationsDatamodule(
    'data',
    hyperparameters['model']['image_size'],
    hyperparameters['batch_size'],
    hyperparameters['num_workers'],
)

In [None]:
from model.transformer import Image2LaTeXVisionTransformer
from lightning.pytorch.utilities.model_summary import ModelSummary

model = Image2LaTeXVisionTransformer(hyperparameters['model'], hyperparameters['optimizer'])
ModelSummary(model, max_depth=3)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min')
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=5, mode='min')
lr_monitor = LearningRateMonitor(logging_interval='step')
logger = TensorBoardLogger('logs', name='image2latex', log_graph=True)

trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
    max_epochs=hyperparameters['num_epochs'],
    enable_checkpointing=True,
    enable_progress_bar=True
)

In [None]:
trainer.fit(model, data_module)