Skip to content

add torch._scale_mm#606

Merged
helloyongyang merged 1 commit intomainfrom
dev/fixbug
Dec 12, 2025
Merged

add torch._scale_mm#606
helloyongyang merged 1 commit intomainfrom
dev/fixbug

Conversation

@gushiqiao
Copy link
Contributor

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @gushiqiao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the quantization capabilities of the system by integrating FP8 (Float8) precision for matrix multiplications. It leverages torch._scaled_mm and torchao utilities to enable efficient FP8 quantization for both weights and activations, introducing new classes and updating existing model configurations to support this new scheme. The changes aim to reduce memory footprint and potentially accelerate computations across various model architectures. Additionally, explicit resource management has been added to ensure proper cleanup of GPU memory.

Highlights

  • FP8 Quantization Support: Introduced a new FP8 quantization scheme using torch._scaled_mm for improved efficiency in matrix multiplications.
  • torchao Integration Enhancements: Enhanced integration with torchao.quantization.utils by aliasing functions and adding robust import handling for both standard and internal (_quant_) versions.
  • New Quantization Classes: Added MMWeightWfp8channelAfp8channeldynamicTorchao and TorchaoQuantLinearFp8 classes to manage FP8 weight and activation quantization.
  • Model Configuration Updates: Updated various model loading and configuration files across different model architectures to recognize and utilize the new fp8-torchao quantization scheme.
  • Resource Management: Implemented explicit resource cleanup in default_runner.py by adding a __del__ method to free model components and clear CUDA memory, improving memory management.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 quantization using torch._scaled_mm, which is a great addition for performance. The changes are well-structured, adding new quantization schemes and linear layer implementations. I've identified a few areas for improvement:

  • An unreliable __del__ method for resource cleanup.
  • Some minor code redundancies and duplications.
  • A redundant assignment with an in-place operation.
    Overall, the changes are good, and addressing these points will improve the code's robustness and maintainability.

Comment on lines +421 to +433
def __del__(self):
if hasattr(self, "model"):
del self.model
if hasattr(self, "text_encoders"):
del self.text_encoders
if hasattr(self, "image_encoder"):
del self.image_encoder
if hasattr(self, "vae_encoder"):
del self.vae_encoder
if hasattr(self, "vae_decoder"):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using __del__ for resource cleanup, especially for GPU memory, is unreliable. The __del__ method is not guaranteed to be called when you expect it, due to Python's garbage collection behavior (e.g., circular references). This can lead to resource leaks. It's better to implement an explicit cleanup method, like cleanup() or close(), and ensure it's called deterministically when the runner is no longer needed.

output_tensor = torchao_int8_gemm(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
output_tensor = output_tensor.add_(self.bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The add_ method performs an in-place addition and returns the modified tensor. The assignment back to output_tensor is redundant. You can simplify this to just output_tensor.add_(self.bias) for better clarity.

Suggested change
output_tensor = output_tensor.add_(self.bias)
output_tensor.add_(self.bias)

Comment on lines +243 to +248
def act_quant_func(self, x):
abs_max = x.abs().max(dim=-1, keepdim=True)[0]
abs_max = torch.clamp(abs_max, min=1e-8)
scale = abs_max / 448.0
quantized = torch.clamp(x / scale, -448, 448).to(torch.float8_e4m3fn)
return quantized, scale.float()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This act_quant_func implementation is identical to act_quant_fp8_perchannel_sym_torchao in lightx2v/common/ops/mm/mm_weight.py. To avoid code duplication and improve maintainability, consider moving this logic to a shared utility function and importing it in both places.

@helloyongyang helloyongyang merged commit e383ac2 into main Dec 12, 2025
2 checks passed
@gushiqiao gushiqiao deleted the dev/fixbug branch December 30, 2025 06:57
helloyongyang pushed a commit that referenced this pull request Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants