Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,981 changes: 1,697 additions & 284 deletions app/gradio_demo.py

Large diffs are not rendered by default.

1,906 changes: 1,671 additions & 235 deletions app/gradio_demo_zh.py

Large diffs are not rendered by default.

53 changes: 48 additions & 5 deletions lightx2v/common/ops/mm/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@
deep_gemm = None

try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
from torchao.quantization.utils import quant_int8_per_token_matmul as torchao_int8_gemm
from torchao.quantization.utils import quantize_activation_per_token_absmax as torchao_int8_quant
except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
from torchao.quantization.utils import _quant_int8_per_token_matmul as torchao_int8_gemm
from torchao.quantization.utils import _quantize_activation_per_token_absmax as torchao_int8_quant
except ImportError:
torchao_int8_gemm, torchao_int8_quant = None, None

try:
import gguf
Expand Down Expand Up @@ -595,9 +600,16 @@ def per_block_cast_to_fp8(self, x):
# act quant kernels
# =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
input_tensor_quant, input_tensor_scale = torchao_int8_quant(x)
return input_tensor_quant, input_tensor_scale

def act_quant_fp8_perchannel_sym_torchao(self, x):
abs_max = x.abs().max(dim=-1, keepdim=True)[0]
abs_max = torch.clamp(abs_max, min=1e-8)
scale = abs_max / 448.0
quantized = torch.clamp(x / scale, -448, 448).to(torch.float8_e4m3fn)
return quantized, scale.float()

def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
Expand Down Expand Up @@ -1109,6 +1121,37 @@ def apply(self, input_tensor):
return output_tensor


