## Fine-tune Idefics2 for document parsing (PDF -> JSON)

In this notebook, we are going to fine-tune the [Idefics2](https://huggingface.co/docs/transformers/main/en/model_doc/idefics2) model for a document AI use case. Idefics2 is one of the best open-source multimodal models at the time of writing, developed by Hugging Face. Idefics started as a replication of Deepmind's Flamingo model, and the second iteration incorporates a lot of advancements in the field such as [NaViT](https://arxiv.org/abs/2307.06304) patching, [Perceiver](https://arxiv.org/abs/2103.03206) resampling and more. However, explaining how Idefics2 works would desire its own video which I might put on YouTube!

The goal for the model is to generate a JSON that contains key fields (like food items and their corresponding prices) from receipts. We will fine-tune Idefics2 on the [CORD](https://huggingface.co/datasets/naver-clova-ix/cord-v2) dataset, which contains (receipt image, ground truth JSON) pairs.

Sources:

* Idefics2 [blog post](https://huggingface.co/blog/idefics2)
* Idefics2 [DocVQA notebook](https://colab.research.google.com/drive/1NtcTgRbSBKN7pYD3Vdx1j9m8pt3fhFDB?usp=sharing#scrollTo=06CMDrH7Kkdy) on which this notebook is based 

Note: this notebook is a direct adaptation of my [original Donut notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb) for Idefics2. You can view Idefics2 as a more powerful Donut model.

## Load dataset

Let's start by loading the dataset from the hub. Here we use the [CORD](https://huggingface.co/datasets/naver-clova-ix/cord-v2) dataset, created by the [Donut](https://huggingface.co/docs/transformers/en/model_doc/donut) authors (Donut is another powerful - but slightly undertrained document AI model available in the Transformers library). CORD is an important benchmark for receipt understanding. The Donut authors have prepared it in a format that suits vision-language models: we're going to fine-tune it to generate the JSON given the image.

If you want to load your own custom dataset, check out this guide: https://huggingface.co/docs/datasets/image_dataset.

In [None]:
from datasets import load_dataset

dataset = load_dataset("naver-clova-ix/cord-v2")

Let's check out the dataset:

In [None]:
dataset

As oftentimes, we get a `DatasetDict` which is a dictionary containing 3 splits, one for training, validation and testing. Each split has 2 features, an image and a corresponding ground truth.

 Let's check the first training example:

In [None]:
example = dataset['train'][0]
image = example["image"]
# resize image for smaller displaying
width, height = image.size
image = image.resize((int(0.3*width), int(0.3*height)))
image

Let's check the corresponding ground truth:

In [None]:
example["ground_truth"]

Cool! So this contains a ground truth parsing that we want the model to output given an image. We can read it as json:

In [None]:
import json

json.loads(example["ground_truth"])

## Load processor

Next, we'll load the processor which is used to prepare the data in the format that the model expects. Neural networks like Idefics2 don't directly take images and text as input, but rather `pixel_values` (which is a resized, rescaled, normalized and optionally splitted version of the receipt images), `input_ids` (which are text token indices in the vocabulary of the model), etc. This is handled by the processor.

### Image splitting

Idefics2's processor has a setting called `do_image_splitting` which can be set to `True`/`False`. This defines how images are prepared for the model, either it will just create 1 image (with `do_image_splitting=False`) or it will create multiple (by splitting the images into multiple patches and also including the original image).

This has an effect on the amount of memory that's going to be used during training: if we use image splitting we'll encounter more memory usage (as several images are created for each receipt image). We'll use the memory friendly version here. Do note that this has an effect on performance; performance is typically higher with the `do_image_splitting=True` setting. The latter is also how the model which we're going to use (https://huggingface.co/HuggingFaceM4/idefics2-8b) was trained.

See also from the [model card](https://huggingface.co/HuggingFaceM4/idefics2-8b):

> do_image_splitting=True is especially needed to boost performance on OCR tasks where a very large image is used as input. For the regular VQA or captioning tasks, this argument can be safely set to False with minimal impact on performance (see the evaluation table above).

### Image resolution

Additionally, one can decrease the maximum image resolution used during training to decrease memory usage. To do so, add `size= {"longest_edge": 448, "shortest_edge": 378}` when initializing the processor (`AutoProcessor.from_pretrained`). In particular, the longest_edge value can be adapted to fit the need (the default value is 980). We recommend using values that are multiples of 14. There are no changes required on the model side.

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)

## Load model

Next, we're going to load the Idefics2 model from the [hub](https://huggingface.co/HuggingFaceM4/idefics2-8b). This is a model with 8 billion trainable parameters. Do note that we load a model here which already has undergone supervised fine-tuning (SFT). The pure pre-trained model (also called "base model") is available here: https://huggingface.co/HuggingFaceM4/idefics2-8b-base. We can benefit from the fine-tuning that the model already has undergone.

Do note that the Idefics2 team is also going to release a chatty version of Idefics2 optimized for chatbot/AI assistant use cases. However in our case we do not care about the chatty aspect, we just want the model to generate perfect JSON given an image of a receipt.

### Full fine-tuning, LoRa and Q-LoRa

As this model has 8 billion trainable parameters, that's going to have quite an impact on the amount of memory used. For reference, fine-tuning a model using the [AdamW optimizer](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW) (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x8 billion bytes = 144 GB of GPU RAM if we want to update all the parameters of the model!! That's huge right? And for most people infeasible.

Luckily, some clever people came up with the [LoRa](https://huggingface.co/docs/peft/main/en/conceptual_guides/lora) method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods (that's where the name PEFT comes from).

Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.

This means that we're going to shrink the size of the base Idefics2-8b model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.

Of course, if you have the memory available, feel free to use full fine-tuning or LoRa without quantization! In case of full fine-tuning, the code snippet below instantiates the model with Flash Attention which considerably speeds up computations.

There exist many forms of quantization, here we leverage the [BitsAndBytes](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig) integration.

In [None]:
from transformers import BitsAndBytesConfig, Idefics2ForConditionalGeneration
from peft import LoraConfig
import torch


DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True

## Load model

# Three options for training, from the lowest precision training to the highest precision training:
# - QLora
# - Standard Lora
# - Full fine-tuning
if USE_QLORA or USE_LORA:
    lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian",
    )
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
        )
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.float16,
        quantization_config=bnb_config if USE_QLORA else None,
    )
    model.add_adapter(lora_config)
    model.enable_adapters()
else:
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.float16,
        _attn_implementation="flash_attention_2",  # Only available on A100 or H100
    ).to(DEVICE)

## Create PyTorch dataset

Next we'll create a regular [PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). For that, one needs to implement 3 methods: an `init` method, a `len` method (which returns the length of the dataset) and a `getitem` method (which returns items of the dataset).

The `init` method implements 2 things:
* it goes over all the ground truth JSON sequences and turns them into token sequences (which we want the model to generate) using the `json2token` method
* it adds special tokens to the model/tokenizer using the `add_tokens` method for which the model will learn an embedding vector. By doing this, keys which occur in the JSON sequences (like `<menu>` in our case) will get their own token (and corresponding embedding), whereas otherwise these might have been split up into multiple tokens. Do note that I haven't quantified the performance difference regarding this, not sure it helps a lot but the Donut authors did this so I assume it must benefit training.

Typically, one uses the processor in the `getitem` method to prepare the data in the format that the model expects, but we'll postpone that here for a reason we'll explain later. In our case we're just going to return 2 things: the image and a corresponding ground truth token sequence.

In [None]:
import random
from typing import Any, List
from torch.utils.data import Dataset

added_tokens = []

class CustomDataset(Dataset):
    def __init__(
        self,
        hf_dataset,
        split,
        sort_json_key: bool = True,
    ):
        self.dataset = hf_dataset[split]
        self.split = split
        self.sort_json_key = sort_json_key

        ground_truth_token_sequences = []
        for sample in self.dataset:
            ground_truth = json.loads(sample["ground_truth"])
            if "gt_parses" in ground_truth:  # some datasets have multiple ground truths available, e.g. DocVQA
                assert isinstance(ground_truth["gt_parses"], list)
                ground_truth_jsons = ground_truth["gt_parses"]
            else:
                assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
                ground_truth_jsons = [ground_truth["gt_parse"]]

            ground_truth_token_sequences.append(
                [
                    self.json2token(
                        ground_truth_json,
                        update_special_tokens_for_json_key=self.split == "train",
                        sort_json_key=self.sort_json_key,
                    )
                    for ground_truth_json in ground_truth_jsons  # load json from list of json
                ]
            )

        self.ground_truth_token_sequences = ground_truth_token_sequences

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([rf"<s_{k}>", rf"</s_{k}>"])
                    output += (
                        rf"<s_{k}>"
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + rf"</s_{k}>"
                    )
                return output
        elif type(obj) == list:
            return r"<sep/>".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj

    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        # get the image and corresponding target token sequence
        image = example["image"]
        target_sequence = random.choice(self.ground_truth_token_sequences[idx])  # can be more than one, e.g., DocVQA

        return image, target_sequence

Let's instantiate the PyTorch datasets:

In [None]:
train_dataset = CustomDataset(hf_dataset=dataset, split="train")
eval_dataset = CustomDataset(hf_dataset=dataset, split="validation")

We can check one item of the dataset:

In [None]:
train_dataset[0]

Let's also check how many tokens were added:

In [None]:
print(f"Added {len(added_tokens)} tokens to the model/tokenizer")
print(added_tokens)

You can verify that special tokens are now known by the tokenizer:

In [None]:
processor.tokenizer.tokenize("</s_total_price>")

## Define DataCollator

Now that we have a PyTorch dataset, we'll define a so-called collator which defines how items of the dataset should be batched together. This is because we tyipcally train neural networks on batches of data (i.e. various images/target sequences combined) rather than one-by-one, using a variant of stochastic-gradient descent or SGD (like Adam, AdamW, etc.).

It's only here that we're going to use the processor to turn the (image, target token sequence) into the format that the model expects (which is `pixel_values`, `input_ids` etc.). The reason we do that here is because it allows for **dynamic padding** of the batches: each batch contains images of various resolutions (as Idefics2 preserves the aspect ratio of images). By only using the processor here, we will pad the pixel values up to the largest one in each batch (rather than padding them all to a fixed resolution upfront).

Important here is that we are calling `apply_chat_template`, which applies a so-called chat template which turns the inputs into the format that the model expects. This is VERY important as this is how the model expects inputs to be formatted!! Read all about it here: https://huggingface.co/docs/transformers/main/en/chat_templating. We'll use the text prompt "Extract JSON." which is also going to be used at inference time.

We also decide to limit the length of the text tokens (`input_ids`) to 200 due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length to determine the optimal value).

Labels are created for the model by simply copying the inputs to the LLM (`input_ids`), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together). For Idefics2, this is the `image_token_id` as can be seen [here](https://github.com/huggingface/transformers/blob/6f465d45d98f9eaeef83cfdfe79aecc7193b0f1f/src/transformers/models/idefics2/modeling_idefics2.py#L1860).

Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen [here](https://github.com/huggingface/transformers/blob/6f465d45d98f9eaeef83cfdfe79aecc7193b0f1f/src/transformers/models/idefics2/modeling_idefics2.py#L1851-L1855).

In [None]:
class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            image, ground_truth = example
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Extract JSON."},
                        {"type": "image"},
                    ],
                },
                {"role": "assistant", "content": [{"type": "text", "text": ground_truth}]},
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text.strip())
            images.append([image])

        batch = processor(
            text=texts, images=images, return_tensors="pt", padding=True,
        )

        batch = processor(text=texts, images=images, padding="max_length", max_length=200, truncation=True, return_tensors="pt")

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels

        return batch
    
