For this one, we'll export the model as an ONNX graph so that it can be used on platforms besides Python.

In [None]:
model_name = "UltraZoom-2X"
checkpoint_path = "./checkpoints/checkpoint.pt"
exports_path = "./exports"

Then, we'll load the base model checkpoint into memory from disk.

In [None]:
import torch

from src.ultrazoom.model import UltraZoom

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

model = UltraZoom(**checkpoint["model_args"])

model.add_weight_norms()

model = torch.compile(model)

model.load_state_dict(checkpoint["model"])

model.remove_weight_norms()

model.eval()

print("Base checkpoint loaded successfully")

Lastly, we'll prepare an ONNX graph and export the model in the ONNX format.

In [None]:
from os import path

from torch.onnx import export as export_onnx

onnx_path = path.join(exports_path, model_name, "model.onnx")

example_inputs = torch.randn(1, 3, 128, 128)

onnx_model = export_onnx(model, example_inputs, dynamo=True)

onnx_model.save(onnx_path)