Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference speed extremely slow #26

Closed
razvypp opened this issue May 20, 2024 · 6 comments
Closed

inference speed extremely slow #26

razvypp opened this issue May 20, 2024 · 6 comments

Comments

@razvypp
Copy link

razvypp commented May 20, 2024

Hello,

The inference speed is extremely slow.
I am doing the inference with GPU, but its the same i am doing with u2net and ths speed there is 12x faster.

Is there anything i can do to speed up things?

I have also tried to export to onnx but get error

import torch
import torch.onnx
from models.birefnet import BiRefNet
from utils import check_state_dict
from torch.onnx import register_custom_op_symbolic

Register custom symbolic function for deform_conv2d

def deform_conv2d_symbolic(g, input, weight, offset, bias, stride, padding, dilation, groups, deformable_groups, use_mask=False, mask=None):
return g.op("DeformConv2d", input, weight, offset, bias,
stride_i=stride, padding_i=padding, dilation_i=dilation,
groups_i=groups, deformable_groups_i=deformable_groups)

register_custom_op_symbolic('torchvision::deform_conv2d', deform_conv2d_symbolic, 11)

Load the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiRefNet(bb_pretrained=False).to(device)
state_dict = torch.load("/root/BiRefNet-massive-epoch_240.pth", map_location=device)
state_dict = check_state_dict(state_dict)
model.load_state_dict(state_dict)
model.eval()

Dummy input to trace the model

dummy_input = torch.randn(1, 3, 1024, 1024).to(device)

Ensure to handle tensor-to-Python type conversions in your model

Example modifications:

if W % self.patch_size[1] != 0:

replace with

if (W % self.patch_size[1]).item() != 0:

Export the model

onnx_model_path = "/root/BiRefNet.onnx"
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
onnx_model_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # variable length axes
)

print(f"Model has been converted to ONNX and saved at {onnx_model_path}")

@ZhengPeng7
Copy link
Owner

Hi, could you check your texts and correct their formats?

@ZhengPeng7
Copy link
Owner

About the comparison with u-2-net, I don't think that is a problem. It's almost impossible to bring much improvement on accuracy with same number of parameters.

Statistics about BiRefNet with different backbones can be referred to this issue.

You can also find some other issues where we talked about increasing the speed of inference (ONNX, FP16, ...). By now, there haven't been very good methods for it. You can wait for the version with swin_v1_tiny, which could be 4 times faster than the official one.

@razvypp
Copy link
Author

razvypp commented May 21, 2024

Will the v1_tiny have the same performance on general datasets?

I failed to convert to onnx i get crash, is there a guide for this?

@ZhengPeng7
Copy link
Owner

Of course not, that's a trade-off.
Sorry, as for the issues I mentioned above, I currently have no time for this kind of thing.

@ZhengPeng7
Copy link
Owner

ZhengPeng7 commented May 25, 2024

The well-trained BiRefNet with the swin_v1_tiny backbone has been uploaded to my Google Drive. Check the stuff in README for access to the weights, performance, predicted maps, and training log in the corresponding folder (exp-xxx). The performance is a bit lower than the official version, but still good (HCE↓: 1152 -> 1182 on DIS-VD). Feel free to download and use them.

@ZhengPeng7
Copy link
Owner

By the way, check the update in inference.py. Set the torch.set_float32_matmul_precision to 'high' can increase the FPS on A100 from 5 to 12 with ~0 performance downgrade (Because I set it to 'high' during training).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants