diff --git a/README.md b/README.md index 4d776ca2..c7e621de 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2025/8/14`**: LTX-Video img2vid generation is now supported. - **`2025/7/29`**: LTX-Video text2vid generation is now supported. - **`2025/04/17`**: Flux Finetuning. - **`2025/02/12`**: Flux LoRA for inference. @@ -42,7 +43,7 @@ MaxDiffusion supports * Load Multiple LoRA (SDXL inference). * ControlNet inference (Stable Diffusion 1.4 & SDXL). * Dreambooth training support for Stable Diffusion 1.x,2.x. -* LTX-Video text2vid (inference). +* LTX-Video text2vid, img2vid (inference). # Table of Contents @@ -183,7 +184,8 @@ To generate images, run the following command: ```bash python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json" ``` - - Other generation parameters can be set in ltx_video.yml file. + - Img2video Generation: + Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above. ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 5ed82c66..71316ea1 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -22,7 +22,7 @@ sampler: "from_checkpoint" # Generation parameters pipeline_type: multi-scale -prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. " +prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." #negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" height: 512 width: 512 @@ -35,6 +35,8 @@ stg_mode: "attention_values" decode_timestep: 0.05 decode_noise_scale: 0.025 seed: 10 +conditioning_media_paths: None #["IMAGE_PATH"] +conditioning_start_frames: [0] first_pass: diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 553d6373..6ecc6666 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -16,15 +16,19 @@ import numpy as np from absl import app -from typing import Sequence +from typing import Sequence, List, Optional, Union from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline -from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline +from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem +import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor from maxdiffusion import pyconfig, max_logging +import torchvision.transforms.functional as TVF import imageio from datetime import datetime import os import time from pathlib import Path +from PIL import Image +import torch def calculate_padding( @@ -44,6 +48,79 @@ def calculate_padding( return padding +def load_image_to_tensor_with_resize_and_crop( + image_input: Union[str, Image.Image], + target_height: int = 512, + target_width: int = 768, + just_crop: bool = False, +) -> torch.Tensor: + """Load and process an image into a tensor. + + Args: + image_input: Either a file path (str) or a PIL Image object + target_height: Desired height of output tensor + target_width: Desired width of output tensor + just_crop: If True, only crop the image to the target size without resizing + """ + if isinstance(image_input, str): + image = Image.open(image_input).convert("RGB") + elif isinstance(image_input, Image.Image): + image = image_input + else: + raise ValueError("image_input must be either a file path or a PIL Image object") + + input_width, input_height = image.size + aspect_ratio_target = target_width / target_height + aspect_ratio_frame = input_width / input_height + if aspect_ratio_frame > aspect_ratio_target: + new_width = int(input_height * aspect_ratio_target) + new_height = input_height + x_start = (input_width - new_width) // 2 + y_start = 0 + else: + new_width = input_width + new_height = int(input_width / aspect_ratio_target) + x_start = 0 + y_start = (input_height - new_height) // 2 + + image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) + if not just_crop: + image = image.resize((target_width, target_height)) + + frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1] + frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0) + frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C) + frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc) + frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W) + frame_tensor = (frame_tensor / 127.5) - 1.0 + # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) + return frame_tensor.unsqueeze(0).unsqueeze(2) + + +def prepare_conditioning( + conditioning_media_paths: List[str], + conditioning_strengths: List[float], + conditioning_start_frames: List[int], + height: int, + width: int, + padding: tuple[int, int, int, int], +) -> Optional[List[ConditioningItem]]: + """Prepare conditioning items based on input media paths and their parameters.""" + conditioning_items = [] + for path, strength, start_frame in zip(conditioning_media_paths, conditioning_strengths, conditioning_start_frames): + num_input_frames = 1 + media_tensor = load_media_file( + media_path=path, + height=height, + width=width, + max_frames=num_input_frames, + padding=padding, + just_crop=True, + ) + conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength)) + return conditioning_items + + def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: # Remove non-letters and convert to lowercase clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) @@ -68,6 +145,19 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: return "-".join(result) +def load_media_file( + media_path: str, + height: int, + width: int, + max_frames: int, + padding: tuple[int, int, int, int], + just_crop: bool = False, +) -> torch.Tensor: + media_tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width, just_crop=just_crop) + media_tensor = torch.nn.functional.pad(media_tensor, padding) + return media_tensor + + def get_unique_filename( base: str, ext: str, @@ -97,6 +187,25 @@ def run(config): pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) if config.pipeline_type == "multi-scale": pipeline = LTXMultiScalePipeline(pipeline) + conditioning_media_paths = config.conditioning_media_paths if isinstance(config.conditioning_media_paths, List) else None + conditioning_start_frames = config.conditioning_start_frames + conditioning_strengths = None + if conditioning_media_paths: + if not conditioning_strengths: + conditioning_strengths = [1.0] * len(conditioning_media_paths) + conditioning_items = ( + prepare_conditioning( + conditioning_media_paths=conditioning_media_paths, + conditioning_strengths=conditioning_strengths, + conditioning_start_frames=conditioning_start_frames, + height=config.height, + width=config.width, + padding=padding, + ) + if conditioning_media_paths + else None + ) + s0 = time.perf_counter() images = pipeline( height=height_padded, @@ -106,6 +215,7 @@ def run(config): output_type="pt", config=config, enhance_prompt=enhance_prompt, + conditioning_items=conditioning_items, seed=config.seed, ) max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.") diff --git a/src/maxdiffusion/pipelines/ltx_video/crf_compressor.py b/src/maxdiffusion/pipelines/ltx_video/crf_compressor.py new file mode 100644 index 00000000..50cc2fef --- /dev/null +++ b/src/maxdiffusion/pipelines/ltx_video/crf_compressor.py @@ -0,0 +1,57 @@ +# Copyright 2025 Lightricks Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This implementation is based on the Torch version available at: +# https://github.com/Lightricks/LTX-Video/tree/main +import av +import torch +import io +import numpy as np + + +def _encode_single_frame(output_file, image_array: np.ndarray, crf): + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}) + stream.height = image_array.shape[0] + stream.width = image_array.shape[1] + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p") + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def _decode_single_frame(video_file): + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def compress(image: torch.Tensor, crf=29): + if crf == 0: + return image + + image_array = (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() + with io.BytesIO() as output_file: + _encode_single_frame(output_file, image_array, crf) + video_bytes = output_file.getvalue() + with io.BytesIO(video_bytes) as video_file: + image_array = _decode_single_frame(video_file) + tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 + return tensor diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 1e0abe69..1b8f4deb 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -21,6 +21,7 @@ from transformers import (FlaxT5EncoderModel, AutoTokenizer) from torchax import interop from torchax import default_env +from dataclasses import dataclass import json import numpy as np import torch @@ -51,11 +52,12 @@ from maxdiffusion.models.ltx_video.transformers.symmetric_patchifier import SymmetricPatchifier from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...pyconfig import HyperParameters -from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState +from ...schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler, RectifiedFlowSchedulerState, FlaxRectifiedFlowSchedulerOutput from ...max_utils import (create_device_mesh, setup_initial_state, get_memory_allocations) from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel import functools import orbax.checkpoint as ocp +from jax.lax import dynamic_update_slice def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids): @@ -66,6 +68,26 @@ def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encod max_logging.log(f"encoder_attention_segment_ids.shape: {encoder_attention_segment_ids.shape}") # (3, 256) int32 +@dataclass +class ConditioningItem: + """ + Defines a single frame-conditioning item - a single frame or a sequence of frames. + + Attributes: + media_item (torch.Tensor): shape=(b, 3, f, h, w). The media item to condition on. + media_frame_number (int): The start-frame number of the media item in the generated video. + conditioning_strength (float): The strength of the conditioning (1.0 = full conditioning). + media_x (Optional[int]): Optional left x coordinate of the media item in the generated frame. + media_y (Optional[int]): Optional top y coordinate of the media item in the generated frame. + """ + + media_item: torch.Tensor + media_frame_number: int + conditioning_strength: float + media_x: Optional[int] = None + media_y: Optional[int] = None + + class LTXVideoPipeline: def __init__( @@ -276,33 +298,6 @@ def process(text: str): return [process(t) for t in text] - def denoising_step( - scheduler, - latents: Array, - noise_pred: Array, - current_timestep: Optional[Array], - conditioning_mask: Optional[Array], - t: float, - extra_step_kwargs: Dict, - t_eps: float = 1e-6, - stochastic_sampling: bool = False, - ) -> Array: - # Denoise the latents using the scheduler - denoised_latents = scheduler.step( - noise_pred, - t if current_timestep is None else current_timestep, - latents, - **extra_step_kwargs, - stochastic_sampling=stochastic_sampling, - ) - - if conditioning_mask is None: - return denoised_latents - - tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) - tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) - return jnp.where(tokens_to_denoise_mask, denoised_latents, latents) - def retrieve_timesteps( # currently doesn't support custom timesteps self, scheduler: FlaxRectifiedFlowMultistepScheduler, @@ -448,20 +443,141 @@ def prepare_latents( # currently no support for media item encoding, since enco return latents - def prepare_conditioning( # no support for conditioning items, conditioning mask, needs to convert to torch before patchifier + def prepare_conditioning( # needs to convert to torch before patchifier self, init_latents: jnp.ndarray, - ) -> Tuple[jnp.ndarray, jnp.ndarray, int]: - assert isinstance(self.vae, TorchaxCausalVideoAutoencoder) + conditioning_items: Optional[List[ConditioningItem]] = None, + height: int = None, + width: int = None, + num_frames: int = None, + vae_per_channel_normalize: bool = False, + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]: + if conditioning_items: + init_conditioning_mask = jnp.zeros(init_latents[:, 0, :, :, :].shape) # in jax + + # Process each conditioning item + for conditioning_item in conditioning_items: + # resize conditioning_item + media_items = conditioning_item.media_item + n_frames = media_items.shape[2] + if media_items.shape[-2:] != (height, width): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + media_items = F.interpolate( + media_items, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + media_items = rearrange(media_items, "(b n) c h w -> b c n h w", n=n_frames) + conditioning_item.media_item = media_items + media_item = conditioning_item.media_item + media_frame_number = conditioning_item.media_frame_number + strength = conditioning_item.conditioning_strength + assert media_item.ndim == 5 # (b, c, f, h, w) + b, c, n_frames, h, w = media_item.shape + assert ( + height == h and width == w + ) or media_frame_number == 0, ( + f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0" + ) + assert n_frames % 8 == 1 + assert media_frame_number >= 0 and media_frame_number + n_frames <= num_frames + + # Encode the provided conditioning media item + media_item_latents = self.vae.encode( + jax.device_put(jnp.array(np.array(media_item)), jax.devices("tpu")[0]).astype(jnp.bfloat16), + vae_per_channel_normalize=vae_per_channel_normalize, + ) + + # # Handle the different conditioning cases + if media_frame_number == 0: + # Get the target spatial position of the latent conditioning item + media_item_latents, l_x, l_y = self.get_latent_spatial_position( + media_item_latents, + conditioning_item, + height, + width, + strip_latent_border=True, + ) + b, c_l, f_l, h_l, w_l = media_item_latents.shape + latent_slice = init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] + lerp_result = (1.0 - strength) * latent_slice + strength * media_item_latents + updated_latents = dynamic_update_slice(init_latents, lerp_result, (0, 0, 0, l_y, l_x)) + updated_conditioning_mask = dynamic_update_slice( + init_conditioning_mask, jnp.full((1, f_l, h_l, w_l), strength), (0, 0, l_y, l_x) + ) + init_latents = updated_latents + init_conditioning_mask = updated_conditioning_mask + init_latents = torch.from_numpy(np.array(init_latents)) init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents) init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=True) + if not conditioning_items: + return ( + jnp.array(init_latents.to(torch.float32).detach().numpy()), + jnp.array(init_pixel_coords.to(torch.float32).detach().numpy()), + None, + 0, + ) + init_conditioning_mask, _ = self.patchifier.patchify( + latents=torch.from_numpy(np.array(init_conditioning_mask)).unsqueeze(1) + ) + init_conditioning_mask = init_conditioning_mask.squeeze(-1) return ( jnp.array(init_latents.to(torch.float32).detach().numpy()), jnp.array(init_pixel_coords.to(torch.float32).detach().numpy()), + jnp.array(init_conditioning_mask.to(torch.float32).detach().numpy()), 0, ) + def get_latent_spatial_position( + self, + latents: jnp.ndarray, + conditioning_item: ConditioningItem, + height: int, + width: int, + strip_latent_border: bool, + ) -> Tuple[jnp.ndarray, int, int]: + """ + Get the spatial position of the conditioning item in the latent space. + If requested, strip the conditioning latent borders that do not align + with target borders. + """ + scale = self.vae_scale_factor + h, w = conditioning_item.media_item.shape[-2:] + + # Assertions are for verification and should be handled outside a JIT-compiled + # function or with jax.debug.check. + # The checks here are for shape and alignment validation. + assert h <= height and w <= width + assert h % scale == 0 and w % scale == 0 + + # Compute the start and end spatial positions of the media item + x_start, y_start = conditioning_item.media_x, conditioning_item.media_y + x_start = (width - w) // 2 if x_start is None else x_start + y_start = (height - h) // 2 if y_start is None else y_start + x_end, y_end = x_start + w, y_start + h + + # JAX-friendly way to handle conditional stripping + if strip_latent_border: + # Determine slice indices based on the position + # JAX's slicing can handle negative and dynamic indices + x_slice_start = 1 if x_start > 0 else 0 + x_slice_end = -1 if x_end < width else None + + y_slice_start = 1 if y_start > 0 else 0 + y_slice_end = -1 if y_end < height else None + + # Update latents with a single, combined slice + latents = latents[..., y_slice_start:y_slice_end, x_slice_start:x_slice_end] + + # Adjust start positions + x_start = x_start + jnp.where(x_start > 0, scale, 0) + y_start = y_start + jnp.where(y_start > 0, scale, 0) + + # Return the modified latents and the scaled start positions + return latents, x_start // scale, y_start // scale + def denormalize(self, images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: r""" Borrowed from diffusers.image_processor @@ -536,6 +652,7 @@ def __call__( skip_final_inference_steps: int = 0, cfg_star_rescale: bool = False, seed: int = 0, + conditioning_items: Optional[List[ConditioningItem]] = None, skip_layer_strategy: Optional[SkipLayerStrategy] = None, skip_block_list: Optional[Union[List[List[int]], List[int]]] = None, **kwargs, @@ -632,7 +749,7 @@ def __call__( self.prompt_enhancer_llm_model, self.prompt_enhancer_llm_tokenizer, prompt, - None, # conditioning items set to None + None, # conditioning items set to None, not tested! max_new_tokens=text_encoder_max_tokens, ) @@ -673,11 +790,19 @@ def __call__( key=key, ) - latents, pixel_coords, num_cond_latents = self.prepare_conditioning( + latents, pixel_coords, conditioning_mask, num_cond_latents = self.prepare_conditioning( init_latents=latents, + conditioning_items=conditioning_items, + num_frames=num_frames, + height=height, + width=width, + vae_per_channel_normalize=kwargs.get("vae_per_channel_normalize", True), ) pixel_coords = jnp.concatenate([pixel_coords] * num_conds, axis=0) + orig_conditioning_mask = conditioning_mask + if conditioning_mask is not None and is_video: + conditioning_mask = jnp.concatenate([conditioning_mask] * num_conds, axis=0) fractional_coords = pixel_coords.astype(jnp.float32) fractional_coords = fractional_coords.at[:, 0].set(fractional_coords[:, 0] * (1.0 / frame_rate)) validate_transformer_inputs(prompt_embeds_batch, fractional_coords, latents, prompt_attention_mask_batch) @@ -703,6 +828,8 @@ def __call__( skip_layer_masks=skip_layer_masks, skip_layer_strategy=skip_layer_strategy, cfg_star_rescale=cfg_star_rescale, + conditioning_mask=conditioning_mask, + original_conditioning_mask=orig_conditioning_mask, ) with self.mesh: @@ -789,6 +916,47 @@ def transformer_forward_pass( return noise_pred, state +def add_noise_to_image_conditioning_latents_jax( + key: Array, + t: Union[float, Array], + init_latents: Array, + latents: Array, + noise_scale: float, + conditioning_mask: Array, + eps: float = 1e-6, +) -> Array: + """ + Add timestep-dependent noise to the hard-conditioning latents in a JAX-compatible way. + """ + noise = jax.random.normal(key, latents.shape, dtype=latents.dtype) + need_to_noise = (conditioning_mask > 1.0 - eps)[jnp.newaxis, :, jnp.newaxis, :, :] + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = jnp.where(need_to_noise, noised_latents, latents) + + return latents + + +def denoising_step( + scheduler, + scheduler_state, + noise_pred: Array, + current_timestep: Optional[Array], + conditioning_mask: Optional[Array], + t: float, + latents: Array, + t_eps: float = 1e-6, +) -> Union[FlaxRectifiedFlowSchedulerOutput, Tuple[jnp.ndarray, RectifiedFlowSchedulerState]]: + # Denoise the latents using the scheduler + denoised_latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + + if conditioning_mask is None: + return denoised_latents, scheduler_state + + tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).astype(jnp.bool_) + tokens_to_denoise_mask = jnp.expand_dims(tokens_to_denoise_mask, axis=-1) + return jnp.where(tokens_to_denoise_mask, denoised_latents, latents), scheduler_state + + def run_inference( transformer_state, transformer, @@ -812,10 +980,14 @@ def run_inference( skip_layer_masks, skip_layer_strategy, cfg_star_rescale, + conditioning_mask, + original_conditioning_mask, ): for i, t in enumerate(scheduler_state.timesteps): current_timestep = t + latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents + if not isinstance(current_timestep, (jnp.ndarray, jax.Array)): if isinstance(current_timestep, float): dtype = jnp.float32 @@ -831,6 +1003,11 @@ def run_inference( # Broadcast to batch dimension current_timestep = jnp.broadcast_to(current_timestep, (latent_model_input.shape[0], 1)) + # if conditioning_mask is not None: + # conditioning_timestep = 1.0 - conditioning_mask + # current_timestep =jnp.where( + # conditioning_timestep < current_timestep, conditioning_timestep, current_timestep + # ) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): noise_pred, transformer_state = transformer_forward_pass( @@ -879,7 +1056,9 @@ def run_inference( noise_pred = noise_pred * factor.reshape(batch_size, 1, 1) current_timestep = current_timestep[:1] - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, current_timestep[0][0], latents).to_tuple() + latents, scheduler_state = denoising_step( + scheduler, scheduler_state, noise_pred, current_timestep, original_conditioning_mask, t, latents + ) return latents, scheduler_state @@ -966,7 +1145,16 @@ def __init__(self, video_pipeline: LTXVideoPipeline): self.vae = video_pipeline.vae def __call__( - self, height, width, num_frames, is_video, output_type, config, seed: int = 0, enhance_prompt: bool = False + self, + height, + width, + num_frames, + is_video, + output_type, + config, + seed: int = 0, + enhance_prompt: bool = False, + conditioning_items: Optional[List[ConditioningItem]] = None, ) -> Any: # first pass original_output_type = output_type @@ -988,6 +1176,7 @@ def __call__( num_inference_steps=config.first_pass["num_inference_steps"], guidance_timesteps=config.first_pass["guidance_timesteps"], cfg_star_rescale=config.first_pass["cfg_star_rescale"], + conditioning_items=conditioning_items, skip_layer_strategy=None, skip_block_list=config.first_pass["skip_block_list"], ) @@ -1017,6 +1206,7 @@ def __call__( num_inference_steps=config.second_pass["num_inference_steps"], guidance_timesteps=config.second_pass["guidance_timesteps"], cfg_star_rescale=config.second_pass["cfg_star_rescale"], + conditioning_items=conditioning_items, skip_layer_strategy=None, skip_block_list=config.second_pass["skip_block_list"], )