In [1]:
import json
import torch
import onnx
from transformers import (
    VisionEncoderDecoderModel,
    VisionEncoderDecoderConfig,
    TrOCRProcessor,
)
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


To convert `ORTModelForVision2Seq` ONNX model into `VisionEncoderDecoderModel` Torch model, we must find a map between ONNX model and Torch model. Therefore, we follow steps below:

Initialize a `VisionEncoderDecoderModel` model with identical configuration and assign designed weights to it.

In [2]:
with open("pix2text-mfr/config.json", "r") as config_file:
    config_data = json.load(config_file)

config = VisionEncoderDecoderConfig.from_dict(config_data)
torch_model = VisionEncoderDecoderModel(config=config)

Here, "designed" means we are confident to find it in a converted ONNX model. To satisfy that requirement, we

In [3]:
def modify_weights(model):
    counter1 = 0.01
    counter2 = 0.02
    for param in model.parameters():
        new_values = torch.full(param.shape, counter1)

        # Ensure there is a second element to modify
        if param.numel() > 1:  
            param_shape = param.shape
            flattened_values = new_values.view(-1)
            for i in range(1, len(flattened_values), param_shape[-1]):
                # Modify the second element to ensure we can find it even after transpose
                flattened_values[i] = counter2  
        param.data = new_values
        counter1 += 0.01
        counter2 += 0.01

modify_weights(torch_model)

Then we can simply save the model and convert it into ONNX model

In [4]:
torch_model.save_pretrained("torch_model")

"""
optimum-cli export onnx --task image-to-text --model torch_model onnx_model
"""

'\noptimum-cli export onnx --task image-to-text --model torch_model onnx_model\n'

We are finding the map between ONNX model weights and Torch model weights, so let's load ONNX model first

In [9]:
encoder_model_path = "onnx_model/encoder_model.onnx"
decoder_model_path = "onnx_model/decoder_model.onnx"

encoder_onnx = onnx.load(encoder_model_path)
decoder_onnx = onnx.load(decoder_model_path)

Then extract everything

In [10]:
def extract_weights_from_onnx(onnx_model):
    weights = {}
    for tensor in onnx_model.graph.initializer:
        weights[tensor.name] = torch.tensor(onnx.numpy_helper.to_array(tensor))
    return weights


encoder_weights = extract_weights_from_onnx(encoder_onnx)
decoder_weights = extract_weights_from_onnx(decoder_onnx)
onnx_weights = {**encoder_weights, **decoder_weights}

We also need weights of Torch model

In [11]:
torch_weights = {name: param for name, param in torch_model.named_parameters()}

Match them.

About why I transpose Torch weights and match them, it's a simple assumption: 

When exporting Torch model to ONNX model, optimum changed some steps of inference like $ AB = \left(B^T A^T\right)^T $, and therefore the name of those weights A or B are missing(You'll find something like `onnx::MatMul_1234`). 

So my solution is just transpose them back, and surprisingly it works.

In [12]:
def find_matching_weights(torch_weights, onnx_weights, tolerance=1e-3):
    matches = []
    mismatches = []
    remaining_onnx_weights = onnx_weights.copy()

    for torch_name, torch_weight in torch_weights.items():
        match_found = False
        for onnx_name, onnx_weight in remaining_onnx_weights.items():
            if torch_weight.shape == onnx_weight.shape:
                # Calculate the difference
                difference = torch.abs(torch_weight - onnx_weight)
                max_difference = torch.max(difference).item()
                if max_difference < tolerance:
                    matches.append((torch_name, onnx_name, False))
                    match_found = True
                    del remaining_onnx_weights[onnx_name]
                    break

        # If no match found, try with transposed weight
        if not match_found:
            transposed_weight = torch_weight.T
            for onnx_name, onnx_weight in remaining_onnx_weights.items():
                if transposed_weight.shape == onnx_weight.shape:
                    # Calculate the difference
                    difference = torch.abs(transposed_weight - onnx_weight)
                    max_difference = torch.max(difference).item()
                    if max_difference < tolerance:
                        matches.append((torch_name, onnx_name, True))
                        match_found = True
                        del remaining_onnx_weights[onnx_name]
                        break

        if not match_found:
            mismatches.append(torch_name)

    return matches, mismatches

matches, mismatches = find_matching_weights(torch_weights, onnx_weights)

print(f"Warning: weight {mismatches} mismatched!")
# Don't worry about warnings. It's okay missing those weights for our model.



Now we've got the map, let's extract weights from ONNX model

In [16]:
encoder_model_path = "pix2text-mfr/encoder_model.onnx"
decoder_model_path = "pix2text-mfr/decoder_model.onnx"

encoder_onnx = onnx.load(encoder_model_path)
decoder_onnx = onnx.load(decoder_model_path)

encoder_weights = extract_weights_from_onnx(encoder_onnx)
decoder_weights = extract_weights_from_onnx(decoder_onnx)
onnx_weights = {**encoder_weights, **decoder_weights}

Simply load weights into a new torch model

In [17]:
def load_weights_to_torch(weight_map, onnx_weights, torch_model):
    for torch_name, onnx_name, transpose in weight_map:
        onnx_weight = onnx_weights[onnx_name]
        if transpose:
            onnx_weight = onnx_weight.T
        torch_model_state_dict = torch_model.state_dict()
        if torch_name in torch_model_state_dict:
            torch_model_state_dict[torch_name].copy_(onnx_weight)
        else:
            print(f"Warning: {torch_name} not found in the PyTorch model state dict.")

model = VisionEncoderDecoderModel(config=config)



load_weights_to_torch(matches, onnx_weights, model)

Test our Torch model

In [18]:
processor = TrOCRProcessor.from_pretrained("breezedeus/pix2text-mfr")

image_fps = [
    "testimg/1.png",
    "testimg/2.png",
    "testimg/3.png",
]

for image_fp in image_fps:
    images = [Image.open(image_fp).convert("RGB")]
    pixel_values = processor(images=images, return_tensors="pt").pixel_values

    with torch.no_grad():
        outputs = model.generate(pixel_values)

    generated_text = processor.batch_decode(outputs, skip_special_tokens=True)

    print(f"Image: {image_fp}")
    print(f"Decoded Output: {generated_text}")
    print("-" * 50)

Image: testimg/1.png
Decoded Output: ['x ^ { 2 } + 1 = 0 \\Rightarrow x = \\pm i']
--------------------------------------------------
Image: testimg/2.png
Decoded Output: ['\\begin{array} { r l } { } & { { } \\operatorname* {']
--------------------------------------------------
Image: testimg/3.png
Decoded Output: ['\\begin{aligned} { } & { { } \\operatorname* { l i m }']
--------------------------------------------------


Save, of course

In [19]:
model.save_pretrained("converted_torch")

Also, note that to ensure the model is standalone and works perfectly, you also need to copy some configurations like `generation_config.json` of original model to the `converted_torch` folder.