# Info

This notebook allows us to export some TrOCR models to the ONNX format.

NOTE: A lot of the code is from a notebook somebody else created for converting MBart models to ONNX, which is itself following fastT5, a library to convert T5 models to ONNX.
See: https://github.com/Ki6an/fastT5/issues/7

NOTE: Doubt this works for training (removed some logic relating to loss calculations). This should probably only be used for inference.
Additional NOTE: Model seems to perform poorly when optimized (not to be confused with quantized). Not sure why?

NOTE: Quantization seems to have a pretty bad effect on the model quality. Not sure if it's worth it. Only tested on single images.

Probably only understand ~50% of this code. Welcome to hell.

Some possible things to fix in the future:
- Giving proper names to some of the dynamic axes in `generate_onnx_representation`. Had to give some of them differing names, otherwise we get shape mismatching errors when trying to run inference.
- Adding the ability to return the loss in the ONNX models, for use in training.

Tested on `microsoft/trocr-small-printed` model.

In [None]:
import torch

from transformers import (
    AutoTokenizer,
    AutoConfig,
)
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
import functools
import operator
from pathlib import Path

In [None]:

run_export = True
run_quantize = False
run_inference = True

huggingface_model = 'microsoft/trocr-small-printed' # The Huggingface model name/path to load and export.

# The path to export all of the files to.
output_folder = './'

# The path of a sample image to ensure that the ONNX model is giving the proper output.
image_path = './sample.jpg'

In [None]:
def get_exported_paths():
    return f'{output_folder}/encoder.onnx', f'{output_folder}/decoder.onnx', f'{output_folder}/decoder_init.onnx'

def get_quantized_paths():
    return f'{output_folder}/encoder_q.onnx', f'{output_folder}/decoder_q.onnx', f'{output_folder}/decoder_init_q.onnx'

# Export to ONNX Utils

## Submodels

In [None]:
class VisionEncoder(torch.nn.Module):
    """ Creation of a class to output only the last hidden state from the encoder """

    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, *input, **kwargs):
        # self.encoder() returns a BaseModelOutput object
        # output[0] = last hidden state for encoder (this is what we care about)
        # output[1] = pooler state (not what we want for TrOCR)
        output = self.encoder(*input, **kwargs)
        return output[0]

In [None]:
class VisionDecoderWithLMhead(torch.nn.Module):
    def __init__(self, decoder, config):
        # Unlike the original notebook from the Github link, we don't (probably?) need to add the LM logit bias for TrOCR.
        # The LM head for TrOCR is already baked in the decoder thankfully.
        super().__init__()
        self.decoder = decoder
        self.config = config

    def forward(self, *inputs):
        input_ids, attention_mask, encoder_hidden_states = inputs[:3]

        # convert flattened past_key_values inputs to tuples that the
        # transformers decoder expects
        list_pkv = inputs[3:]
        past_key_values = tuple(list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4))

        decoder_output = self.decoder(
            input_ids=input_ids,
            encoder_attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            past_key_values=past_key_values,
            use_cache=True, # Some TrOCR models don't need this enabled (maybe it's already enabled for those models?), whereas others do.
        )

        return decoder_output[0], decoder_output[1]

In [None]:
class VisionDecoderWithLMheadInitial(torch.nn.Module):
    def __init__(self, decoder, config):
        super().__init__()
        self.decoder = decoder
        self.config = config

    def forward(self, input_ids, attention_mask, encoder_hidden_states):
        decoder_output = self.decoder(
            input_ids=input_ids,
            encoder_attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            use_cache=True,
        )

        return decoder_output[0], decoder_output[1]

## Setup Utils

In [None]:
def turn_model_into_encoder_decoder(model):
    encoder = model.get_encoder()
    decoder = model.get_decoder()

    # No need for the output embeddings - it's already baked in the decoder.

    simplified_encoder = VisionEncoder(encoder)
    decoder_with_lm_head = VisionDecoderWithLMhead(decoder, model.config)
    decoder_with_lm_head_init = VisionDecoderWithLMheadInitial(decoder, model.config)

    return simplified_encoder, decoder_with_lm_head, decoder_with_lm_head_init

In [None]:
import torch

