# Environment Setup

## NOTE: REQUIREMENTS
* Linux kernel version: >=5.5. Check with `uname -r`.
* tesseract. Install with `sudo apt install tesseract-ocr -y`.
   

In [None]:
!pip uninstall torch torchaudio torchvision -y

In [None]:
!pip install torch torchaudio torchvision transformers[torch] pytesseract tesseract matplotlib Pillow numpy tqdm scikit-learn protobuf sentencepiece pytorch-lightning tensorboardX accelerate nltk wandb

In [None]:
!which tesseract  # Make sure that tesseract is installed

In [None]:
import pytesseract
pytesseract.pytesseract.tesseract_cmd = '/usr/bin/tesseract'

# Preparing Hugging Face connection

This will allow to upload the training results to a private HF repository.

In [None]:
from huggingface_hub import login

login()  # Use itam-franmr account token

# Define Donut Model and Processor

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration parameters
model_name = "naver-clova-ix/donut-base"
start_token = "<s_cord-v2>"
end_token = "</s_cord-v2>"
image_size = [1080, 720]
max_length = 800

# Defining processor
processor = DonutProcessor.from_pretrained(model_name, use_fast=True)  # use_fast allows optimization
processor.image_processor.do_align_long_axis = False
processor.image_processor.size = {"height": image_size[0], "width": image_size[1]}

# Defining Model
model_config = VisionEncoderDecoderConfig.from_pretrained(model_name)
model_config.encoder.image_size = image_size
model_config.decoder.max_length = max_length
model_config.pad_token_id = processor.tokenizer.pad_token_id
model_config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(start_token)

model = VisionEncoderDecoderModel.from_pretrained(model_name, config=model_config)
# model.to(device)  # Sends to GPU if existing.
model.train()  # Activate training mode

new_tokens_list = []

# Define Donut Dataset

Use PyTorch Dataset class to define a dataset in the format that Donut expects.

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import json
import re
import torch
from pathlib import Path, PureWindowsPath
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from collections import defaultdict