data_collator = MyDataCollator(processor)

## Define evaluation metrics

During training, we'd like to compute some metrics on our evaluation set. For that we'll implement a `compute_metrics` function which takes in a list of predicted token sequences and a list of target token sequences, and we'll calculate the so-called Levenhstein edit distance. This quantifies how much we would need to edit the predicted token sequence to get the target sequence (the fewer edits the better!). Its optimal value is 0 (which means, no edits need to be made).

In [None]:
import Levenshtein
import numpy as np


def normalized_levenshtein(s1, s2):
    len_s1, len_s2 = len(s1), len(s2)
    distance = Levenshtein.distance(s1, s2)
    return distance / max(len_s1, len_s2)


def similarity_score(a_ij, o_q_i, tau=0.5):
    nl = normalized_levenshtein(a_ij, o_q_i)
    return 1 - nl if nl < tau else 0


def average_normalized_levenshtein_similarity(ground_truth, predicted_answers):
    assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match."

    N = len(ground_truth)
    total_score = 0

    for i in range(N):
        a_i = ground_truth[i]
        o_q_i = predicted_answers[i]
        if o_q_i == "":
            print("Warning: Skipped an empty prediction.")
            max_score = 0
        else:
            max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i)

        total_score += max_score

    return total_score / N


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    # Replace -100s used for padding as we can't decode them
    preds = np.where(preds != -100, preds, processor.tokenizer.pad_token_id)
    decoded_preds = processor.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)

    print("Decoded predictions:", decoded_preds)
    print("Decoded labels:", decoded_labels)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    score = average_normalized_levenshtein_similarity(decoded_labels, decoded_preds)
    result = {"levenshtein": score}

    prediction_lens = [np.count_nonzero(pred != processor.tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

## Define training arguments

There are various options to train a PyTorch model: one could just write a training loop themselves, use the [Trainer API](https://huggingface.co/docs/transformers/en/main_classes/trainer), use frameworks like PyTorch Lightning, etc.

In this notebook, we'll use the `Seq2SeqTrainer` class, which is optimized for seq2seq models like T5 or BART (but will also work well in our case as we'll see). This class requires us to define `Seq2SeqTrainingArguments`. These define all hyperparameters regarding training. There are a lot more than the ones we define here, but most importantly we pass the batch size for training/evaluation, the number of warmup steps, the learning rate, how frequently we want to save the model, whether we want to use Weights and Biases logging, etc.

Do note that I have not performed any hyperparameter optimization whatsoever so this could definitely be improved.

See the full list here: https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments.

In [None]:
from transformers import Seq2SeqTrainingArguments, GenerationConfig

generation_config = GenerationConfig.from_pretrained("HuggingFaceM4/idefics2-8b", max_new_tokens=200)

training_args = Seq2SeqTrainingArguments(
    num_train_epochs=2,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    output_dir="idefics2_ft_tutorial",
    eval_strategy="epoch",
    save_strategy="steps",
    save_steps=250,
    save_total_limit=1,
    fp16=True,
    # push_to_hub_model_id="idefics2-8b-docvqa-finetuned-tutorial",
    remove_unused_columns=False,
    report_to="none",
    predict_with_generate=True,
    generation_config=generation_config,
)

## Train

The reason we use the `Seq2SeqTrainer` class is because it supports a `predict_with_generate` option, which allows us to use the [generate](https://huggingface.co/docs/transformers/v4.40.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) method (typically used at inference time) for evaluation.

We overwrite the `prediction_step` as we need to apply the chat template in order to prompt the model correctly (see also the inference section of this notebook).

In [None]:
from typing import Optional, Dict, Union, Tuple
from torch import nn
from transformers import Seq2SeqTrainer
import requests

# important: we need to disable caching during training
# otherwise the model generates past_key_values which is of type DynamicCache
model.config.use_cache = False

class Idefics2Trainer(Seq2SeqTrainer):
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
        **gen_kwargs,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            gen_kwargs:
                Additional `generate` specific kwargs.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """
        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        # Priority (handled in generate):
        # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
        if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
            gen_kwargs = self._gen_kwargs.copy()
        if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
            gen_kwargs.pop("num_beams")
        if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
            gen_kwargs.pop("max_length")

        default_synced_gpus = False
        gen_kwargs["synced_gpus"] = (
            gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
        )

        generation_inputs = inputs.copy()
        # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
        # (otherwise, it would continue generating from the padded `decoder_input_ids`)
        if (
            "labels" in generation_inputs
            and "decoder_input_ids" in generation_inputs
            and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
        ):
            generation_inputs = {
                k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
            }

        # here we need to overwrite the input_ids to only include the prompt
        processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)

        # use dummy image
        # we can do this since each image is always turned into 64 image tokens
        url = "https://upload.wikimedia.org/wikipedia/commons/f/f3/Zinedine_Zidane_by_Tasnim_03.jpg"
        test_image = Image.open(requests.get(url, stream=True).raw)

        # prepare prompt for the model
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Extract JSON."},
                    {"type": "image"},
                ],
            },
        ]
        prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
        processor_inputs = processor(text=prompt, images=[test_image], return_tensors="pt")
        custom_inputs = {}
        batch_size = generation_inputs["pixel_values"].shape[0]
        device = generation_inputs["pixel_values"].device
        custom_inputs["input_ids"] = processor_inputs.input_ids.repeat(batch_size, 1).to(
            device
        )  # repeat along batch dimension
        custom_inputs["attention_mask"] = processor_inputs.attention_mask.repeat(batch_size, 1).to(
            device
        )  # repeat along batch dimension
        custom_inputs["pixel_values"] = generation_inputs["pixel_values"]
        custom_inputs["pixel_attention_mask"] = generation_inputs["pixel_attention_mask"]

        generated_tokens = self.model.generate(**custom_inputs, **gen_kwargs)

        # Strip the prompt from the generated_tokens
        generated_tokens = generated_tokens[:, custom_inputs["input_ids"].size(1) :]

        # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
        # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
        # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
        if self.model.generation_config._from_model_config:
            self.model.generation_config._from_model_config = False

        # Retrieves GenerationConfig from model.generation_config
        gen_config = self.model.generation_config
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_config.max_length:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
        elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    outputs = model(**inputs)
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return loss, None, None

        if has_labels:
            labels = inputs["labels"]
            if labels.shape[-1] < gen_config.max_length:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
            elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
        else:
            labels = None

        return loss, generated_tokens, labels

    def _pad_tensors_to_max_len(self, tensor, max_length):
        pad_token_id = processor.tokenizer.pad_token_id

        padded_tensor = pad_token_id * torch.ones(
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor


trainer = Idefics2Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

Next it's finally time to train!

In [None]:
trainer.train()

## Inference

Let's see if the model has learned something. We'll take a receipt image of the test set here.

In [None]:
test_example = dataset["test"][0]
test_image = test_example["image"]
test_image

Next we need to prepare the image for the model, along with the text prompt we used during training. We need to apply the chat template to make sure the format is respected.

In [None]:
# prepare image and prompt for the model
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Extract JSON."},
            {"type": "image"},
        ]
    },
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
print(prompt)

Next we let the model autoregressively generate tokens using the generate() method, which is recommended for use at inference time. This method feeds each predicted token back into the model as conditioning for each next time step.

In [None]:
inputs = processor(text=prompt, images=[test_image], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

# Generate
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

print(generated_texts)

Based on the Donut model, we could write a `token2json` method which converts the generated token sequence into parsible JSON.

In [None]:
import re

# let's turn that into JSON
def token2json(tokens, is_inner_value=False, added_vocab=None):
        """
        Convert a (generated) token sequence into an ordered JSON format.
        """
        if added_vocab is None:
            added_vocab = processor.tokenizer.get_added_vocab()

        output = {}

        while tokens:
            start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
            if start_token is None:
                break
            key = start_token.group(1)
            key_escaped = re.escape(key)

            end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
            start_token = start_token.group()
            if end_token is None:
                tokens = tokens.replace(start_token, "")
            else:
                end_token = end_token.group()
                start_token_escaped = re.escape(start_token)
                end_token_escaped = re.escape(end_token)
                content = re.search(
                    f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL
                )
                if content is not None:
                    content = content.group(1).strip()
                    if r"<s_" in content and r"</s_" in content:  # non-leaf node
                        value = token2json(content, is_inner_value=True, added_vocab=added_vocab)
                        if value:
                            if len(value) == 1:
                                value = value[0]
                            output[key] = value
                    else:  # leaf nodes
                        output[key] = []
                        for leaf in content.split(r"<sep/>"):
                            leaf = leaf.strip()
                            if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
                                leaf = leaf[1:-2]  # for categorical special tokens
                            output[key].append(leaf)
                        if len(output[key]) == 1:
                            output[key] = output[key][0]

                tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
                if tokens[:6] == r"<sep/>":  # non-leaf nodes
                    return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)

        if len(output):
            return [output] if is_inner_value else output
        else:
            return [] if is_inner_value else {"text_sequence": tokens}

Let's print the final JSON!

In [None]:
generated_json = token2json(generated_texts[0])
print(generated_json)

In [None]:
for key, value in generated_json.items():
    print(key, value)