def generate_onnx_representation(model, tokenizer, feature_extractor):
    (
        simplified_encoder,
        decoder_with_lm_head,
        decoder_with_lm_head_init,
    ) = turn_model_into_encoder_decoder(model)

    model_config = model.config

    encoder_path, decoder_path, init_decoder_path = get_exported_paths()

    # creating dummy inputs
    sample_encoder_input = torch.randn((3, 224, 224)) # C * H * W
    pixel_values = feature_extractor(sample_encoder_input, return_tensors='pt').pixel_values

    sample_decoder_input = "The universe is a dark forest."
    model_decoder_inputs = tokenizer(sample_decoder_input, return_tensors="pt")
    input_ids = model_decoder_inputs["input_ids"]

    batch_size = 5

    # TrOCR stores most of the important properties in the encoder and decoder configs.
    decoder_attention_heads = model_config.decoder.num_attention_heads
    d_model = model_config.decoder.hidden_size
    d_enc_model = model_config.encoder.hidden_size
    num_decoder_layers = model_config.decoder.num_hidden_layers

    n_heads = decoder_attention_heads
    seq_length_a, seq_length_b = input_ids.shape

    d_kv = d_model // decoder_attention_heads

    input_ids_dec = torch.ones((batch_size, 1), dtype=torch.int64)
    attention_mask_dec = torch.ones((batch_size, seq_length_b), dtype=torch.int64)
    enc_out = torch.ones(
        (batch_size, seq_length_b, d_enc_model), dtype=torch.float32
    )
    sa = torch.ones(
        (batch_size, n_heads, seq_length_a, d_kv), dtype=torch.float32
    )
    ca = torch.ones(
        (batch_size, n_heads, seq_length_b, d_kv), dtype=torch.float32
    )
    # (self attention keys, self attention values, cross attention keys, cross attention values)
    attention_block = (sa, sa, ca, ca)
    past_key_values = (attention_block,) * num_decoder_layers
    flat_past_key_values = functools.reduce(operator.iconcat, past_key_values, [])

    decoder_all_inputs = tuple(
        [input_ids_dec, attention_mask_dec, enc_out] + flat_past_key_values
    )
    num_of_inputs = 4 * num_decoder_layers

    # Exports to ONNX
    with torch.no_grad():

        # export decoder
        decoder_inputs = [
            "input_ids",
            "encoder_attention_mask",
            "encoder_hidden_states",
        ]
        pkv_input_names = [
            f"l{i//4}_{'self' if i % 4 < 2 else 'cross'}_{'keys' if not i % 2 else 'vals'}"
            for i in range(num_of_inputs)
        ]
        decoder_input_names = decoder_inputs + pkv_input_names

        decoder_outputs = ["logits"]
        pkv_output_names = [
            f"l{i//4}_{'self_out' if i % 4 < 2 else 'cross'}_{'keys' if not i % 2 else 'vals'}"
            for i in range(num_of_inputs)
        ]    
        decoder_output_names = decoder_outputs + pkv_output_names

        dyn_axis = {
            "input_ids": {0: "batch", 1: "input_ids_dim1"},
            "encoder_attention_mask": {0: "batch", 1: "encoder_attention_mask_dim1"},
            "encoder_hidden_states": {0: "batch", 1: "seq_length"},
            "logits": {0: "batch", 1: "seq_length"},
        }
        dyn_pkv = {
            name: {0: "batch", 2: "seq_length"} for name in pkv_input_names + pkv_output_names
        }
        dyn_axis_params = {**dyn_axis, **dyn_pkv}

        # encoder
        torch.onnx.export(
            simplified_encoder,
            args=(pixel_values),
            f=Path(encoder_path).as_posix(),
            export_params=True,
            opset_version=12,
            do_constant_folding=True,
            input_names=["pixel_values"],
            output_names=["hidden_states"],
            dynamic_axes={
                "pixel_values": {0: "batch", 1: "pixel_values_dim1"},
                "hidden_states": {0: "batch", 1: "seq_length"},
            },
        )

        # decoder
        torch.onnx.export(
            decoder_with_lm_head,
            decoder_all_inputs,
            Path(decoder_path).as_posix(),
            export_params=True,
            do_constant_folding=False,
            opset_version=12,
            input_names=decoder_input_names,
            output_names=decoder_output_names,
            dynamic_axes=dyn_axis_params,
        )
        
        # initial decoder to produce past key values
        torch.onnx.export(
            decoder_with_lm_head_init,
            (input_ids_dec, attention_mask_dec, enc_out),
            Path(init_decoder_path).as_posix(),
            export_params=True,
            do_constant_folding=False,
            opset_version=12,
            input_names=[
                "input_ids",
                "encoder_attention_mask",
                "encoder_hidden_states",
            ],
            output_names=decoder_output_names,
            dynamic_axes={
                "logits": {0: "batch", 1: "seq_length"},
                "input_ids": {0: "batch", 1: "input_ids_dim1"},
                "encoder_attention_mask": {0: "batch", 1: "encoder_attention_mask_dim1"},
                "encoder_hidden_states": {0: "batch", 1: "seq_length"},
                **{
                    name: {0: "batch", 2: "seq_length"} for name in pkv_output_names
                }
            },
        )

    return encoder_path, decoder_path, init_decoder_path

