Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
3776190
set up files for ltxvid
Serenagu525 Jun 26, 2025
13656fb
ltx-video-transformer-setup
Serenagu525 Jun 26, 2025
7bed4f9
formatting
Serenagu525 Jun 26, 2025
b31a97b
conversion script added
Serenagu525 Jun 26, 2025
7e098c5
format fixed
Serenagu525 Jun 26, 2025
9a9f5db
conversion script checked
Serenagu525 Jun 27, 2025
d1c304d
comments removed
Serenagu525 Jun 27, 2025
f93c3bd
Added running instructions
Serenagu525 Jun 27, 2025
e0327e5
edited instruction
Serenagu525 Jun 27, 2025
c369302
ruff check error fixed
Serenagu525 Jun 27, 2025
991a44e
mesh edit
Serenagu525 Jun 27, 2025
b0e9bab
key error fix
Serenagu525 Jun 27, 2025
e18128c
transformer step and test
Serenagu525 Jun 30, 2025
1c55452
removed diffusers import
Serenagu525 Jun 30, 2025
fd4af91
fixed mesh
Serenagu525 Jun 30, 2025
5e17a62
changed path
Serenagu525 Jul 1, 2025
fc60b27
changed path
Serenagu525 Jul 1, 2025
3243535
changed config path
Serenagu525 Jul 1, 2025
e873a17
ruff check
Serenagu525 Jul 1, 2025
d06dee3
changed back pyconfig
Serenagu525 Jul 2, 2025
1ea6590
ruff check
Serenagu525 Jul 2, 2025
aa7befd
changed sharding back
Serenagu525 Jul 2, 2025
d9a3502
removed testing for now
Serenagu525 Jul 5, 2025
a1ad421
Update pyconfig.py
Serenagu525 Jul 5, 2025
615174f
Update max_utils.py
Serenagu525 Jul 5, 2025
7469c62
Update ltx_video.yml
Serenagu525 Jul 5, 2025
6de4424
Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred
Serenagu525 Jul 5, 2025
18ec247
Delete src/maxdiffusion/tests/ltx_transformer_step_test.py
Serenagu525 Jul 5, 2025
2737877
added header
Serenagu525 Jul 8, 2025
8a043f6
sharding back
Serenagu525 Jul 9, 2025
35a3337
added test
Serenagu525 Jul 9, 2025
546ecab
ruff fixed
Serenagu525 Jul 9, 2025
12a247f
added header
Serenagu525 Jul 9, 2025
1062c72
license headers
Serenagu525 Jul 9, 2025
535c75e
exclude test
Serenagu525 Jul 9, 2025
f6115df
auto script
Serenagu525 Jul 10, 2025
8bf24a3
headers
Serenagu525 Jul 10, 2025
0f8483e
pulled
Serenagu525 Jul 10, 2025
eaa7196
auto script for file downloading
Serenagu525 Jul 10, 2025
7af151a
change base branch
Serenagu525 Jul 10, 2025
634591b
save now
Serenagu525 Jul 10, 2025
a272d08
load transformer error
Serenagu525 Jul 10, 2025
4bcffd1
later
Serenagu525 Jul 11, 2025
f5afa91
changed repeatable layer
Serenagu525 Jul 11, 2025
e805034
Update max_utils.py
Serenagu525 Jul 11, 2025
bb61ecb
functional
Serenagu525 Jul 11, 2025
7d4b2a9
moved upsampler
Serenagu525 Jul 11, 2025
972e316
initial cleaning
Serenagu525 Jul 11, 2025
c375471
multiscale pipeline
Serenagu525 Jul 16, 2025
f63a6fa
remove init
Serenagu525 Jul 16, 2025
b3874f5
new empty folders
Serenagu525 Jul 16, 2025
3e6499c
downloaded files
Serenagu525 Jul 20, 2025
0b67a19
changed upsampler
Serenagu525 Jul 20, 2025
443243d
kept latents as jnp
Serenagu525 Jul 20, 2025
fefe18e
prepare latents
Serenagu525 Jul 20, 2025
4bad196
save
Serenagu525 Jul 21, 2025
b1e5b0c
fixed transformer init
Serenagu525 Jul 21, 2025
fd9eb11
error attribute weight already exist
Serenagu525 Jul 21, 2025
4c9be69
baseline pipeline cleaned
Serenagu525 Jul 21, 2025
0577d3e
pipeline cleaned
Serenagu525 Jul 22, 2025
072982c
added timing
Serenagu525 Jul 22, 2025
8042df0
pipeline cleaned, licence added
Serenagu525 Jul 23, 2025
0c48524
changed output to cmd line
Serenagu525 Jul 23, 2025
d4c6738
added init file
Serenagu525 Jul 23, 2025
36242d2
changed input format
Serenagu525 Jul 23, 2025
8fc3626
Merge branch 'conversion-script' of https://github.com/AI-Hypercomput…
Serenagu525 Jul 23, 2025
774e2c4
updated requirements
Serenagu525 Jul 24, 2025
b4bd96e
merged conversion
Serenagu525 Jul 24, 2025
f23eeef
merged in conversion script
Serenagu525 Jul 24, 2025
c18c0c6
fixed importing error
Serenagu525 Jul 25, 2025
cfe2c64
fixed importing issue
Serenagu525 Jul 25, 2025
0229dd6
requirement change
Serenagu525 Jul 25, 2025
740d403
merged from main
Serenagu525 Jul 25, 2025
e34d47e
Delete myenv directory
Serenagu525 Jul 25, 2025
460f7db
changed ckpt name
Serenagu525 Jul 25, 2025
0d7f68f
Merge branch 'vae-pipeline-cleaned' of https://github.com/AI-Hypercom…
Serenagu525 Jul 25, 2025
8df8dbd
style fix
Serenagu525 Jul 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added
Empty file.
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
ruff check .
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pytest==8.2.2
tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
git+https://github.com/Lightricks/LTX-Video
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
tokenizers==0.21.0
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ else
fi

