Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hunyuan Video adjustments #11140

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
21 changes: 12 additions & 9 deletions examples/community/pipeline_stg_hunyuan_video.py
Original file line number Diff line number Diff line change
@@ -189,11 +189,14 @@ def retrieve_timesteps(
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
logger.warning(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
f" sigmas schedules. Please check whether you are using the correct scheduler. The pipeline"
f" will continue without setting sigma values"
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
else:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
@@ -723,9 +726,9 @@ def __call__(
timestep = t.expand(latents.shape[0]).to(latents.dtype)

if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_without_stg, self.transformer.transformer_blocks[i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note forward_without_stg vs forward_with_stg. Let's use something like stg_idx here so it doesn't conflict with index of enumerate(timesteps).

However, any results you have to share using this PR would be interesting, as it is using forward_with_stg for both noise_pred and noise_pred_perturb.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for the catch, I'll update. I found that this implementation became scheduler agnostic (or stg became the 'scheduler'?). I'll test that a bit more and see what the exact side effects are

for stg_idx in stg_applied_layers_idx:
self.transformer.transformer_blocks[stg_idx].forward = types.MethodType(
forward_without_stg, self.transformer.transformer_blocks[stg_idx]
)

noise_pred = self.transformer(
@@ -740,9 +743,9 @@ def __call__(
)[0]

if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[i]
for stg_idx in stg_applied_layers_idx:
self.transformer.transformer_blocks[stg_idx].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[stg_idx]
)

noise_pred_perturb = self.transformer(
Original file line number Diff line number Diff line change
@@ -139,11 +139,14 @@ def retrieve_timesteps(
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
logger.warning(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
f" sigmas schedules. Please check whether you are using the correct scheduler. The pipeline"
f" will continue without setting sigma values"
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
scheduler.set_timesteps(num_inference_steps, device=device)
else:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
Original file line number Diff line number Diff line change
@@ -128,11 +128,14 @@ def retrieve_timesteps(
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
logger.warning(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
f" sigmas schedules. Please check whether you are using the correct scheduler. The pipeline"
f" will continue without setting sigma values"
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
scheduler.set_timesteps(num_inference_steps, device=device)
else:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
Original file line number Diff line number Diff line change
@@ -141,11 +141,14 @@ def retrieve_timesteps(
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
logger.warning(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
f" sigmas schedules. Please check whether you are using the correct scheduler. The pipeline"
f" will continue without setting sigma values"
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
scheduler.set_timesteps(num_inference_steps, device=device)
else:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: