Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions configs/neopp/neopp_dense_8steps.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"version": "dense",
"load_kv_cache_in_pipeline_for_debug": true,
"infer_steps": 8,
"attn_type": "flash_attn3",
"timestep_shift": 3.0,
"cfg_interval": [-1, 2],
"enable_cfg": false,
"use_triton_qknorm_rope": true,
"lora_configs": [
{
"path": "/data/nvme1/yongyang/kkk/models/sensenova/SenseNova-U1-8B-MoT-LoRAs/SenseNova-U1-8B-MoT-LoRA-8step-V1.0.safetensors",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The LoRA path is hardcoded to a specific user's directory (/data/nvme1/yongyang/...). This makes the configuration file non-portable. Consider using a relative path or a placeholder that can be resolved at runtime.

"strength": 1.0
}
]
}
87 changes: 87 additions & 0 deletions examples/neopp/neopp_dense_2k_8steps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from lightx2v import LightX2VPipeline

# -------------------------------------------------
# Initialize pipeline for NeoPP
# -------------------------------------------------

pipe = LightX2VPipeline(
model_path="/data/nvme1/yongyang/kkk/models/sensenova/SenseNova-U1-8B-MoT",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This example script contains several hardcoded absolute paths (e.g., lines 8, 25, 26, 39) pointing to a specific user's environment. To improve reproducibility and portability, consider using relative paths or environment variables.

model_cls="neopp",
support_tasks=["t2i", "i2i"],
)

pipe.create_generator(config_json="../../configs/neopp/neopp_dense_8steps.json")
pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": True})


# -------------------------------------------------
# Load KV cache and generate
# -------------------------------------------------

# -------------------------------------------------
# TURN 0
# -------------------------------------------------
pipe.runner.load_kvcache(
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_cond_kv_0_298.pt",
"/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_uncond_kv_0_9.pt",
)
pipe.runner.set_inference_params(
index_offset_cond=298,
index_offset_uncond=9,
cfg_interval=(-1, 2),
cfg_scale=4.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cfg_scale is set to 4.0, but the corresponding configuration file neopp_dense_8steps.json has "enable_cfg": false. In the NeoppModel.infer implementation, the CFG logic is entirely skipped if enable_cfg is False. This might be confusing for users; consider aligning the example with the config or enabling CFG if it's intended to be used.

cfg_norm="none",
timestep_shift=3.0,
)

pipe.generate(
seed=200,
save_result_path="/data/nvme1/yongyang/kkk/models/LightX2V/save_results/output_lightx2v_neopp_dense_2k_0.png",
target_shape=[2048, 2048], # Height, Width
)


# # -------------------------------------------------
# # TURN 1
# # -------------------------------------------------
# pipe.runner.load_kvcache(
# "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_cond_kv_1_360.pt",
# "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_uncond_kv_1_12.pt",
# )
# pipe.runner.set_inference_params(
# index_offset_cond=366,
# index_offset_uncond=12,
# cfg_interval=(-1, 2),
# cfg_scale=4.0,
# cfg_norm="none",
# timestep_shift=3.0,
# )

# pipe.generate(
# seed=None,
# save_result_path="/data/nvme1/yongyang/kkk/LightX2V/save_results/output_lightx2v_neopp_dense_2k_1.png",
# target_shape=[2048, 2048], # Height, Width
# )


# # -------------------------------------------------
# # TURN 2
# # -------------------------------------------------
# pipe.runner.load_kvcache(
# "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_cond_kv_2_439.pt",
# "/data/nvme1/yongyang/FL/neo_9b_new/vlm_tensor_44000_ema_2k/to_x2v_uncond_kv_2_15.pt",
# )
# pipe.runner.set_inference_params(
# index_offset_cond=441,
# index_offset_uncond=15,
# cfg_interval=(-1, 2),
# cfg_scale=4.0,
# cfg_norm="none",
# timestep_shift=3.0,
# )

# pipe.generate(
# seed=None,
# save_result_path="/data/nvme1/yongyang/kkk/LightX2V/save_results/output_lightx2v_neopp_dense_2k_2.png",
# target_shape=[2048, 2048], # Height, Width
# )
Comment on lines +44 to +87
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are large blocks of commented-out code (Turns 1 and 2). If these are not essential for the example, they should be removed to maintain code cleanliness.

78 changes: 42 additions & 36 deletions lightx2v/models/networks/neopp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ class NeoppModel(BaseTransformerModel):
transformer_weight_class = NeoppTransformerWeights
post_weight_class = NeoppPostWeights

def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def __init__(self, model_path, config, device, lora_path=None, lora_strength=1.0):
super().__init__(model_path, config, device, None, lora_path, lora_strength)
self.preserved_keys = ["fm_modules", "mot_gen"]
self._init_infer_class()
self._init_infer()
self._init_weights()
self.enable_cfg = self.config.get("enable_cfg", True)
self.cfg_interval = self.config.get("cfg_interval", (-1, 2))
self.cfg_scale = self.config.get("cfg_scale", 4.0)
self.cfg_norm = self.config.get("cfg_norm", "global")
Expand Down Expand Up @@ -89,9 +90,6 @@ def cfg_norm_func(self, v_pred, v_pred_condition):
return v_pred

def _infer_t2i_i2i(self, inputs, pre_infer_out):
t = self.scheduler.timesteps[self.scheduler.step_index]
use_cfg = t >= self.cfg_interval[0] and t <= self.cfg_interval[1] and self.cfg_scale > 1

# 预计算各 pass 的 image_embeds:seq_parallel 时切分为本 rank 的 shard,否则直接引用原张量
# 这样 _infer_cond_uncond 无需在每次调用时反复 chunk/restore,避免多次 pass 间互相污染
if self.seq_p_group is not None:
Expand All @@ -109,40 +107,48 @@ def _infer_t2i_i2i(self, inputs, pre_infer_out):
pre_infer_out.image_embeds_cond = pre_infer_out.image_embeds
pre_infer_out.image_embeds_uncond = pre_infer_out.image_embeds

if self.config.get("cfg_parallel", False):
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
# assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)

cfg_p_world_size = dist.get_world_size(cfg_p_group)
if use_cfg:
if cfg_p_rank == 0:
v_pred = self._infer_cond_uncond(inputs, pre_infer_out, True)
if self.enable_cfg:
t = self.scheduler.timesteps[self.scheduler.step_index]
use_cfg = t >= self.cfg_interval[0] and t <= self.cfg_interval[1] and self.cfg_scale > 1

if self.config.get("cfg_parallel", False):
# ==================== CFG Parallel Processing ====================
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group)
cfg_p_world_size = dist.get_world_size(cfg_p_group)

if use_cfg:
if cfg_p_rank == 0:
v_pred = self._infer_cond_uncond(inputs, pre_infer_out, infer_condition=True)
else:
v_pred = self._infer_cond_uncond(inputs, pre_infer_out, infer_condition=False)
v_pred_list = [torch.zeros_like(v_pred) for _ in range(cfg_p_world_size)]
dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
v_pred_cond, v_pred_uncond = v_pred_list[0], v_pred_list[1]
v_pred = v_pred_uncond + self.cfg_scale * (v_pred_cond - v_pred_uncond)
v_pred = self.cfg_norm_func(v_pred, v_pred_cond)
return v_pred
else:
v_pred = self._infer_cond_uncond(inputs, pre_infer_out, False)
v_pred_list = [torch.zeros_like(v_pred) for _ in range(cfg_p_world_size)]
dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
v_pred_cond, v_pred_uncond = v_pred_list[0], v_pred_list[1]
v_pred = v_pred_uncond + self.cfg_scale * (v_pred_cond - v_pred_uncond)
v_pred = self.cfg_norm_func(v_pred, v_pred_cond)
return v_pred
# cfg 区间外只有 rank 0 做 cond 推理,其余 rank 用 all_gather 接收结果
if cfg_p_rank == 0:
v_pred = self._infer_cond_uncond(inputs, pre_infer_out, infer_condition=True)
else:
v_pred = torch.zeros_like(pre_infer_out.z)
v_pred_list = [torch.zeros_like(v_pred) for _ in range(cfg_p_world_size)]
dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
return v_pred_list[0]
else:
# cfg 区间外只有 rank 0 做 cond 推理,其余 rank 用 all_gather 接收结果,无需持有 cond kvcache
if cfg_p_rank == 0:
v_pred = self._infer_cond_uncond(inputs, pre_infer_out, True)
else:
v_pred = torch.zeros_like(pre_infer_out.z)
v_pred_list = [torch.zeros_like(v_pred) for _ in range(cfg_p_world_size)]
dist.all_gather(v_pred_list, v_pred, group=cfg_p_group)
return v_pred_list[0]
# ==================== CFG Processing ====================
v_pred_cond = self._infer_cond_uncond(inputs, pre_infer_out, infer_condition=True)
if use_cfg:
v_pred_uncond = self._infer_cond_uncond(inputs, pre_infer_out, infer_condition=False)
v_pred = v_pred_uncond + self.cfg_scale * (v_pred_cond - v_pred_uncond)
v_pred = self.cfg_norm_func(v_pred, v_pred_cond)
return v_pred
return v_pred_cond
else:
v_pred_condition = self._infer_cond_uncond(inputs, pre_infer_out, True)
if use_cfg:
v_pred_uncond = self._infer_cond_uncond(inputs, pre_infer_out, False)
v_pred = v_pred_uncond + self.cfg_scale * (v_pred_condition - v_pred_uncond)
v_pred = self.cfg_norm_func(v_pred, v_pred_condition)
return v_pred
return v_pred_condition
# ==================== No CFG Processing ====================
return self._infer_cond_uncond(inputs, pre_infer_out, infer_condition=True)

def _infer_cond_uncond(self, inputs, pre_infer_out, infer_condition: bool):
self.scheduler.infer_condition = infer_condition
Expand Down
48 changes: 28 additions & 20 deletions lightx2v/models/networks/neopp/weights/transformer_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, config, lazy_load_path=None, lora_path=None):
config=self.config,
mm_type=self.mm_type,
attn_type=self.attn_type,
lora_path=lora_path,
)
for i in range(self.blocks_num)
)
Expand All @@ -47,7 +48,7 @@ def __init__(self, config, lazy_load_path=None, lora_path=None):


class NeoppDecoderLayerWeights(WeightModule):
def __init__(self, block_index, config, mm_type, attn_type="flash_attn2"):
def __init__(self, block_index, config, mm_type, attn_type="flash_attn2", lora_path=None):
super().__init__()
prefix = f"language_model.model.layers.{block_index}"

Expand All @@ -57,7 +58,7 @@ def __init__(self, block_index, config, mm_type, attn_type="flash_attn2"):
)

use_triton_qknorm_rope = config.get("use_triton_qknorm_rope", True)
attn = NeoppAttentionWeights(config, block_index, mm_type, attn_type, use_triton_qknorm_rope)
attn = NeoppAttentionWeights(config, block_index, mm_type, attn_type, use_triton_qknorm_rope, lora_path=lora_path)
self.add_module("self_attn", attn)

self.add_module(
Expand All @@ -67,26 +68,27 @@ def __init__(self, block_index, config, mm_type, attn_type="flash_attn2"):

if config["version"] == "moe":
gen_num_experts = int(config["llm_config"]["gen_num_experts"])
mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts)
mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts, lora_path=lora_path)
elif config["version"] == "dense":
mlp_mot_gen = NeoppMlpWeights(block_index, mm_type)
mlp_mot_gen = NeoppMlpWeights(block_index, mm_type, lora_path=lora_path)
else:
raise ValueError(f"Unsupported version: {config['version']}")
self.add_module("mlp_mot_gen", mlp_mot_gen)


class NeoppAttentionWeights(WeightModule):
def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_triton_qknorm_rope=True):
def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_triton_qknorm_rope=True, lora_path=None):
super().__init__()
prefix = f"language_model.model.layers.{block_index}.self_attn"
lora_prefix = "language_model"

self.add_module("q_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.q_proj_mot_gen.weight", None))
self.add_module("q_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.q_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))

self.add_module("k_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.k_proj_mot_gen.weight", None))
self.add_module("k_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.k_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))

self.add_module("v_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.v_proj_mot_gen.weight", None))
self.add_module("v_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.v_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))

self.add_module("o_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.o_proj_mot_gen.weight", None))
self.add_module("o_proj_mot_gen", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.o_proj_mot_gen.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))

if use_triton_qknorm_rope:
# Fused triton kernel: single module holds all 4 norm weights and applies
Expand Down Expand Up @@ -128,14 +130,15 @@ def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_tr


class NeoppSparseMoeWeights(WeightModule):
def __init__(self, block_index, mm_type, subname, num_experts):
def __init__(self, block_index, mm_type, subname, num_experts, lora_path=None):
super().__init__()
prefix = f"language_model.model.layers.{block_index}.{subname}"
lora_prefix = "language_model"

self.add_module("gate", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate.weight", None))
self.add_module("gate", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))

self.num_experts = num_experts
experts = WeightModuleList(NeoppMoeSingleExpertWeights(block_index, mm_type, subname, j) for j in range(num_experts))
experts = WeightModuleList(NeoppMoeSingleExpertWeights(block_index, mm_type, subname, j, lora_path=lora_path) for j in range(num_experts))
self.add_module("experts", experts)

def load(self, weight_dict):
Expand All @@ -157,21 +160,23 @@ def _build_flashinfer_weights(self):