# Install maxdiffusion
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
[
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
7 changes: 5 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,11 @@ def load_state_if_possible(
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
if checkpoint_item == "ltxvid_transformer":
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
else:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))

def map_to_pspec(data):
pspec = data.sharding.spec
Expand Down
99 changes: 99 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: ''
config_path: ''
save_config_to_gcs: False

#Checkpoints
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
frame_rate: 30
max_sequence_length: 512
sampler: "from_checkpoint"

# Generation parameters
pipeline_type: multi-scale
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. "
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
height: 512
width: 512
num_frames: 88
flow_shift: 5.0
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
prompt_enhancement_words_threshold: 120
stg_mode: "attention_values"
decode_timestep: 0.05
decode_noise_scale: 0.025
seed: 10


first_pass:
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
skip_initial_inference_steps: 0
cfg_star_rescale: True

second_pass:
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
skip_final_inference_steps: 0
cfg_star_rescale: True

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_batch', 'data'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'fsdp'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
enable_single_replica_ckpt_restoring: False
161 changes: 161 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np
from absl import app
from typing import Sequence
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
from maxdiffusion import pyconfig, max_logging
import imageio
from datetime import datetime
import os
import time
from pathlib import Path


def calculate_padding(
source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:

# Calculate total padding needed
pad_height = target_height - source_height
pad_width = target_width - source_width

# Calculate padding for each side
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top # Handles odd padding
pad_left = pad_width // 2
pad_right = pad_width - pad_left # Handles odd padding
padding = (pad_left, pad_right, pad_top, pad_bottom)
return padding


def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
# Remove non-letters and convert to lowercase
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())

# Split into words
words = clean_text.split()

# Build result string keeping track of length
result = []
current_length = 0

for word in words:
# Add word length plus 1 for underscore (except for first word)
new_length = current_length + len(word)

if new_length <= max_len:
result.append(word)
current_length += len(word)
else:
break

return "-".join(result)


def get_unique_filename(
base: str,
ext: str,
prompt: str,
resolution: tuple[int, int, int],
dir: Path,
endswith=None,
index_range=1000,
) -> Path:
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
for i in range(index_range):
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
if not os.path.exists(filename):
return filename
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")


def run(config):
height_padded = ((config.height - 1) // 32 + 1) * 32
width_padded = ((config.width - 1) // 32 + 1) * 32
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
padding = calculate_padding(config.height, config.width, height_padded, width_padded)
prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
prompt_word_count = len(config.prompt.split())
enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold

pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
if config.pipeline_type == "multi-scale":
pipeline = LTXMultiScalePipeline(pipeline)
s0 = time.perf_counter()
images = pipeline(
height=height_padded,
width=width_padded,
num_frames=num_frames_padded,
is_video=True,
output_type="pt",
config=config,
enhance_prompt=enhance_prompt,
seed=config.seed,
)
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")

(pad_left, pad_right, pad_top, pad_bottom) = padding
pad_bottom = -pad_bottom
pad_right = -pad_right
if pad_bottom == 0:
pad_bottom = images.shape[3]
if pad_right == 0:
pad_right = images.shape[4]
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
output_dir.mkdir(parents=True, exist_ok=True)

for i in range(images.shape[0]):
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
# Unnormalizing images to [0, 255] range
video_np = (video_np * 255).astype(np.uint8)
fps = config.frame_rate
height, width = video_np.shape[1:3]
# In case a single image is generated
if video_np.shape[0] == 1:
output_filename = get_unique_filename(
f"image_output_{i}",
".png",
prompt=config.prompt,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
imageio.imwrite(output_filename, video_np[0])
else:
output_filename = get_unique_filename(
f"video_output_{i}",
".mp4",
prompt=config.prompt,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
# Write video
with imageio.get_writer(output_filename, fps=fps) as video:
for frame in video_np:
video.append_data(frame)


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)
5 changes: 4 additions & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,10 @@ def setup_initial_state(
config.enable_single_replica_ckpt_restoring,
)
if state:
state = state[checkpoint_item]
if checkpoint_item == "ltxvid_transformer":
state = state
else:
state = state[checkpoint_item]
if not state:
max_logging.log(f"Could not find the item in orbax, creating state...")
init_train_state_partial = functools.partial(
Expand Down
5 changes: 2 additions & 3 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

_import_structure = {}

Expand All @@ -32,6 +30,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .ltx_video.transformers.transformer3d import Transformer3DModel

else:
import sys
Expand Down
15 changes: 15 additions & 0 deletions src/maxdiffusion/models/ltx_video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
16 changes: 16 additions & 0 deletions src/maxdiffusion/models/ltx_video/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2025 Lightricks Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This implementation is based on the Torch version available at:
# https://github.com/Lightricks/LTX-Video/tree/main
Loading
Loading