-
Notifications
You must be signed in to change notification settings - Fork 197
support neopp model 8 steps infer #1060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| { | ||
| "version": "dense", | ||
| "load_kv_cache_in_pipeline_for_debug": true, | ||
| "infer_steps": 8, | ||
| "attn_type": "flash_attn3", | ||
| "timestep_shift": 3.0, | ||
| "cfg_interval": [-1, 2], | ||
| "enable_cfg": false, | ||
| "use_triton_qknorm_rope": true, | ||
| "lora_configs": [ | ||
| { | ||
| "path": "/data/nvme1/yongyang/kkk/models/sensenova/SenseNova-U1-8B-MoT-LoRAs/SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors", | ||
| "strength": 1.0 | ||
| } | ||
| ] | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| from lightx2v import LightX2VPipeline | ||
|
|
||
| # ------------------------------------------------- | ||
| # Initialize pipeline for NeoPP | ||
| # ------------------------------------------------- | ||
|
|
||
| pipe = LightX2VPipeline( | ||
| model_path="/data/nvme1/yongyang/kkk/models/sensenova/SenseNova-U1-8B-MoT", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| model_cls="neopp", | ||
| support_tasks=["t2i", "i2i"], | ||
| ) | ||
|
|
||
| pipe.create_generator(config_json="../../configs/neopp/neopp_dense_8steps.json") | ||
| pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": True}) | ||
|
|
||
|
|
||
| # ------------------------------------------------- | ||
| # Load KV cache and generate | ||
| # ------------------------------------------------- | ||
|
|
||
| # ------------------------------------------------- | ||
| # TURN 0 | ||
| # ------------------------------------------------- | ||
| pipe.runner.load_kvcache( | ||
| "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_cond_kv_0_298.pt", | ||
| "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_uncond_kv_0_9.pt", | ||
| ) | ||
| pipe.runner.set_inference_params( | ||
| index_offset_cond=298, | ||
| index_offset_uncond=9, | ||
| cfg_interval=(-1, 2), | ||
| cfg_scale=4.0, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| cfg_norm="none", | ||
| timestep_shift=3.0, | ||
| ) | ||
|
|
||
| pipe.generate( | ||
| seed=200, | ||
| save_result_path="/data/nvme1/yongyang/kkk/models/LightX2V/save_results/output_lightx2v_neopp_dense_2k_0.png", | ||
| target_shape=[2048, 2048], # Height, Width | ||
| ) | ||
|
|
||
|
|
||
| # # ------------------------------------------------- | ||
| # # TURN 1 | ||
| # # ------------------------------------------------- | ||
| # pipe.runner.load_kvcache( | ||
| # "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_cond_kv_1_360.pt", | ||
| # "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_uncond_kv_1_12.pt", | ||
| # ) | ||
| # pipe.runner.set_inference_params( | ||
| # index_offset_cond=366, | ||
| # index_offset_uncond=12, | ||
| # cfg_interval=(-1, 2), | ||
| # cfg_scale=4.0, | ||
| # cfg_norm="none", | ||
| # timestep_shift=3.0, | ||
| # ) | ||
|
|
||
| # pipe.generate( | ||
| # seed=None, | ||
| # save_result_path="/data/nvme1/yongyang/kkk/LightX2V/save_results/output_lightx2v_neopp_dense_2k_1.png", | ||
| # target_shape=[2048, 2048], # Height, Width | ||
| # ) | ||
|
|
||
|
|
||
| # # ------------------------------------------------- | ||
| # # TURN 2 | ||
| # # ------------------------------------------------- | ||
| # pipe.runner.load_kvcache( | ||
| # "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_cond_kv_2_439.pt", | ||
| # "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_uncond_kv_2_15.pt", | ||
| # ) | ||
| # pipe.runner.set_inference_params( | ||
| # index_offset_cond=441, | ||
| # index_offset_uncond=15, | ||
| # cfg_interval=(-1, 2), | ||
| # cfg_scale=4.0, | ||
| # cfg_norm="none", | ||
| # timestep_shift=3.0, | ||
| # ) | ||
|
|
||
| # pipe.generate( | ||
| # seed=None, | ||
| # save_result_path="/data/nvme1/yongyang/kkk/LightX2V/save_results/output_lightx2v_neopp_dense_2k_2.png", | ||
| # target_shape=[2048, 2048], # Height, Width | ||
| # ) | ||
|
Comment on lines
+44
to
+87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -30,6 +30,7 @@ def __init__(self, config, lazy_load_path=None, lora_path=None): | |||||
| config=self.config, | ||||||
| mm_type=self.mm_type, | ||||||
| attn_type=self.attn_type, | ||||||
| lora_path=lora_path, | ||||||
| ) | ||||||
| for i in range(self.blocks_num) | ||||||
| ) | ||||||
|
|
@@ -47,7 +48,7 @@ def __init__(self, config, lazy_load_path=None, lora_path=None): | |||||
|
|
||||||
|
|
||||||
| class NeoppDecoderLayerWeights(WeightModule): | ||||||
| def __init__(self, block_index, config, mm_type, attn_type="flash_attn2"): | ||||||
| def __init__(self, block_index, config, mm_type, attn_type="flash_attn2", lora_path=None): | ||||||
| super().__init__() | ||||||
| prefix = f"language_model.model.layers.{block_index}" | ||||||
|
|
||||||
|
|
@@ -57,7 +58,7 @@ def __init__(self, block_index, config, mm_type, attn_type="flash_attn2"): | |||||
| ) | ||||||
|
|
||||||
| use_triton_qknorm_rope = config.get("use_triton_qknorm_rope", True) | ||||||
| attn = NeoppAttentionWeights(config, block_index, mm_type, attn_type, use_triton_qknorm_rope) | ||||||
| attn = NeoppAttentionWeights(config, block_index, mm_type, attn_type, use_triton_qknorm_rope, lora_path=lora_path) | ||||||
| self.add_module("self_attn", attn) | ||||||
|
|
||||||
| self.add_module( | ||||||
|
|
@@ -67,26 +68,27 @@ def __init__(self, block_index, config, mm_type, attn_type="flash_attn2"): | |||||
|
|
||||||
| if config["version"] == "moe": | ||||||
| gen_num_experts = int(config["llm_config"]["gen_num_experts"]) | ||||||
| mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts) | ||||||
| mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts, lora_path=lora_path) | ||||||
| elif config["version"] == "dense": | ||||||
| mlp_mot_gen = NeoppMlpWeights(block_index, mm_type) | ||||||
| mlp_mot_gen = NeoppMlpWeights(block_index, mm_type, lora_path=lora_path) | ||||||
| else: | ||||||
| raise ValueError(f"Unsupported version: {config['version']}") | ||||||
| self.add_module("mlp_mot_gen", mlp_mot_gen) | ||||||
|
|
||||||
|
|
||||||
| class NeoppAttentionWeights(WeightModule): | ||||||
| def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_triton_qknorm_rope=True): | ||||||
| def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_triton_qknorm_rope=True, lora_path=None): | ||||||
| super().__init__() | ||||||
| prefix = f"language_model.model.layers.{block_index}.self_attn" | ||||||
| lora_prefix = "language_model" | ||||||
|
|
||||||
| self.add_module("q_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.q_proj_mot_gen.weight", None)) | ||||||
| self.add_module("q_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.q_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
| self.add_module("k_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.k_proj_mot_gen.weight", None)) | ||||||
| self.add_module("k_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.k_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
| self.add_module("v_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.v_proj_mot_gen.weight", None)) | ||||||
| self.add_module("v_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.v_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
| self.add_module("o_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.o_proj_mot_gen.weight", None)) | ||||||
| self.add_module("o_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.o_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
| if use_triton_qknorm_rope: | ||||||
| # Fused triton kernel: single module holds all 4 norm weights and applies | ||||||
|
|
@@ -128,14 +130,15 @@ def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_tr | |||||
|
|
||||||
|
|
||||||
| class NeoppSparseMoeWeights(WeightModule): | ||||||
| def __init__(self, block_index, mm_type, subname, num_experts): | ||||||
| def __init__(self, block_index, mm_type, subname, num_experts, lora_path=None): | ||||||
| super().__init__() | ||||||
| prefix = f"language_model.model.layers.{block_index}.{subname}" | ||||||
| lora_prefix = "language_model" | ||||||
|
|
||||||
| self.add_module("gate", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate.weight", None)) | ||||||
| self.add_module("gate", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
| self.num_experts = num_experts | ||||||
| experts = WeightModuleList(NeoppMoeSingleExpertWeights(block_index, mm_type, subname, j) for j in range(num_experts)) | ||||||
| experts = WeightModuleList(NeoppMoeSingleExpertWeights(block_index, mm_type, subname, j, lora_path=lora_path) for j in range(num_experts)) | ||||||
| self.add_module("experts", experts) | ||||||
|
|
||||||
| def load(self, weight_dict): | ||||||
|
|
@@ -157,21 +160,23 @@ def _build_flashinfer_weights(self): | |||||
|
|
||||||
|
|
||||||
| class NeoppMoeSingleExpertWeights(WeightModule): | ||||||
| def __init__(self, block_index, mm_type, subname, expert_index): | ||||||
| def __init__(self, block_index, mm_type, subname, expert_index, lora_path=None): | ||||||
| super().__init__() | ||||||
| prefix = f"language_model.model.layers.{block_index}.{subname}.experts.{expert_index}" | ||||||
| self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None)) | ||||||
| self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None)) | ||||||
| self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None)) | ||||||
| lora_prefix = "language_model" | ||||||
| self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
| self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
| self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
|
|
||||||
| class NeoppMlpWeights(WeightModule): | ||||||
| def __init__(self, block_index, mm_type): | ||||||
| def __init__(self, block_index, mm_type, lora_path=None): | ||||||
| super().__init__() | ||||||
| prefix = f"language_model.model.layers.{block_index}.mlp_mot_gen" | ||||||
| self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None)) | ||||||
| self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None)) | ||||||
| self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None)) | ||||||
| lora_prefix = "language_model" | ||||||
| self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
| self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
| self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) | ||||||
|
|
||||||
| # def load(self, weight_dict): | ||||||
| # super().load(weight_dict) | ||||||
|
|
@@ -186,11 +191,13 @@ def __init__(self, block_index, mm_type): | |||||
| class NeoppFmHeadWeights(WeightModule): | ||||||
| def __init__(self, mm_type): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| super().__init__() | ||||||
| lora_prefix = "fm_modules" | ||||||
| self.add_module( | ||||||
| "fm_head_0", | ||||||
| MM_WEIGHT_REGISTER["Default"]( | ||||||
| "fm_modules.fm_head.0.weight", | ||||||
| "fm_modules.fm_head.0.bias", | ||||||
| lora_prefix=lora_prefix, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -199,5 +206,6 @@ def __init__(self, mm_type): | |||||
| MM_WEIGHT_REGISTER["Default"]( | ||||||
| "fm_modules.fm_head.2.weight", | ||||||
| "fm_modules.fm_head.2.bias", | ||||||
| lora_prefix=lora_prefix, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| ), | ||||||
| ) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import torchvision.io as io | ||
| from PIL import Image | ||
|
|
||
| from lightx2v.models.networks.lora_adapter import LoraAdapter | ||
| from lightx2v.models.networks.neopp.model import NeoppModel | ||
| from lightx2v.models.runners.default_runner import DefaultRunner | ||
| from lightx2v.models.schedulers.neopp.scheduler import NeoppMoeScheduler | ||
|
|
@@ -16,6 +17,24 @@ | |
| from lightx2v_platform.base.global_var import AI_DEVICE | ||
|
|
||
|
|
||
| def build_neopp_model_with_lora(neopp_module, config, model_kwargs, lora_configs): | ||
| lora_dynamic_apply = config.get("lora_dynamic_apply", False) | ||
|
|
||
| if lora_dynamic_apply: | ||
| lora_path = lora_configs[0]["path"] | ||
| lora_strength = lora_configs[0]["strength"] | ||
| model_kwargs["lora_path"] = lora_path | ||
| model_kwargs["lora_strength"] = lora_strength | ||
|
Comment on lines
+24
to
+27
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| model = neopp_module(**model_kwargs) | ||
| else: | ||
| assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported." | ||
| assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging." | ||
| model = neopp_module(**model_kwargs) | ||
| lora_adapter = LoraAdapter(model) | ||
| lora_adapter.apply_lora(lora_configs) | ||
| return model | ||
|
|
||
|
|
||
| @RUNNER_REGISTER("neopp") | ||
| class NeoppRunner(DefaultRunner): | ||
| def __init__(self, config): | ||
|
|
@@ -53,7 +72,16 @@ def load_transformer(self): | |
| MoT: Mixture-of-Transformer-Experts (MoT) architecture | ||
| https://arxiv.org/abs/2505.14683 | ||
| """ | ||
| model = NeoppModel(self.config["model_path"], self.config, self.init_device) | ||
| neopp_model_kwargs = { | ||
| "model_path": self.config["model_path"], | ||
| "config": self.config, | ||
| "device": self.init_device, | ||
| } | ||
| lora_configs = self.config.get("lora_configs") | ||
| if not lora_configs: | ||
| model = NeoppModel(**neopp_model_kwargs) | ||
| else: | ||
| model = build_neopp_model_with_lora(NeoppModel, self.config, neopp_model_kwargs, lora_configs) | ||
| return model | ||
|
|
||
| def _build_inv_freq(self, half_head_dim, theta): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LoRA path is hardcoded to a specific user's directory (
/data/nvme1/yongyang/...). This makes the configuration file non-portable. Consider using a relative path or a placeholder that can be resolved at runtime.