class NeoppMoeSingleExpertWeights(WeightModule):
def __init__(self, block_index, mm_type, subname, expert_index):
def __init__(self, block_index, mm_type, subname, expert_index, lora_path=None):
super().__init__()
prefix = f"language_model.model.layers.{block_index}.{subname}.experts.{expert_index}"
self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None))
self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None))
self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None))
lora_prefix = "language_model"
self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))
self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))
self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))


class NeoppMlpWeights(WeightModule):
def __init__(self, block_index, mm_type):
def __init__(self, block_index, mm_type, lora_path=None):
super().__init__()
prefix = f"language_model.model.layers.{block_index}.mlp_mot_gen"
self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None))
self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None))
self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None))
lora_prefix = "language_model"
self.add_module("gate_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))
self.add_module("up_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.up_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))
self.add_module("down_proj", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.down_proj.weight", None, lora_prefix=lora_prefix, lora_path=lora_path))

# def load(self, weight_dict):
# super().load(weight_dict)
Expand All @@ -186,11 +191,13 @@ def __init__(self, block_index, mm_type):
class NeoppFmHeadWeights(WeightModule):
def __init__(self, mm_type):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The NeoppFmHeadWeights class is missing the lora_path parameter in its __init__ method. This will prevent LoRA weights from being correctly associated with the FM head modules when using lazy loading. Note that you will also need to update the call site in NeoppTransformerWeights (around line 46) to pass this parameter.

Suggested change
def __init__(self, mm_type):
def __init__(self, mm_type, lora_path=None):

super().__init__()
lora_prefix = "fm_modules"
self.add_module(
"fm_head_0",
MM_WEIGHT_REGISTER["Default"](
"fm_modules.fm_head.0.weight",
"fm_modules.fm_head.0.bias",
lora_prefix=lora_prefix,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Pass the lora_path to the weight register to ensure LoRA support for the FM head modules.

Suggested change
lora_prefix=lora_prefix,
lora_prefix=lora_prefix, lora_path=lora_path,

),
)

Expand All @@ -199,5 +206,6 @@ def __init__(self, mm_type):
MM_WEIGHT_REGISTER["Default"](
"fm_modules.fm_head.2.weight",
"fm_modules.fm_head.2.bias",
lora_prefix=lora_prefix,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Pass the lora_path to the weight register to ensure LoRA support for the FM head modules.

Suggested change
lora_prefix=lora_prefix,
lora_prefix=lora_prefix, lora_path=lora_path,

),
)
30 changes: 29 additions & 1 deletion lightx2v/models/runners/neopp/neopp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torchvision.io as io
from PIL import Image

from lightx2v.models.networks.lora_adapter import LoraAdapter
from lightx2v.models.networks.neopp.model import NeoppModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.neopp.scheduler import NeoppMoeScheduler
Expand All @@ -16,6 +17,24 @@
from lightx2v_platform.base.global_var import AI_DEVICE


def build_neopp_model_with_lora(neopp_module, config, model_kwargs, lora_configs):
lora_dynamic_apply = config.get("lora_dynamic_apply", False)

if lora_dynamic_apply:
lora_path = lora_configs[0]["path"]
lora_strength = lora_configs[0]["strength"]
model_kwargs["lora_path"] = lora_path
model_kwargs["lora_strength"] = lora_strength
Comment on lines +24 to +27
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When lora_dynamic_apply is True, this function only processes the first entry in lora_configs. If multiple LoRAs are provided, the subsequent ones will be silently ignored. Consider iterating through the list or adding a check to warn the user if multiple LoRAs are supplied.

model = neopp_module(**model_kwargs)
else:
assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported."
assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging."
model = neopp_module(**model_kwargs)
lora_adapter = LoraAdapter(model)
lora_adapter.apply_lora(lora_configs)
return model


@RUNNER_REGISTER("neopp")
class NeoppRunner(DefaultRunner):
def __init__(self, config):
Expand Down Expand Up @@ -53,7 +72,16 @@ def load_transformer(self):
MoT: Mixture-of-Transformer-Experts (MoT) architecture
https://arxiv.org/abs/2505.14683
"""
model = NeoppModel(self.config["model_path"], self.config, self.init_device)
neopp_model_kwargs = {
"model_path": self.config["model_path"],
"config": self.config,
"device": self.init_device,
}
lora_configs = self.config.get("lora_configs")
if not lora_configs:
model = NeoppModel(**neopp_model_kwargs)
else:
model = build_neopp_model_with_lora(NeoppModel, self.config, neopp_model_kwargs, lora_configs)
return model

def _build_inv_freq(self, half_head_dim, theta):
Expand Down
Loading