## Run Utils

In [None]:
if run_export:
    feature_extractor = ViTFeatureExtractor.from_pretrained(huggingface_model)
    tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
    model = VisionEncoderDecoderModel.from_pretrained(huggingface_model)

    onnx_model_paths = generate_onnx_representation(model, tokenizer, feature_extractor)

# Quantize

### Setup Utils

In [None]:
from onnxruntime.quantization import quantize_dynamic, QuantType

def quantize(onnx_model_path, output_path):
    quantize_dynamic(
        model_input=onnx_model_path,
        model_output=output_path,
        per_channel=True,
        weight_type=QuantType.QUInt8,
        reduce_range=True,
    )

    print('Quantized a model.')

## Run Utils

In [None]:
if run_quantize:
    enc, dec, dec_init = get_exported_paths()
    enc_out, dec_out, dec_init_out = get_quantized_paths()

    quantize(enc, enc_out)
    quantize(dec, dec_out)
    quantize(dec_init, dec_init_out)

# Inference

## Architecture

In [None]:
import torch
from transformers.modeling_outputs import (
    Seq2SeqLMOutput,
    BaseModelOutput,
)

class OnnxVisionEncoder(torch.nn.Module):
    def __init__(self, encoder_sess):
        super().__init__()
        self.encoder = encoder_sess

        self.main_input_name = 'pixel_values'

    def forward(
        self,
        pixel_values,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        encoder_hidden_state = torch.from_numpy(
            self.encoder.run(
                None,
                {
                    "pixel_values": pixel_values.cpu().numpy(),
                },
            )[0]
        )

        return BaseModelOutput(encoder_hidden_state)

In [None]:
class OnnxVisionDecoder(torch.nn.Module):
    def __init__(self, decoder_sess):
        super().__init__()
        self.decoder = decoder_sess

    def forward(self, input_ids, attention_mask, encoder_hidden_states, past_key_values, output_attentions=None, output_hidden_states=None, return_dict=None, encoder_attention_mask=None, **kwargs):

        decoder_inputs = {
            'input_ids': input_ids.detach().cpu().numpy(),
            'encoder_attention_mask': encoder_attention_mask.detach().cpu().numpy(),
            "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
        }
        flat_past_key_values = functools.reduce(operator.iconcat, past_key_values, [])
        
        input_names = [x.name for x in self.decoder.get_inputs()]
        inputs = [
            input_ids.detach().cpu().numpy(),
            encoder_attention_mask.cpu().numpy(),
        ] + [
            tensor.cpu().numpy() for tensor in flat_past_key_values
        ]

        decoder_inputs = dict(zip(input_names, inputs))
        decoder_outputs = self.decoder.run(None, decoder_inputs)
 
        list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
        out_past_key_values = tuple(
            list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)
        )

        return torch.from_numpy(decoder_outputs[0]), out_past_key_values

In [None]:
class OnnxVisionDecoderInit(torch.nn.Module):
    def __init__(self, decoder_sess):
        super().__init__()
        self.decoder = decoder_sess

    def forward(self, input_ids, encoder_hidden_states, attention_mask, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs):
        decoder_outputs = self.decoder.run(
            None,
            {
                'input_ids': input_ids.detach().cpu().numpy(),
                'encoder_attention_mask': attention_mask.detach().cpu().numpy(),
                "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
            },
        )

        list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
        out_past_key_values = tuple(
            list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)
        )

        return torch.from_numpy(decoder_outputs[0]), out_past_key_values

In [None]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

In [None]:
from transformers.generation_utils import GenerationMixin

