In [11]:
import os
import sys
from pathlib import Path

import torch

# 确保能 import 到你的 LAR_IQA 工程
sys.path.append("..")  # 根据你的 ipynb 所在路径调整

from python.packages.LAR_IQA.scripts.utils import load_model


def export_lar_iqa_onnx(
    checkpoint_path: str = "../python/packages/LAR_IQA/checkpoint_epoch_3.pt",
    out_dir: str = "./out",
    onnx_name: str = "lar_iqa.onnx",
    use_cuda: bool = True,
):
    # 1. 选择设备
    device = "cuda" if (use_cuda and torch.cuda.is_available()) else "cpu"
    print(f"[INFO] Using device: {device}")

    # 2. 加载模型
    ckpt_path = Path(checkpoint_path).resolve()
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    print(f"[INFO] Loading checkpoint from: {ckpt_path}")
    model = load_model(str(ckpt_path), False, device)
    model.eval()

    # 3. 构造 dummy 输入（与 preprocess_image 输出形状一致）
    #
    # preprocess_image 中：
    #   image_authentic: Resize 到 (384, 384)
    #   image_synthetic: CenterCrop 到 (1280, 1280)
    #
    # 所以 dummy 输入分别是 [1, 3, 384, 384] 和 [1, 3, 1280, 1280]
    image_authentic = torch.randn(1, 3, 384, 384, device=device)
    image_synthetic = torch.randn(1, 3, 1280, 1280, device=device)

    # 4. 确保导出目录存在
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)
    onnx_path = out_path / onnx_name

    print(f"[INFO] Exporting ONNX to: {onnx_path}")

    # 5. 导出 ONNX
    torch.onnx.export(
        model,
        (image_authentic, image_synthetic),  # 模型的两个输入
        onnx_path.as_posix(),
        export_params=True,  # 保存权重到 ONNX
        opset_version=17,  # 常用的较新 opset 版本（你也可以改成 16/18）
        do_constant_folding=True,  # 常量折叠优化
        input_names=["image_authentic", "image_synthetic"],
        output_names=["score"],
        dynamic_axes={  # 只把 batch 维做成动态，空间尺寸固定
            "image_authentic": {0: "batch_size"},
            "image_synthetic": {0: "batch_size"},
            "score": {0: "batch_size"},
        },
    )

    print("[INFO] ONNX export finished.")
    print(f"[INFO] ONNX model saved at: {onnx_path}")
    return onnx_path


# 在 ipynb 中直接跑这一段即可导出
if __name__ == "__main__":
    export_lar_iqa_onnx()


[INFO] Using device: cuda
[INFO] Loading checkpoint from: F:\ML\PythonAIProject\SMARKMediaTools_web\electron-media-toolbox\python\packages\LAR_IQA\checkpoint_epoch_3.pt


  torch.onnx.export(
W1120 22:48:02.254000 30040 site-packages\torch\onnx\_internal\exporter\_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 17 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


[INFO] Exporting ONNX to: out\lar_iqa.onnx
[torch.onnx] Obtain model graph for `MobileNetMerged([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MobileNetMerged([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...


The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 17).


[torch.onnx] Translate the graph into ONNX... ✅


Failed to convert the model to the target version 17 using the ONNX C API. The model was not modified
Traceback (most recent call last):
  File "d:\ProgramData\miniforge3\envs\ML\lib\site-packages\onnxscript\version_converter\__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
  File "d:\ProgramData\miniforge3\envs\ML\lib\site-packages\onnxscript\version_converter\_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
  File "d:\ProgramData\miniforge3\envs\ML\lib\site-packages\onnxscript\version_converter\__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
  File "d:\ProgramData\miniforge3\envs\ML\lib\site-packages\onnx\version_converter.py", line 37, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
RuntimeError: D:\a\onnx\onnx\onnx\onnx/version_converter/adapters/axes_input_to_attribute.h:65: adapt: Assertion `node->hasAttribute(kaxes)` failed: No ini

Applied 184 of general pattern rewrite rules.
[INFO] ONNX export finished.
[INFO] ONNX model saved at: out\lar_iqa.onnx
