Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/z_image/z_image_turbo_t2i.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"enable_cfg": false,
"sample_guide_scale": 0.0,
"patch_size": 2
}
}
8 changes: 4 additions & 4 deletions lightx2v/common/modules/weight_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from lightx2v_platform.base.global_var import AI_DEVICE


class WeightModule:
def __init__(self):
self._modules = {}
Expand Down Expand Up @@ -122,11 +124,10 @@ def to_cpu(self, non_blocking=False):

def to_cuda(self, non_blocking=False):
"""Move parameters to GPU device (supports cuda/intel xpu)"""
target_device = AI_DEVICE
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.to(target_device, non_blocking=non_blocking)
self._parameters[name] = param.to(AI_DEVICE, non_blocking=non_blocking)
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda()
setattr(self, name, self._parameters[name])
Expand Down Expand Up @@ -166,11 +167,10 @@ def to_cpu_async(self, non_blocking=True):
module.to_cpu(non_blocking=True)

def to_cuda_async(self, non_blocking=True):
target_device = AI_DEVICE
for name, param in self._parameters.items():
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.to(target_device, non_blocking=non_blocking)
self._parameters[name] = param.to(AI_DEVICE, non_blocking=non_blocking)
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda(non_blocking=True)
setattr(self, name, self._parameters[name])
Expand Down
24 changes: 19 additions & 5 deletions lightx2v/common/ops/conv/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from lightx2v.common.ops.utils import *
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE


class Conv3dWeightTemplate(metaclass=ABCMeta):
Expand Down Expand Up @@ -35,8 +34,23 @@ def _get_lora_attr_mapping(self):
"weight_diff": "weight_diff_name",
"bias_diff": "bias_diff_name",
}
self.weight_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)
self.bias_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)

def _get_actual_weight(self):
if not hasattr(self, "weight_diff"):
return self.weight
return self.weight + self.weight_diff

def _get_actual_bias(self, bias=None):
if bias is not None:
if not hasattr(self, "bias_diff"):
return bias
return bias + self.bias_diff
else:
if not hasattr(self, "bias") or self.bias is None:
return None
if not hasattr(self, "bias_diff"):
return self.bias
return self.bias + self.bias_diff
Comment on lines +43 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in _get_actual_bias can be simplified to improve readability and reduce duplication. The current implementation repeats the check for bias_diff. Consider refactoring to first determine the base bias and then apply the diff if it exists.

Suggested change
def _get_actual_bias(self, bias=None):
if bias is not None:
if not hasattr(self, "bias_diff"):
return bias
return bias + self.bias_diff
else:
if not hasattr(self, "bias") or self.bias is None:
return None
if not hasattr(self, "bias_diff"):
return self.bias
return self.bias + self.bias_diff
def _get_actual_bias(self, bias=None):
if bias is None:
if not hasattr(self, "bias") or self.bias is None:
return None
base_bias = self.bias
else:
base_bias = bias
if hasattr(self, "bias_diff"):
return base_bias + self.bias_diff
return base_bias


def register_diff(self, weight_dict):
if self.weight_diff_name in weight_dict:
Expand Down Expand Up @@ -74,8 +88,8 @@ def load(self, weight_dict):
def apply(self, input_tensor):
output_tensor = torch.nn.functional.conv3d(
input_tensor,
weight=self.weight + self.weight_diff,
bias=self.bias + self.bias_diff,
weight=self._get_actual_weight(),
bias=self._get_actual_bias(),
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
Expand Down
2 changes: 0 additions & 2 deletions lightx2v/models/networks/wan/sf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
class WanSFModel(WanModel):
def __init__(self, model_path, config, device, lora_path=None, lora_strength=1.0):
super().__init__(model_path, config, device, lora_path=lora_path, lora_strength=lora_strength)
if config["model_cls"] not in ["wan2.1_sf_mtxg2"]:
self.to_cuda()

def _load_ckpt(self, unified_dtype, sensitive_layer):
sf_confg = self.config["sf_config"]
Expand Down
5 changes: 3 additions & 2 deletions lightx2v_platform/base/intel_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import torch
import torch.distributed as dist
from loguru import logger

from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
Expand Down Expand Up @@ -58,5 +57,7 @@ def get_device() -> str:
# """
# dist.init_process_group(backend="ccl")
# torch.xpu.set_device(dist.get_rank())


# Register alias "xpu" for backward compatibility
PLATFORM_DEVICE_REGISTER._dict["xpu"] = IntelXpuDevice
PLATFORM_DEVICE_REGISTER._dict["xpu"] = IntelXpuDevice