# Convert our BiRefNet weights to onnx format.

> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.

> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample  
> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb

+ Transforming a standard BiRefNet on GPU needs **19.7GB** GPU memory.
+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example on Colab.

### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om

In [1]:
!pip install -q onnx onnxscript onnxruntime-gpu==1.18.1

[0m

In [2]:
cd ..

/root/autodl-tmp/BiRefNet


In [3]:
import torch


weights_file = 'BiRefNet-matting-epoch_100.pth'  # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
with open('config.py') as fp:
    file_lines = fp.read()
if 'swin_v1_tiny' in weights_file:
    print('Set `swin_v1_tiny` as the backbone.')
    file_lines = file_lines.replace(
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][6]
        ''',
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][3]
        ''',
    )
    with open('config.py', mode="w") as fp:
        fp.write(file_lines)
else:
    file_lines = file_lines.replace(
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][3]
        ''',
        '''
            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5
        ][6]
        ''',
    )
    with open('config.py', mode="w") as fp:
        fp.write(file_lines)

In [5]:
from utils import check_state_dict
from models.birefnet import BiRefNet


birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load('./{}'.format(weights_file), map_location=device, weights_only=True)
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)

torch.set_float32_matmul_precision(['high', 'highest'][0])

birefnet.to(device)
_ = birefnet.eval()

  from .autonotebook import tqdm as notebook_tqdm


# Process deform_conv2d in the conversion to ONNX

In [6]:
!git clone https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter
%cp deform_conv2d_onnx_exporter/src/deform_conv2d_onnx_exporter.py .
!rm -rf deform_conv2d_onnx_exporter

Cloning into 'deform_conv2d_onnx_exporter'...
remote: Enumerating objects: 205, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Total 205 (delta 6), reused 6 (delta 6), pack-reused 198 (from 1)[K
Receiving objects: 100% (205/205), 36.21 KiB | 170.00 KiB/s, done.
Resolving deltas: 100% (102/102), done.


In [7]:
with open('deform_conv2d_onnx_exporter.py') as fp:
    file_lines = fp.read()

file_lines = file_lines.replace(
    "return sym_help._get_tensor_dim_size(tensor, dim)",
    '''
    tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)
    if tensor_dim_size == None and (dim == 2 or dim == 3):
        import typing
        from torch import _C

        x_type = typing.cast(_C.TensorType, tensor.type())
        x_strides = x_type.strides()

        tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]
    elif tensor_dim_size == None and (dim == 0):
        import typing
        from torch import _C

        x_type = typing.cast(_C.TensorType, tensor.type())
        x_strides = x_type.strides()
        tensor_dim_size = x_strides[3]

    return tensor_dim_size
    ''',
)

with open('deform_conv2d_onnx_exporter.py', mode="w") as fp:
    fp.write(file_lines)

In [None]:
from torch.onnx.symbolic_helper import parse_args
from torch.onnx import register_custom_op_symbolic


@parse_args(
    "v",  # arg0: input (tensor)
    "v",  # arg1: weight (tensor)
    "v",  # arg2: offset (tensor)
    "v",  # arg3: mask (tensor)
    "v",  # arg4: bias (tensor)
    "i",  # arg5: stride_h
    "i",  # arg6: stride_w
    "i",  # arg7: pad_h
    "i",  # arg8: pad_w
    "i",  # arg9: dilation_h
    "i",  # arg10: dilation_w
    "i",  # arg11: groups
    "i",  # arg12: deform_groups
    "b",  # arg13: some bool
)
def symbolic_deform_conv_19(
    g,
    input,
    weight,
    offset,
    mask,
    bias,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    dilation_h,
    dilation_w,
    groups,
    deform_groups,
    maybe_bool,
):
    # Convert them back into lists where needed:
    strides = [stride_h, stride_w]
    pads = [pad_h, pad_w, pad_h, pad_w]
    dilations = [dilation_h, dilation_w]

    # If bias is None, you'd do something like:
    #   if bias.node().kind() == "prim::Constant" and bias.node()["value"] is None:
    #       bias = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
    #
    # But from your debug, arg4 is a real tensor of shape [256], so it's not None.

    # Similarly for mask not being None in your debug, but if you want to handle
    # a None path, do a check like above.

    # Construct the official ONNX DeformConv (Opset 19).
    # 'main' domain => just "DeformConv"
    return g.op(
        "DeformConv",
        input,
        weight,
        offset,
        bias,
        mask,
        strides_i=strides,
        pads_i=pads,
        dilations_i=dilations,
        group_i=groups,
        offset_group_i=deform_groups,
        # You can ignore maybe_bool if you don't need it, or pass it as an attribute.
    )

In [8]:
from torchvision.ops.deform_conv import DeformConv2d

# import deform_conv2d_onnx_exporter
# # register deform_conv2d operator
# deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()

register_custom_op_symbolic(
    "torchvision::deform_conv2d",  # PyTorch JIT/FX name
    symbolic_deform_conv_19,
    opset_version=19,
)

def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):
    input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)

    input_layer_names = ['input_image']
    output_layer_names = ['output_image']

    torch.onnx.export(
        net,
        input,
        file_name,
        verbose=False,
        opset_version=20,
        input_names=input_layer_names,
        output_names=output_layer_names,
        dynamic_axes={"input_image": [0]},
    )
convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)

  if W % self.patch_size[1] != 0:
  if H % self.patch_size[0] != 0:
  Hp = int(np.ceil(H / self.window_size)) * self.window_size
  Wp = int(np.ceil(W / self.window_size)) * self.window_size
  assert L == H * W, "input feature has wrong size"
  B = int(windows.shape[0] / (H * W / window_size / window_size))
  if pad_r > 0 or pad_b > 0:
  assert L == H * W, "input feature has wrong size"
  pad_input = (H % 2 == 1) or (W % 2 == 1)
  if pad_input:
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


verbose: False, log level: Level.ERROR



# Load ONNX weights and do the inference.

In [None]:
from PIL import Image
from torchvision import transforms


transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

imagepath = './Helicopter-HR.jpg'
image = Image.open(imagepath)
image = image.convert("RGB") if image.mode != "RGB" else image
input_images = transform_image(image).unsqueeze(0).to(device)
input_images_numpy = input_images.cpu().numpy()

In [None]:
import onnxruntime
import matplotlib.pyplot as plt


providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
onnx_session = onnxruntime.InferenceSession(
    weights_file.replace('.pth', '.onnx'),
    providers=providers
)
input_name = onnx_session.get_inputs()[0].name
print(onnxruntime.get_device(), onnx_session.get_providers())

In [None]:
from time import time
import matplotlib.pyplot as plt

time_st = time()
pred_onnx = torch.tensor(
    onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]
).squeeze(0).sigmoid().cpu()
print(time() - time_st)

plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()

In [None]:
with torch.no_grad():
    preds = birefnet(input_images)[-1].sigmoid().cpu()
plt.imshow(preds.squeeze(), cmap='gray'); plt.show()

In [None]:
diff = abs(preds - pred_onnx)
print('sum(diff):', diff.sum())
plt.imshow((diff).squeeze(), cmap='gray'); plt.show()

# Efficiency Comparison between .pth and .onnx

In [None]:
%%timeit
with torch.no_grad():
    preds = birefnet(input_images)[-1].sigmoid().cpu()

In [None]:
%%timeit
pred_onnx = torch.tensor(
    onnx_session.run(None, {input_name: input_images_numpy})[-1]
).squeeze(0).sigmoid().cpu()