class DonutDataset(Dataset):
    def __init__(self,
            dataset_path: Path,
            processor,
            model_config,
            split: str = 'train',
            start_token: str = "<s>",
            end_token: str = "</s>",
            excluded_keys: tuple = tuple()
            ):

        super().__init__()
        
        self.dataset_path = dataset_path
        self.split = split
        self.start_token = start_token
        self.end_token = end_token
        self.excluded_keys = excluded_keys
        self.empty_token = '<s_no_value>'
        self.ignore_id = -100
        self.processor = processor
        self.model_config = model_config
        self.new_special_tokens = set()  # Set that will collect all new special tokens found in the given data.
        
        self.train_set = []
        self.val_set = []
        self.test_set = []

        # Loads data from the given dataset path
        self._get_synthetic_samples_list()
        self.raw_samples = self._load_dataset()

        # Add start and end tokens as special tokens.
        self.new_special_tokens.add(self.start_token)
        self.new_special_tokens.add(self.end_token)

    @staticmethod
    def _split_train_validation_test(dataset: list):
        """ Splits the data into training, validation and testing sets. """
        
        x_train_tmp, x_test = train_test_split(dataset, test_size=0.1)
        x_train, x_val= train_test_split(x_train_tmp, test_size=0.135)
        return x_train, x_val, x_test

    def _load_dataset(self) -> dict:
        """ Selects a split from the dataset and loads its images and json contents as 'raw samples'. """

        samples_paths = self.train_set if self.split == 'train' \
            else self.val_set if self.split in ['validation', 'val'] \
            else self.test_set

        raw_samples = {'ims': [], 'ids': [], 'json': []}

        for raw_sample in tqdm(samples_paths):

            im_path, json_path = raw_sample.with_suffix('.png'), raw_sample.with_suffix('.json')
            
            # Read image in RGB format.
            im_raw = Image.open(im_path).convert('RGB')

            # Read attached JSON content.
            with open(json_path, 'r', encoding='utf-8') as fp:
                json_content = json.load(fp)

            raw_samples['json'].append(json_content)  # For debugging purposes
            
            json_content = self._clean_excluded_keys(json_content)
            json_content = self._normalize_empty_values(json_content)
            
            json_tokens = self.json2token(json_content)

            raw_samples['ims'].append(im_raw)
            raw_samples['ids'].append(json_tokens)

        return raw_samples

    @staticmethod
    def _clean_excluded_keys(obj: dict):
        for k in self.excluded_keys:
            if k in obj.keys():
                obj.pop(k)

        if obj.get('products'):
            for k in self.excluded_keys:
                if k in obj['products']:
                    obj['products'].pop(k)

        return obj

    def _normalize_empty_values(self, obj: dict):
        # Document-level fields
        obj = {k: v if v else self.empty_token for k, v in obj.items()}

        # Product-level fields
        if type(obj['products']) == list:
            for i, product in enumerate(obj['products']):
                obj['products'][i] = {k: v if v else self.empty_token for k, v in product.items()}

        return obj

    def __len__(self):
        """ 
        Retrieves the number of samples in the current dataset split. 
        **Method required by PyTorch DataLoaders to work**
        """

        dataset_length = len(self.train_set) if self.split == 'train' \
                    else len(self.val_set) if self.split in ['validation', 'val'] \
                    else len(self.test_set)
        return dataset_length                                                
        
    def __getitem__(self, idx) -> dict:
        """ 
        Retrieves the sample in position `idx` from the dataset. 
        **Method required by PyTorch DataLoaders to work**
        """
        
        im_raw, json_tokens = self.raw_samples['ims'][idx], self.raw_samples['ids'][idx]

        im = self._normalize_image(im_raw)
        gt_tokens, target_str_sequence = self._normalize_tagging(json_tokens)

        return {
            'pixel_values': im,
            'labels': gt_tokens,
            'target_str_seq': target_str_sequence,
        }        

    def _get_samples_list(self):
        """ 
        Finds all the samples in the dataset and splits them into training, validation and test sets
        based on providers. The split percentages are applied to each provider, and not directly to 
        the full dataset, ensuring that if a provider has enough samples, they will exist in every data
        split, making validation more reasonable.
        """

        samples_by_provider = defaultdict(list)
        metadata_files = self.dataset_path.glob('**/*.file_metadata.json')

        # Iterate over sample folders to determine original providers.
        for metadata_file in metadata_files:
            with open(metadata_file, 'r') as fp:
                json_content = json.load(fp)

            original_file = PureWindowsPath(json_content['origin']['name'])  # Paths saved in a Windows machine.
            provider = Path(*original_file.parts[:original_file.parts.index('datasets') + 3]).name

            # Obtain the list of images for the current sample document.
            samples_from_document = list(i.with_suffix('') for i in metadata_file.parent.glob('*.png'))
            samples_by_provider[provider].extend(samples_from_document)

        # Split providers into training, validation and test sets.
        for provider, files in samples_by_provider.items():
            if len(files) == 1:
                self.train_set.extend(files)
            elif len(files) == 2:
                self.train_set.append(files[0])
                self.val_set.append(files[1])
            else:
                train_set, val_set, test_set = self._split_train_validation_test(files)
                self.train_set.extend(train_set)
                self.val_set.extend(val_set)
                self.test_set.extend(test_set)

    def _get_synthetic_samples_list(self):
        """
        Finds all the samples in a dataset with synthetic data and splits them into training,
        validation and test sets based on the synthetic template used to create it. The split
        percentages are applied to each template case, and not directly to the full dataset.
        """

        samples_by_template = defaultdict(list)
        metadata_files = self.dataset_path.glob('**/*.file_metadata.json')

        for metadata_file in metadata_files:
            with open(metadata_file, 'r') as fp:
                json_content = json.load(fp)

            template_name = json_content['template_name']
            samples_by_template[template_name].append(metadata_file.parent / metadata_file.parent.name)

        for template_name, files in samples_by_template.items():
            if len(files) == 1:
                self.train_set.extend(files)
            elif len(files) == 2:
                self.train_set.append(files[0])
                self.val_set.append(files[1])
            else:
                train_set, val_set, test_set = self._split_train_validation_test(files)
                self.train_set.extend(train_set)
                self.val_set.extend(val_set)
                self.test_set.extend(test_set)
        
    def _normalize_image(self, im: Image) -> torch.Tensor:
        """ Normalizes the given image to the expected tensor format for Donut. """

        im = self._fix_image_orientation(im)

        # First dimension is removed, as it refers to the batch. Batches are automatically created by DataLoader.
        im_tensor = processor(im, return_tensors='pt').pixel_values.squeeze()
        return im_tensor

    @staticmethod
    def _fix_image_orientation(image: Image) -> Image:
        """
        Uses pytesseract to get some image metadata, then get the page orientation, and finally rotate the image
        when needed. If pytesseract cannot get the orientation, then the image is returned with no changes.

        :return:
        """
        try:
            im_metadata = pytesseract.image_to_osd(image)
            angle = 360 - int(re.search(r'(?<=Rotate: )\d+', im_metadata).group(0))
            return image.rotate(angle, expand=1)
        except Exception:
            # Tesseract failed. Return image in its original format.
            return image

    def _normalize_tagging(self, json_tokens) -> tuple[torch.Tensor, str]:
        """ Converts the JSON content to tokens that can be fed to Donut. """
        target_sequence = self.start_token + json_tokens + self.end_token

        # Remove first dimension from tensor, as it refers to the batch. Batches are automatically created by DataLoader.
        tokenizer_response = self.processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length=self.model_config.decoder.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )['input_ids'].squeeze(0)
        
        labels = tokenizer_response.clone()
        
        # Replace all pad tokens by the ignore token.
        labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id

        return labels.squeeze(0), target_sequence

    @staticmethod
    def recover_from_tensor(t: torch.Tensor) -> Image:
        """ Utility method to get back the image from a normalized tensor. Use only for debugging. """
        
        arr = t.cpu().numpy()

        # Donut applies mean and std transformations to image values. Reverting.
        mean = np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1)
        std = np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1)
        arr = (arr * std + mean) * 255.0

        # Limit values to uint8 range, which is the expected for images (0 to 255).
        arr = arr.clip(0, 255).astype(np.uint8)

        # [layers, width, height] -> [width, height, layers] (expected format for PIL images).
        arr = np.transpose(arr, (1, 2, 0))
        return Image.fromarray(arr)

    def json2token(self, obj, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """ 
        Convert an ordered JSON object into a token sequence. 
        Source: https://github.com/clovaai/donut/blob/4cfcf972560e1a0f26eb3e294c8fc88a0d336626/donut/model.py#L499
        """
        
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        # Add to the list of special tokens detected.
                        self.new_special_tokens.add(fr"<s_{k}>")
                        self.new_special_tokens.add(fr"</s_{k}>")
                    output += (
                            fr"<s_{k}>"
                            + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                            + fr"</s_{k}>"
                    )
                return output
        elif type(obj) == list:
            return r"<sep/>".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            # excluded special tokens for now
            obj = str(obj)
            if f"<{obj}/>" in self.new_special_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj

# Creating datasets

* A separate dataset is created for each split.
* New tokens detected in all datasets will be added to Donut vocabulary.

In [None]:
# UPDATE WITH YOUR DATASET'S PATH
DATASET_PATH = Path('synth_data')
excluded_keys = [
    "drom",
    "corp",
    "ltype",
    "ctype",
    "atype",
    "metgr",
    "id",
    "po",
    "met",
    "sku",
    "products"
]

training_dataset = DonutDataset(DATASET_PATH, processor, model_config, split='train', start_token="<s_cord-v2>", end_token="</s_cord-v2>", excluded_keys=excluded_keys)
validation_dataset = DonutDataset(DATASET_PATH, processor, model_config, split='val', start_token="<s_cord-v2>", end_token="</s_cord-v2>", excluded_keys=excluded_keys)
test_dataset = DonutDataset(DATASET_PATH, processor, model_config, split='test', start_token="<s_cord-v2>", end_token="</s_cord-v2>", excluded_keys=excluded_keys)

# Gather all new tokens
all_new_tokens = set()
all_new_tokens.update(training_dataset.new_special_tokens)
all_new_tokens.update(validation_dataset.new_special_tokens)
all_new_tokens.update(test_dataset.new_special_tokens)
all_new_tokens_list = list(all_new_tokens)

# Add new tokens to Donut processor and model.
newly_added_num = processor.tokenizer.add_tokens(all_new_tokens_list)
if newly_added_num > 0:
    model.decoder.resize_token_embeddings(len(processor.tokenizer))

# Allow special tokens you use in output (e.g. task start/end), block others
bad_words = [id for tok, id in processor.tokenizer.get_vocab().items() if tok in processor.tokenizer.all_special_tokens and tok not in ["<s_cord-v2>", "</s_cord-v2>"]]
bad_words_ids = [[token_id] for token_id in bad_words]

# Show final split percentages.
total_samples = len(training_dataset) + len(validation_dataset) + len(test_dataset)
print(f'Training set:   {len(training_dataset)} ({len(training_dataset) * 100// total_samples} %)')
print(f'Validation set: {len(validation_dataset)} ({len(validation_dataset) * 100// total_samples} %)')
print(f'Test set:       {len(test_dataset)} ({len(test_dataset) * 100// total_samples} %)')

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE=4
train_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

# Define PyTorch Lightning module

Pytorch Lightning is a PyTorch version that gives special focus to modularization, making it a bit more easy to configure than its base version.
Here, we'll create a LightningModule to define the training behaviour of our model. 

In [None]:
from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math

from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


class DonutModelPLModule(pl.LightningModule):
    def __init__(self, config, processor, model, batch_size, max_length, train_dataloader_obj, val_dataloader_obj):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.batch_size = batch_size
        self.max_length = max_length
        self.train_dataloader_obj = train_dataloader_obj
        self.val_dataloader_obj = val_dataloader_obj

        self.validation_step_outputs = []

    def training_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']

        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log("train_loss", loss, on_step=True, on_epoch=True, sync_dist=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, dataset_idx=0):
        pixel_values = batch['pixel_values']
        answers = [a.strip().lower() for a in batch['target_str_seq']]

        batch_size = pixel_values.size(0)

        decoder_input_ids = torch.full(
            (batch_size, 1),
            self.model.config.decoder_start_token_id,
            device=self.device
        )

        outputs = self.model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=self.max_length,
            early_stopping=True,
            pad_token_id=self.processor.tokenizer.pad_token_id,
            eos_token_id=self.processor.tokenizer.eos_token_id,
            num_beams=1,
            bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True
        )

        decoded = self.processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
        predictions = [d.strip().lower() for d in decoded]

        # Metrics
        metrics = self._compute_metrics(predictions, answers)

        print("Prediction vs Answer (first sample in batch):")
        print("\n".join([f"\t{k}: {v}" for k, v in metrics.items()]))
        print(f"    Prediction: {predictions[0]}")
        print(f"    Reference : {answers[0]}")

        self.validation_step_outputs.append(metrics)

        return {
            "predictions": predictions,
            "references": answers,
            **metrics
        }

    def _compute_metrics(self, predictions, answers) -> dict:
        num_degenerated_samples = 0
        matching_keys_ratios = list()
        edit_distance_values = list()

        for prediction, answer in zip(predictions, answers):
            try:
                tree_pred = self._etree_to_dict(ET.fromstring(prediction))
                tree_ans = self._etree_to_dict(ET.fromstring(answer))

                matching_keys_ratio, edit_distance = self._find_matching_keys_and_values(tree_pred, tree_ans)
                matching_keys_ratios.append(matching_keys_ratio)
                edit_distance_values.append(edit_distance)
            except ET.ParseError:
                num_degenerated_samples += 1

        avg_matching_keys_ratio = sum(matching_keys_ratios) / len(matching_keys_ratios) if len(matching_keys_ratios) > 0 else 0.0
        avg_edit_distance = sum(edit_distance_values) / len(edit_distance_values) if len(edit_distance_values) > 0 else 0.0
        return {
            "degeneration_ratio": num_degenerated_samples / len(predictions),
            "extraction_similarity_ratio": avg_edit_distance,
            "matching_keys_ratio": avg_matching_keys_ratio
        }

    def _etree_to_dict(self, t) -> dict:
        """
        Source: https://stackoverflow.com/questions/7684333/converting-xml-to-dictionary-using-elementtree
        """
        d = {t.tag: {} if t.attrib else None}
        children = list(t)
        if children:
            dd = defaultdict(list)
            for dc in map(self._etree_to_dict, children):
                for k, v in dc.items():
                    dd[k].append(v)
            d = {t.tag: {k: v[0] if len(v) == 1 else v
                for k, v in dd.items()}}
        if t.attrib:
            d[t.tag].update(('@' + k, v)
                for k, v in t.attrib.items())
        if t.text:
            text = t.text.strip()
            if children or t.attrib:
                if text:
                    d[t.tag]['#text'] = text
            else:
                d[t.tag] = text
        return d

    def _find_matching_keys_and_values(self, pred: dict, ans: dict):
        if 's_cord-v2' in ans:
            ans = ans['s_cord-v2']

        matching_keys = 0
        edit_distance_ratios = []
        total_keys = 0
        # Shared keys
        for expected_key in ans:
            if expected_key in pred:
                matching_keys += 1
                edit_distance_ratios.append(self._compute_similarity(pred[expected_key], ans[expected_key]))
            total_keys += 1

        # Product keys
        if isinstance(ans.get('s_products'), list) and  isinstance(pred.get('s_products'), list):
            for pred_product, ans_product in zip(pred['s_products'], ans['s_products']):
                for product_key in ans_product:
                    if product_key in pred_product:
                        matching_keys += 1
                        edit_distance_ratios.append(self._compute_similarity(pred_product[product_key], ans_product[product_key]))
                    total_keys += 1

        avg_edit_distance_ratio = sum(edit_distance_ratios) / len(edit_distance_ratios)
        matching_keys_ratio = matching_keys / total_keys

        return matching_keys_ratio, avg_edit_distance_ratio

    @staticmethod
    def _compute_similarity(pred_value, gt_value):
        return 1.0 - (edit_distance(pred_value, gt_value) / max(len(pred_value), len(gt_value)))

    def on_validation_epoch_end(self) -> None:
        all_degeneration_ratios = []
        all_similarity_ratios = []
        all_matching_keys_ratios = []

        for track in self.validation_step_outputs:
            all_degeneration_ratios.append(track['degeneration_ratio'])
            all_similarity_ratios.append(track['extraction_similarity_ratio'])
            all_matching_keys_ratios.append(track['matching_keys_ratio'])

        avg_degeneration_ratio = float(np.mean(all_degeneration_ratios))
        avg_similarity_ratio = float(np.mean(all_similarity_ratios))
        avg_matching_keys_ratios = float(np.mean(all_matching_keys_ratios))

        self.log("degeneration_ratio", avg_degeneration_ratio, sync_dist=True)
        self.log("similarity_ratio", avg_similarity_ratio, sync_dist=True)
        self.log("matching_keys_ratios", avg_matching_keys_ratios, sync_dist=True)

        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        # TODO: Add a learning rate scheduler
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
        return optimizer

    def train_dataloader(self):
        return self.train_dataloader_obj

    def val_dataloader(self):
        return self.val_dataloader_obj