class OnnxVision(GenerationMixin):
    def __init__(self, config, encoder_sess, decoder_sess, decoder_sess_init):
        self.config = config

        self.encoder = OnnxVisionEncoder(encoder_sess)
        self.decoder = OnnxVisionDecoder(decoder_sess)
        self.decoder_init = OnnxVisionDecoderInit(decoder_sess_init)

        self.main_input_name = 'pixel_values'

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):

        def _prepare_inputs_for_generation(input_ids, past=None, attention_mask=None, **model_kwargs):
            input_shape = input_ids.shape
            # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
            if attention_mask is None:
                attention_mask = input_ids.new_ones(input_shape)

            # cut decoder_input_ids if past is used
            if past is not None:
                input_ids = input_ids[:, -1:]

            return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

        decoder_inputs = _prepare_inputs_for_generation(input_ids, past=past)
        decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
        input_dict = {
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "decoder_input_ids": decoder_inputs["input_ids"],
            "encoder_outputs": encoder_outputs,
            "past_key_values": decoder_inputs["past_key_values"],
            "use_cache": use_cache,
        }
        return input_dict

    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
            )
        return reordered_past

    @property
    def device(self):
        return "cpu"

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def get_output_embeddings(self):
        return None

    def forward(
        self,
        pixel_values=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            if pixel_values is None:
                raise ValueError("You have to specify pixel_values")
            # Convert encoder inputs in embeddings if needed
            # (when using generate, we already get encoder_outputs generated
            #  by _prepare_encoder_decoder_kwargs_for_generation)
            encoder_outputs = self.encoder(
                pixel_values,
            )

        encoder_hidden_states = encoder_outputs[0]

        # optionally project encoder_hidden_states
        if (
            self.config.encoder.hidden_size != self.config.decoder.hidden_size
            and self.config.decoder.cross_attention_hidden_size is None
        ):
            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

        if past_key_values is None:
            init_decoder_outputs = self.decoder_init(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=decoder_attention_mask,
            )

            logits, past_key_values = init_decoder_outputs
        else:
            encoder_attention_mask = torch.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.int64)

            logits, past_key_values = self.decoder(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_values=past_key_values,
            )

        outputs = Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)

        return outputs

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)        

## Setup Utils

In [None]:
from onnxruntime import (
    GraphOptimizationLevel,
    InferenceSession,
    SessionOptions,
)

def get_onnx_runtime_sessions(
    path_to_encoder,
    path_to_decoder,
    path_to_initial_decoder,
    provider=[
        "CPUExecutionProvider",
    ],
) -> InferenceSession:
    options = SessionOptions()
    options.intra_op_num_threads = 1
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    options.log_severity_level = 0

    encoder_sess = InferenceSession(path_to_encoder, options, provider)
    encoder_sess.disable_fallback()

    decoder_sess = InferenceSession(path_to_decoder, options, provider)
    decoder_sess.disable_fallback()

    decoder_sess_init = InferenceSession(path_to_initial_decoder, options, provider)
    decoder_sess_init.disable_fallback()

    return encoder_sess, decoder_sess, decoder_sess_init

## Running the Model

### Setup Utils

In [None]:
from transformers import AutoConfig, AutoTokenizer
from PIL import Image

def try_inference(model, image, tokenizer, feature_extractor):
    f = feature_extractor(image, return_tensors='pt').pixel_values

    tokens = model.generate(
        f,
    )

    output = tokenizer.batch_decode(tokens)[0]

    return output

### Inference

In [None]:
if run_inference:
    # Load utils.
    feature_extractor = ViTFeatureExtractor.from_pretrained(huggingface_model)
    tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
    config = AutoConfig.from_pretrained(huggingface_model)
    hf_model = VisionEncoderDecoderModel.from_pretrained(huggingface_model)

    # Load ONNX model.
    encoder_path, decoder_path, decoder_init_path = get_quantized_paths() if run_quantize else get_exported_paths()
    encoder_sess, decoder_sess, decoder_init_sess = get_onnx_runtime_sessions(encoder_path, decoder_path, decoder_init_path)
    onnx_model = OnnxVision(config, encoder_sess, decoder_sess, decoder_init_sess)

In [None]:
from PIL import Image

if run_inference:
    image = Image.open(image_path).convert('RGB')

    onnx_output_text = try_inference(onnx_model, image, tokenizer, feature_extractor)
    print('ONNX:')
    print(onnx_output_text)

    print('PYTORCH:')
    pytorch_output_text = try_inference(hf_model, image, tokenizer, feature_extractor)
    print(pytorch_output_text)

    print(pytorch_output_text == onnx_output_text)