-
Notifications
You must be signed in to change notification settings - Fork 169
add torch._scale_mm #606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add torch._scale_mm #606
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,9 +12,14 @@ | |
| sgl_kernel = None | ||
|
|
||
| try: | ||
| from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax | ||
| from torchao.quantization.utils import quant_int8_per_token_matmul as torchao_int8_gemm | ||
| from torchao.quantization.utils import quantize_activation_per_token_absmax as torchao_int8_quant | ||
| except ImportError: | ||
| quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None | ||
| try: | ||
| from torchao.quantization.utils import _quant_int8_per_token_matmul as torchao_int8_gemm | ||
| from torchao.quantization.utils import _quantize_activation_per_token_absmax as torchao_int8_quant | ||
| except ImportError: | ||
| torchao_int8_gemm, torchao_int8_quant = None, None | ||
|
|
||
| try: | ||
| from q8_kernels.functional.linear import q8_linear | ||
|
|
@@ -194,15 +199,15 @@ def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): | |
| self.register_buffer("bias", None) | ||
|
|
||
| def act_quant_func(self, x): | ||
| input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x) | ||
| input_tensor_quant, input_tensor_scale = torchao_int8_quant(x) | ||
| return input_tensor_quant, input_tensor_scale | ||
|
|
||
| def forward(self, input_tensor): | ||
| input_tensor = input_tensor.squeeze(0) | ||
| input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) | ||
| output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16) | ||
| output_tensor = torchao_int8_gemm(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16) | ||
| if self.bias is not None: | ||
| output_tensor = output_tensor + self.bias | ||
| output_tensor = output_tensor.add_(self.bias) | ||
|
|
||
| return output_tensor.unsqueeze(0) | ||
|
|
||
|
|
@@ -221,6 +226,56 @@ def maybe_cast(t): | |
| return self | ||
|
|
||
|
|
||
| class TorchaoQuantLinearFp8(nn.Module): | ||
| def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): | ||
| super().__init__() | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
|
|
||
| self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) | ||
| self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) | ||
|
|
||
| if bias: | ||
| self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) | ||
| else: | ||
| self.register_buffer("bias", None) | ||
|
|
||
| def act_quant_func(self, x): | ||
| abs_max = x.abs().max(dim=-1, keepdim=True)[0] | ||
| abs_max = torch.clamp(abs_max, min=1e-8) | ||
| scale = abs_max / 448.0 | ||
| quantized = torch.clamp(x / scale, -448, 448).to(torch.float8_e4m3fn) | ||
| return quantized, scale.float() | ||
|
Comment on lines
+243
to
+248
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def forward(self, input_tensor): | ||
| input_tensor = input_tensor.squeeze(0) | ||
| input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) | ||
| out = torch._scaled_mm( | ||
| input_tensor_quant, | ||
| self.weight.t(), | ||
| scale_a=input_tensor_scale, | ||
| scale_b=self.weight_scale.t(), | ||
| bias=self.bias.to(torch.bfloat16) if self.bias is not None else None, | ||
| out_dtype=torch.bfloat16, | ||
| use_fast_accum=True, | ||
| ) | ||
| return out.unsqueeze(0) | ||
|
|
||
| def _apply(self, fn): | ||
| for module in self.children(): | ||
| module._apply(fn) | ||
|
|
||
| def maybe_cast(t): | ||
| if t is not None and t.device != fn(t).device: | ||
| return fn(t) | ||
| return t | ||
|
|
||
| self.weight = maybe_cast(self.weight) | ||
| self.weight_scale = maybe_cast(self.weight_scale) | ||
| self.bias = maybe_cast(self.bias) | ||
| return self | ||
|
|
||
|
|
||
| class Q8FQuantLinearInt8(nn.Module): | ||
| def __init__(self, in_features, out_features, bias=True, dtype=torch.float32): | ||
| super().__init__() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -417,3 +417,17 @@ def run_pipeline(self, input_info): | |
| if GET_RECORDER_MODE(): | ||
| monitor_cli.lightx2v_worker_request_success.inc() | ||
| return gen_video_final | ||
|
|
||
| def __del__(self): | ||
| if hasattr(self, "model"): | ||
| del self.model | ||
| if hasattr(self, "text_encoders"): | ||
| del self.text_encoders | ||
| if hasattr(self, "image_encoder"): | ||
| del self.image_encoder | ||
| if hasattr(self, "vae_encoder"): | ||
| del self.vae_encoder | ||
| if hasattr(self, "vae_decoder"): | ||
| del self.vae_decoder | ||
| torch.cuda.empty_cache() | ||
| gc.collect() | ||
|
Comment on lines
+421
to
+433
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
add_method performs an in-place addition and returns the modified tensor. The assignment back tooutput_tensoris redundant. You can simplify this to justoutput_tensor.add_(self.bias)for better clarity.