# Set training configuration parameters

## Parameters list:

* `output_dir`: local folder in where the training output should be saved.
* `num_train_epochs`: An epoch refers to one complete pass through the entire training dataset. Multiple epochs allow the model to improve generalization, adjusting better to the seen data. Too much epochs can lead to overfitting. More info [here](https://machinelearningmastery.com/difference-between-a-batch-and-an-epoch/#:~:text=at%20an%20epoch.-,What%20Is%20an%20Epoch%3F,-The%20number%20of)
* `learning_rate`: controls how fast a model can adjust its parameters. Low values can cause a slow training, while big values can cause unstable results. More info [here](https://machinelearningmastery.com/understand-the-dynamics-of-learning-rate-on-deep-learning-neural-networks/#:~:text=The-,learning%20rate,-controls%20how%20quickly)
* `per_device_train_batch_size`: Batch size. A batch is the number of samples that the model sees before updating its weights. It helps to reduce variance, but high values imply fore VRAM usage. Especially, when dealing with big images, this can produce memory explosion. More info [here](https://machinelearningmastery.com/gentle-introduction-mini-batch-gradient-descent-configure-batch-size/#:~:text=batch%20gradient%20descent.-,How%20to%20Configure%20Mini%2DBatch%20Gradient%20Descent,-Mini%2Dbatch%20gradient)

In [None]:
config = {
    "max_epochs": 80,
    "val_check_interval": 1.0,
    "check_val_every_n_epoch": 1,
    "gradient_clip_val": 1.0,
    "lr": 2e-6,
    "num_nodes": 1,
    "precision": "16-mixed",
    "log_every_n_steps": 1
}

In [None]:
# from huggingface_hub import login

# login()  # Use itam-franmr account token

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, EarlyStopping

wandb_logger = WandbLogger(project="Donut", name="demo_run_donut_synth_v2")


class PushToHubCallback(Callback):
    def __init__(self, logger, hf_repo, ckpt_dirpath):
        super().__init__()
        self.logger = logger
        self.hf_repo = hf_repo
        self.ckpt_dirpath = ckpt_dirpath
        self.api = HfApi()

    def _push_checkpoint_file(self):
        if os.path.exists(self.ckpt_dirpath):
            ckpt_name = self._get_last_checkpoint(self.ckpt_dirpath)
            ckpt_path = os.path.join(self.ckpt_dirpath, ckpt_name)
            print("Checkpoint selected:", ckpt_path)
            self._remove_old_checkpoints_from_hub()
            self.api.upload_file(
                path_or_fileobj=ckpt_path,
                path_in_repo=os.path.join('checkpoints', ckpt_name),
                repo_id=self.hf_repo,
                repo_type='model',
                commit_message=f"Checkpoint saved: {ckpt_name}"
            )
        else:
            self.logger.warning("No checkpoints available.")

    def _remove_old_checkpoints_from_hub(self):
        all_files = self.api.list_repo_files(repo_id=self.hf_repo, repo_type='model')
        ckpt_files = [f for f in all_files if 'checkpoints/epoch' in f]

        for ckpt_file in ckpt_files:
            print("Removing old checkpoint:", ckpt_file)

            self.api.delete_file(
                path_in_repo=ckpt_file,
                repo_id=self.hf_repo,
                repo_type='model',
                commit_message=f'Removed old checkpoint "{ckpt_file}"',
            )

    def _get_last_checkpoint(self, ckpt_dir: str):
        ckpts = os.listdir(ckpt_dir)
        print("Checkpoints saved:", ckpts)
        if not len(ckpts):
            return None
        elif len(ckpts) == 1:
            return ckpts[0]
        else:
            epoch_step_pairs = []
            for ckpt in ckpts:
                matches = re.match(r"epoch=(\d*)_step=(\d*).ckpt", ckpt)
                if matches:
                    epoch, step = map(lambda x: int(x), matches.groups())
                    epoch_step_pairs.append([epoch, step])
            best_ckpt_values = self._find_highest_list(epoch_step_pairs)
            best_ckpt_idx = epoch_step_pairs.index(best_ckpt_values)
            return ckpts[best_ckpt_idx]
            
    def on_train_epoch_end(self, trainer, pl_module):
        self.logger.info(f"Pushing model to the Hub [epoch {trainer.current_epoch}]")
        self._push_checkpoint_file()

    def on_train_end(self, trainer, pl_module):
        self.logger.info(f"Pushing model to the hub after training.")
        pl_module.processor.push_to_hub(self.hf_repo, commit_message="Training Finished!")
        pl_module.model.push_to_hub(self.hf_repo, commit_message="Training Finished!")

    def _find_highest_list(self, lists: list, curr_idx=0):
        if len(lists) == 0:
            return []

        lists_lengths = [len(i) for i in lists]

        if not all(x == lists_lengths[0] for x in lists_lengths):
            raise ValueError("All lists must have the same length")

        length = lists_lengths[0]

        lists_idxs_with_max_value = []
        curr_max = -1

        for i, l in enumerate(lists):
            if l[curr_idx] > curr_max:
                curr_max = l[curr_idx]
                lists_idxs_with_max_value.clear()
                lists_idxs_with_max_value.append(i)
            elif l[curr_idx] == curr_max:
                lists_idxs_with_max_value.append(i)

        if len(lists_idxs_with_max_value) == 1:
            return lists[lists_idxs_with_max_value[0]]
        else:
            curr_idx += 1
            if curr_idx > length - 1:
                return lists[lists_idxs_with_max_value[0]]

            return self._find_highest_list([lists[j] for j in lists_idxs_with_max_value], curr_idx)


model_checkpoint_dirpath = 'checkpoints'
self.push_to_hub_callback = PushToHubCallback(
    self.logger,
    self.hf_repo,
    model_checkpoint_dirpath
)

model_checkpoint_callback = ModelCheckpoint(
    dirpath=model_checkpoint_dirpath,
    filename='{epoch}_{step}',
    save_top_k=1,
    every_n_epochs=1,
    save_weights_only=False
)

# early_stop_callback = EarlyStopping(monitor="val_edit_distance", patience=3, verbose=False, mode="min")

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=config.get("max_epochs"),
    val_check_interval=config.get("val_check_interval"),
    check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
    gradient_clip_val = config.get("gradient_clip_val"),
    precision=config.get("precision"),
    num_sanity_val_steps=0,
    logger=wandb_logger,
    callbacks=[PushToHubCallback()],
    log_every_n_steps=config.get('log_every_n_steps')
)

In [None]:
model_module = DonutModelPLModule(config, processor, model)
trainer.fit(model_module)

# Evaluate model

In [None]:
# from transformers import DonutProcessor, VisionEncoderDecoderModel

# processor = DonutProcessor.from_pretrained("itam-franmr/donut_clem")
# model = VisionEncoderDecoderModel.from_pretrained("itam-franmr/donut_clem")

In [None]:
!pip install -q donut-python

In [None]:
import re
import json
import torch
from tqdm.auto import tqdm
import numpy as np

from donut import JSONParseEvaluator

model.eval()  # Activates the evaluation mode

output_list = []
accs = []

for idx, sample in tqdm(enumerate(validation_dataset), total=len(validation_dataset)):
    # Prepare encoder inputs
    pixel_values = sample["pixel_values"]
    pixel_values = pixel_values.unsqueeze(0)

    # Prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids

    # Autoregressively generate sequences
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        # bad_words=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # Turn response into JSON format
    seq = processor.batch_decode(outputs.sequences)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    seq = re.sub(r"<.*?>", "", seq, count=1).strip()
    seq = processor.token2json(seq)

    ground_truth = sample["target_str_seq"]
    evaluator = JSONParseEvaluator()
    score = evaluator.cal_acc(seq, ground_truth)

    accs.append(score)
    output_list.append(seq)

scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length: {len(accs)}")
print(f"Mean accuracy: {scores['mean_accuracy']}")

# Evaluate model

In [None]:
import re
import transformers
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import random

# Load processor and model
processor = DonutProcessor.from_pretrained(model_repository_name)
model = VisionEncoderDecoderModel.from_pretrained(model_repository_name)

# Load random document image from test set
test_sample = test_dataset[random.randint(0, len(test_dataset))]

pixel_values = sample['pixel_values'].unsqueeze(0)
task_prompt = "<s_cord-v2>"
decoder_input_ids = sample['input_ids']

In [None]:
import random
# Load random document image from test set
test_sample = test_dataset[random.randint(0, len(test_dataset))]

pixel_values = test_sample['pixel_values'].unsqueeze(0)
task_prompt = "<s_cord-v2>"
decoder_input_ids = test_sample['labels']

In [None]:
outputs = model.generate(
    pixel_values,
    decoder_input_ids=decoder_input_ids,
    max_length=model.decoder.config.max_position_embeddings,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=1,
    bad_words_ids=bad_words_ids,
    return_dict_in_generate=True,
)

In [None]:
# Process output
prediction = processor.batch_decode(outputs.sequences)[0]
prediction = processor.token2json(prediction)

# Load reference target
target = processor.token2json(test_sample['output_target_seq'])

print(f"Reference:\n {target}")
print(f"Prediction:\n {prediction}")