# Demo of the HMEs-Recognition

Assume that you are using Google Colab, and you have already cloned the repository.

## Preparation

Download the [dataset](https://www.kaggle.com/datasets/jungomi/chrome-png) and put it in the `data` folder.

### Data

In [None]:
! pip install -q kaggle
import os

if not os.path.isfile(os.path.expanduser('~/.kaggle/kaggle.json')):
    from google.colab import files

    print("Upload kaggle.json here")
    files.upload()

if not os.path.isfile('data/token.tsv'):
    !mkdir ~ /.kaggle
    !mv./kaggle.json ~ /.kaggle/
    !chmod 600 ~ /.kaggle/kaggle.json

    dataset_name = 'jungomi/chrome-png'
    zip_name = dataset_name.split('/')[-1]

    !kaggle datasets download -d {dataset_name}
    !unzip -q./ {zip_name}.zip -d.

### Install lightning

In [None]:
# install lightning
!pip install -q lightning

### Download the checkpoint

You could download the checkpoint from [Google Drive](https://drive.google.com/drive/folders/1g6LnaHuJkPI2z5X7Qh4Ms2HdZpSvlSxG?usp=drive_link) and put it in the `checkpoints` folder.

## Training

In [None]:
# parameters
DATA_DIR = 'data'
BATCH_SIZE = 32
NUM_WORKERS = 4
LR = 0.1
MAX_EPOCHS = 100
ENCODER_OUT_DIM = 512

CHECKPOINT_PATH = 'checkpoints/best-checkpoint-exp-rate.ckpt'

In [None]:
import lightning as l
from callbacks import checkpoint_callback_exp_rate_3, checkpoint_callback_val_loss, early_stop_callback_loss
from dataset.data_module import MathExpressionDataModule
from model.HMERecognizer import HMERecognizer

data_module = MathExpressionDataModule(data_dir=DATA_DIR, tokens_file='tokens.tsv', batch_size=BATCH_SIZE,
                                       num_workers=NUM_WORKERS)

model = HMERecognizer(token_to_id=data_module.token_to_id, lr=LR, encoder_out_dim=ENCODER_OUT_DIM,
                      vocab_size=len(data_module.token_to_id), batch_size=BATCH_SIZE)

callback = list()
callback.append(checkpoint_callback_exp_rate_3)
callback.append(checkpoint_callback_val_loss)
# callback.append(early_stop_callback_exp_rate)
callback.append(early_stop_callback_loss)

data_module.setup('train')
# trainer = l.Trainer(max_epochs=MAX_EPOCHS, callbacks=callback)
trainer = l.Trainer(max_epochs=MAX_EPOCHS, callbacks=callback, fast_dev_run=True)
trainer.fit(model, data_module)

## Testing

In [None]:
trainer = l.Trainer(fast_dev_run=True)
data_module.setup('test')
model = HMERecognizer.load_from_checkpoint(checkpoint_path=CHECKPOINT_PATH, token_to_id=data_module.token_to_id,
                                           lr=LR, encoder_out_dim=ENCODER_OUT_DIM,
                                           vocab_size=len(data_module.token_to_id), batch_size=BATCH_SIZE)
model.eval()

results = {}
for year, dataloader in data_module.test_dataloader().items():
    test_dataloader = dataloader

    # Run the test set through the trained model
    results[year] = trainer.test(model, dataloaders=test_dataloader)

## Visualization

If you also want to see the log during training, you may download the log file `lightning_logs/version_0` from [Google Drive](https://drive.google.com/drive/folders/1g6LnaHuJkPI2z5X7Qh4Ms2HdZpSvlSxG?usp=drive_link) and put it in the `lightning_logs` folder. If there is `version_0` in `lightning_logs` folder, you could delete it or rename it.

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/