Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
- **`2025/7/29`**: LTX-Video text2vid generation is now supported.
- **`2025/04/17`**: Flux Finetuning.
- **`2025/02/12`**: Flux LoRA for inference.
Expand All @@ -42,7 +43,7 @@ MaxDiffusion supports
* Load Multiple LoRA (SDXL inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.
* LTX-Video text2vid (inference).
* LTX-Video text2vid, img2vid (inference).


# Table of Contents
Expand Down Expand Up @@ -183,7 +184,8 @@ To generate images, run the following command:
```bash
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
```
- Other generation parameters can be set in ltx_video.yml file.
- Img2video Generation:
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
## Flux

First make sure you have permissions to access the Flux repos in Huggingface.
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ sampler: "from_checkpoint"

# Generation parameters
pipeline_type: multi-scale
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. "
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
height: 512
width: 512
Expand All @@ -35,6 +35,8 @@ stg_mode: "attention_values"
decode_timestep: 0.05
decode_noise_scale: 0.025
seed: 10
conditioning_media_paths: None #["IMAGE_PATH"]
conditioning_start_frames: [0]


first_pass:
Expand Down
114 changes: 112 additions & 2 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@

import numpy as np
from absl import app
from typing import Sequence
from typing import Sequence, List, Optional, Union
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
from maxdiffusion import pyconfig, max_logging
import torchvision.transforms.functional as TVF
import imageio
from datetime import datetime
import os
import time
from pathlib import Path
from PIL import Image
import torch


def calculate_padding(
Expand All @@ -44,6 +48,79 @@ def calculate_padding(
return padding


def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int = 512,
target_width: int = 768,
just_crop: bool = False,
) -> torch.Tensor:
"""Load and process an image into a tensor.

Args:
image_input: Either a file path (str) or a PIL Image object
target_height: Desired height of output tensor
target_width: Desired width of output tensor
just_crop: If True, only crop the image to the target size without resizing
"""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("image_input must be either a file path or a PIL Image object")

input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width = int(input_height * aspect_ratio_target)
new_height = input_height
x_start = (input_width - new_width) // 2
y_start = 0
else:
new_width = input_width
new_height = int(input_width / aspect_ratio_target)
x_start = 0
y_start = (input_height - new_height) // 2

image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
if not just_crop:
image = image.resize((target_width, target_height))

frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1]
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W)
frame_tensor = (frame_tensor / 127.5) - 1.0
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
return frame_tensor.unsqueeze(0).unsqueeze(2)


def prepare_conditioning(
conditioning_media_paths: List[str],
conditioning_strengths: List[float],
conditioning_start_frames: List[int],
height: int,
width: int,
padding: tuple[int, int, int, int],
) -> Optional[List[ConditioningItem]]:
"""Prepare conditioning items based on input media paths and their parameters."""
conditioning_items = []
for path, strength, start_frame in zip(conditioning_media_paths, conditioning_strengths, conditioning_start_frames):
num_input_frames = 1
media_tensor = load_media_file(
media_path=path,
height=height,
width=width,
max_frames=num_input_frames,
padding=padding,
just_crop=True,
)
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
return conditioning_items


def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
# Remove non-letters and convert to lowercase
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
Expand All @@ -68,6 +145,19 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
return "-".join(result)


def load_media_file(
media_path: str,
height: int,
width: int,
max_frames: int,
padding: tuple[int, int, int, int],
just_crop: bool = False,
) -> torch.Tensor:
media_tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width, just_crop=just_crop)
media_tensor = torch.nn.functional.pad(media_tensor, padding)
return media_tensor


def get_unique_filename(
base: str,
ext: str,
Expand Down Expand Up @@ -97,6 +187,25 @@ def run(config):
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
if config.pipeline_type == "multi-scale":
pipeline = LTXMultiScalePipeline(pipeline)
conditioning_media_paths = config.conditioning_media_paths if isinstance(config.conditioning_media_paths, List) else None
conditioning_start_frames = config.conditioning_start_frames
conditioning_strengths = None
if conditioning_media_paths:
if not conditioning_strengths:
conditioning_strengths = [1.0] * len(conditioning_media_paths)
conditioning_items = (
prepare_conditioning(
conditioning_media_paths=conditioning_media_paths,
conditioning_strengths=conditioning_strengths,
conditioning_start_frames=conditioning_start_frames,
height=config.height,
width=config.width,
padding=padding,
)
if conditioning_media_paths
else None
)

s0 = time.perf_counter()
images = pipeline(
height=height_padded,
Expand All @@ -106,6 +215,7 @@ def run(config):
output_type="pt",
config=config,
enhance_prompt=enhance_prompt,
conditioning_items=conditioning_items,
seed=config.seed,
)
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")
Expand Down
57 changes: 57 additions & 0 deletions src/maxdiffusion/pipelines/ltx_video/crf_compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Lightricks Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This implementation is based on the Torch version available at:
# https://github.com/Lightricks/LTX-Video/tree/main
import av
import torch
import io
import numpy as np


def _encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()


def _decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")


def compress(image: torch.Tensor, crf=29):
if crf == 0:
return image

image_array = (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
_encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = _decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
Loading
Loading