Skip to content

feat: enable bnb with torch compile#45

Merged
llcnt merged 2 commits intomainfrom
feat/enable_bnb_with_torch_compile
Apr 8, 2025
Merged

feat: enable bnb with torch compile#45
llcnt merged 2 commits intomainfrom
feat/enable_bnb_with_torch_compile

Conversation

@llcnt
Copy link
Collaborator

@llcnt llcnt commented Apr 1, 2025

Description

Small PR that enable the combination of diffusers_int8 quantizer and torch_compile for diffusers model.
The change in the code is minimal, but the results are interesting:

  • following the previous work on HQQ (feat: enable combination of torch_compile and hqq_diffusers #23), we quantize and compile the denoiser inside the diffusers pipeline;
  • I tested to compile the forward function of the denoiser instead (ie the call function of the unet or transformer), but the inference is not accelerated;
  • I tested the inference time for 50 denoising steps on a Flux-8B-freepik model on a L40S GPU and obtained $23.63s$. After Bnb quantization, this number increases to $28.48s$. After quantization+compilation, we can reduce this number to $23.69s$;
  • Note that the default compilation (ie capturing the full graph) will raise an error on a bnb-quantized model (because bnb will dynamically look for outliers in the inputs. JIT compilation does not support dynamic input with the default torch.compile parameters). To bypass this, we should set smash_config['torch_compile_fullgraph'] = False
  • I also tested to compile (both the model and the forward) of a model quantized with quanto, but it fails because quanto uses 'fake_tensors' that are not compatible with torch.compile.

Related Issue

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

I added a combination test on a sd model:
image
Other unit test are still valid, and I have tested to save and re-load the model that was quantized+compiled in this notebook.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Similar results can be obtained in pruna_pro with the combination of HIGGS quantizer and torch.compile (but with better memory reduction! you can go up to $3$ or $2$ bits):

  • Initial inference time after quantization is $31.40s$, and can be reduced to $25.65s$ after compilation;
  • current experiments are run to improve this speedup, and also enjoy speedup for batch inference.

@llcnt llcnt force-pushed the feat/enable_bnb_with_torch_compile branch from 8e47460 to 48c555d Compare April 2, 2025 07:58
@llcnt llcnt marked this pull request as ready for review April 4, 2025 08:07
@llcnt llcnt requested review from johnrachwan123 and sharpenb April 4, 2025 08:08
Copy link
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Looks lgtm :)

@llcnt llcnt mentioned this pull request Apr 7, 2025
10 tasks
Copy link
Member

@johnrachwan123 johnrachwan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@llcnt llcnt merged commit 0d739b0 into main Apr 8, 2025
9 checks passed
@llcnt llcnt deleted the feat/enable_bnb_with_torch_compile branch April 8, 2025 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants