Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,15 @@ ici_tensor_parallelism: 1
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
dataset_type: 'tfrecord'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
dataset_save_location: ''
load_tfrecord_cached: True
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
Expand Down Expand Up @@ -185,6 +186,10 @@ per_device_batch_size: 1
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size: 0

# For creating tfrecords from dataset
tfrecords_dir: ''
no_records_per_shard: 0

warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

Expand Down
15 changes: 15 additions & 0 deletions src/maxdiffusion/data_preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Prepare tfrecords with latents and text embeddings preprocessed.
1. Download the dataset
"""

import os
import functools
from absl import app
from typing import Sequence, Union, List
from datasets import load_dataset
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxdiffusion import pyconfig, max_utils
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from maxdiffusion.video_processor import VideoProcessor

import tensorflow as tf


def image_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))


def bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))


def float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def float_feature_list(value):
"""Returns a list of float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(latent, hidden_states):
latent = tf.io.serialize_tensor(latent)
hidden_states = tf.io.serialize_tensor(hidden_states)
feature = {
"latents": bytes_feature(latent),
"encoder_hidden_states": bytes_feature(hidden_states),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()


def text_encode(pipeline, prompt: Union[str, List[str]]):
encoder_hidden_states = pipeline._get_t5_prompt_embeds(prompt)
encoder_hidden_states = encoder_hidden_states.detach().numpy()
return encoder_hidden_states


def vae_encode(video, rng, vae, vae_cache):
latent = vae.encode(video, feat_cache=vae_cache)
latent = latent.latent_dist.sample(rng)
return latent


def generate_dataset(config, pipeline):

tfrecords_dir = config.tfrecords_dir
if not os.path.exists(tfrecords_dir):
os.makedirs(tfrecords_dir)

tf_rec_num = 0
no_records_per_shard = config.no_records_per_shard
global_record_count = 0
writer = tf.io.TFRecordWriter(
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
)
shard_record_count = 0

# create mesh
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
rng = jax.random.key(config.seed)

vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)

# jit vae fun.
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))

# Load dataset
ds = load_dataset(config.dataset_name, split="train")
ds = ds.shuffle(seed=config.seed)
ds = ds.select_columns([config.caption_column, config.image_column])
batch_size = 10
for i in range(0, len(ds), batch_size):
rng, new_rng = jax.random.split(rng)
text = ds[i : i + batch_size]["text"]
videos = ds[i : i + batch_size]["image"]

videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
with mesh:
latents = p_vae_encode(video=video, rng=new_rng)
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
encoder_hidden_states = text_encode(pipeline, text)
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
writer.write(create_example(latent, encoder_hidden_state))
shard_record_count += 1
global_record_count += 1

if shard_record_count >= no_records_per_shard:
writer.close()
tf_rec_num += 1
writer = tf.io.TFRecordWriter(
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
)
shard_record_count = 0


def run(config):
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
# Don't need the transformer for preprocessing.
generate_dataset(config, pipeline)


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)
7 changes: 4 additions & 3 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from maxdiffusion.utils import export_to_video


def run(config):
def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
pipeline = WanPipeline.from_pretrained(config)
if pipeline is None:
pipeline = WanPipeline.from_pretrained(config)
s0 = time.perf_counter()

# Skip layer guidance
Expand Down Expand Up @@ -59,7 +60,7 @@ def run(config):

print("compile time: ", (time.perf_counter() - s0))
for i in range(len(videos)):
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
s0 = time.perf_counter()
videos = pipeline(
prompt=prompt,
Expand Down
49 changes: 19 additions & 30 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,43 +73,26 @@ def make_tf_iterator(
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter


def make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""
feature_description = {
"pixel_values": tf.io.FixedLenFeature([], tf.string),
"input_ids": tf.io.FixedLenFeature([], tf.string),
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
"text_embeds": tf.io.FixedLenFeature([], tf.string),
}

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

def prepare_sample(features):
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)

return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}

# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))

train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
Expand All @@ -123,11 +106,7 @@ def prepare_sample(features):

# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
):
"""Iterator for TFRecord format. For Laion dataset,
check out preparation script
Expand All @@ -136,12 +115,22 @@ def make_tfrecord_iterator(

# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
if (config.cache_latents_text_encoder_outputs
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
if (
config.cache_latents_text_encoder_outputs
and os.path.isdir(config.dataset_save_location)
and 'load_tfrecord_cached'in config.get_keys()
and config.load_tfrecord_cached):
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
and "load_tfrecord_cached" in config.get_keys()
and config.load_tfrecord_cached
):
return make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
feature_description,
prepare_sample_fn,
)

feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
Expand Down
19 changes: 19 additions & 0 deletions src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,25 @@ def make_data_iterator(
global_batch_size,
tokenize_fn=None,
image_transforms_fn=None,
feature_description=None,
prepare_sample_fn=None,
):
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""

if config.dataset_type == "hf" or config.dataset_type == "tf":
if tokenize_fn is None or image_transforms_fn is None:
raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn")

if (
config.dataset_type == "tfrecord"
and config.cache_latents_text_encoder_outputs
and feature_description is None
and prepare_sample_fn is None
):
raise ValueError(
f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True."
)

if config.dataset_type == "hf":
return _hf_data_processing.make_hf_streaming_iterator(
config,
Expand Down Expand Up @@ -87,6 +104,8 @@ def make_data_iterator(
dataloading_host_count,
mesh,
global_batch_size,
feature_description,
prepare_sample_fn,
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
Expand Down
10 changes: 6 additions & 4 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
def basic_clean(text):
if is_ftfy_available():
import ftfy

text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
Expand Down Expand Up @@ -221,7 +222,7 @@ def load_scheduler(cls, config):
return scheduler, scheduler_state

@classmethod
def from_pretrained(cls, config: HyperParameters, vae_only=False):
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
rng = jax.random.key(config.seed)
Expand All @@ -232,8 +233,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False):
scheduler_state = None
text_encoder = None
if not vae_only:
with mesh:
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
if load_transformer:
with mesh:
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)

text_encoder = cls.load_text_encoder(config=config)
tokenizer = cls.load_tokenizer(config=config)
Expand Down Expand Up @@ -397,7 +399,7 @@ def __call__(
num_channels_latents=num_channel_latents,
)

data_sharding = NamedSharding(self.devices_array, P())
data_sharding = NamedSharding(self.mesh, P())
if len(prompt) % jax.device_count() == 0:
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))

Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_flax"] = ["FlaxFlowMatchScheduler"]
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
Expand All @@ -70,6 +70,7 @@
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler
from .scheduling_flow_match_flax import FlowMatchScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
Expand Down
Loading
Loading