@MM_WEIGHT_REGISTER("fp8-torchao")
class MMWeightWfp8channelAfp8channeldynamicTorchao(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Torchao

Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Torchao
"""

def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_torchao

def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_fp8_perchannel_sym_torchao(input_tensor)
out = torch._scaled_mm(
input_tensor_quant,
self.weight,
scale_a=input_tensor_scale,
scale_b=self.weight_scale.t(),
bias=self.bias,
out_dtype=self.infer_dtype,
use_fast_accum=True,
)
return out


@MM_WEIGHT_REGISTER("int8-torchao")
class MMWeightWint8channelAint8channeldynamicTorchao(MMWeightQuantTemplate):
"""
Expand All @@ -1129,9 +1172,9 @@ def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
output_tensor = torchao_int8_gemm(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
if self.bias is not None:
output_tensor = output_tensor + self.bias
output_tensor.add_(self.bias)

return output_tensor

Expand Down
3 changes: 3 additions & 0 deletions lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Q8FQuantLinearInt8, # noqa E402
SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402
TorchaoQuantLinearFp8, # noqa E402
VllmQuantLinearInt8, # noqa E402
)
from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402
Expand Down Expand Up @@ -131,6 +132,8 @@ def load_text_encoder(
linear_cls = SglQuantLinearFp8
elif text_encoder_quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif text_encoder_quant_scheme == "fp8-torchao":
linear_cls = TorchaoQuantLinearFp8
elif text_encoder_quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif text_encoder_quant_scheme == "fp8-q8f":
Expand Down
65 changes: 60 additions & 5 deletions lightx2v/models/input_encoders/hf/q_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
sgl_kernel = None

try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
from torchao.quantization.utils import quant_int8_per_token_matmul as torchao_int8_gemm
from torchao.quantization.utils import quantize_activation_per_token_absmax as torchao_int8_quant
except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
from torchao.quantization.utils import _quant_int8_per_token_matmul as torchao_int8_gemm
from torchao.quantization.utils import _quantize_activation_per_token_absmax as torchao_int8_quant
except ImportError:
torchao_int8_gemm, torchao_int8_quant = None, None

try:
from q8_kernels.functional.linear import q8_linear
Expand Down Expand Up @@ -194,15 +199,15 @@ def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
self.register_buffer("bias", None)

def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
input_tensor_quant, input_tensor_scale = torchao_int8_quant(x)
return input_tensor_quant, input_tensor_scale

def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
output_tensor = torchao_int8_gemm(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
output_tensor = output_tensor.add_(self.bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The add_ method performs an in-place addition and returns the modified tensor. The assignment back to output_tensor is redundant. You can simplify this to just output_tensor.add_(self.bias) for better clarity.

Suggested change
output_tensor = output_tensor.add_(self.bias)
output_tensor.add_(self.bias)


return output_tensor.unsqueeze(0)

Expand All @@ -221,6 +226,56 @@ def maybe_cast(t):
return self


class TorchaoQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))

if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)

def act_quant_func(self, x):
abs_max = x.abs().max(dim=-1, keepdim=True)[0]
abs_max = torch.clamp(abs_max, min=1e-8)
scale = abs_max / 448.0
quantized = torch.clamp(x / scale, -448, 448).to(torch.float8_e4m3fn)
return quantized, scale.float()
Comment on lines +243 to +248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This act_quant_func implementation is identical to act_quant_fp8_perchannel_sym_torchao in lightx2v/common/ops/mm/mm_weight.py. To avoid code duplication and improve maintainability, consider moving this logic to a shared utility function and importing it in both places.


def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
out = torch._scaled_mm(
input_tensor_quant,
self.weight.t(),
scale_a=input_tensor_scale,
scale_b=self.weight_scale.t(),
bias=self.bias.to(torch.bfloat16) if self.bias is not None else None,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
return out.unsqueeze(0)

def _apply(self, fn):
for module in self.children():
module._apply(fn)

def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t

self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self


class Q8FQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()
Expand Down
5 changes: 5 additions & 0 deletions lightx2v/models/input_encoders/hf/wan/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Q8FQuantLinearInt8, # noqa E402
SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402
TorchaoQuantLinearFp8, # noqa E402
VllmQuantLinearInt8, # noqa E402,
VllmQuantLinearFp8, # noqa E402
)
Expand Down Expand Up @@ -200,6 +201,8 @@ def __init__(
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "fp8-torchao":
linear_cls = TorchaoQuantLinearFp8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
Expand Down Expand Up @@ -275,6 +278,8 @@ def __init__(
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "fp8-torchao":
linear_cls = TorchaoQuantLinearFp8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
Expand Down
6 changes: 5 additions & 1 deletion lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8
Expand Down Expand Up @@ -69,6 +69,8 @@ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "fp8-torchao":
linear_cls = TorchaoQuantLinearFp8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
Expand Down Expand Up @@ -153,6 +155,8 @@ def __init__(
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "fp8-torchao":
linear_cls = TorchaoQuantLinearFp8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
Expand Down
1 change: 1 addition & 0 deletions lightx2v/models/networks/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, model_path, config, device):
"fp8-sgl",
"int8-sgl",
"int8-torchao",
"fp8-torchao",
"nvfp4",
"mxfp4",
"mxfp6-mxfp8",
Expand Down
4 changes: 2 additions & 2 deletions lightx2v/models/networks/wan/audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(self, model_path, config, device):
def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None:
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl"]:
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl", "fp8-torchao"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-sgl", "int8-tmo"]:
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-tmo"]:
adapter_model_name = "audio_adapter_model_int8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]:
adapter_model_name = "audio_adapter_model_mxfp4.safetensors"
Expand Down
1 change: 1 addition & 0 deletions lightx2v/models/networks/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, model_path, config, device, model_type="wan2.1"):
"fp8-sgl",
"int8-sgl",
"int8-torchao",
"fp8-torchao",
"nvfp4",
"mxfp4",
"mxfp6-mxfp8",
Expand Down
14 changes: 14 additions & 0 deletions lightx2v/models/runners/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,17 @@ def run_pipeline(self, input_info):
if GET_RECORDER_MODE():
monitor_cli.lightx2v_worker_request_success.inc()
return gen_video_final

def __del__(self):
if hasattr(self, "model"):
del self.model
if hasattr(self, "text_encoders"):
del self.text_encoders
if hasattr(self, "image_encoder"):
del self.image_encoder
if hasattr(self, "vae_encoder"):
del self.vae_encoder
if hasattr(self, "vae_decoder"):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
Comment on lines +421 to +433
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using __del__ for resource cleanup, especially for GPU memory, is unreliable. The __del__ method is not guaranteed to be called when you expect it, due to Python's garbage collection behavior (e.g., circular references). This can lead to resource leaks. It's better to implement an explicit cleanup method, like cleanup() or close(), and ensure it's called deterministically when the runner is no longer needed.