Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
ruff check .
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
11 changes: 10 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ri
flash_min_seq_length: 4096
dropout: 0.1

flash_block_sizes: {}
flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 1024,
"block_q_dkv" : 1024,
"block_kv_dkv" : 1024,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 1024
}
# Use on v6e
# flash_block_sizes: {
# "block_q" : 3024,
Expand Down
10 changes: 6 additions & 4 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Any, Dict, Tuple, Union
from . import max_logging
import numpy as np
from dataclasses import asdict, is_dataclass

from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
Expand Down Expand Up @@ -54,16 +55,17 @@ class CustomEncoder(json.JSONEncoder):
"""

def default(self, o):
# This will catch the `dtype[bfloat16]` object and convert it to the string "bfloat16"
if isinstance(o, type(jnp.dtype("bfloat16"))):
return str(o)
# Add fallbacks for other numpy types if needed
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
# Let the base class default method raise the TypeError for other types
return super().default(o)
if is_dataclass(o):
return asdict(o)
else:
max_logging.log(f"Warning: {o} of type {type(o)} is not JSON serializable")
return None


class FrozenDict(OrderedDict):
Expand Down
42 changes: 42 additions & 0 deletions src/maxdiffusion/tests/configuration_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import os

from maxdiffusion import pyconfig
from maxdiffusion.configuration_utils import ConfigMixin
from maxdiffusion import __version__

class DummyConfigMixin(ConfigMixin):
config_name = "config.json"

def __init__(self, **kwargs):
self.register_to_config(**kwargs)

def test_to_json_string_with_config():
# Load the YAML config file
config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml")

# Initialize pyconfig with the YAML config
pyconfig.initialize([None, config_path], unittest=True)
config = pyconfig.config

# Create a DummyConfigMixin instance
dummy_config = DummyConfigMixin(**config.get_keys())

# Get the JSON string
json_string = dummy_config.to_json_string()

# Parse the JSON string
parsed_json = json.loads(json_string)

# Assertions
assert parsed_json["_class_name"] == "DummyConfigMixin"
assert parsed_json["_diffusers_version"] == __version__

# Check a few values from the config
assert parsed_json["run_name"] == config.run_name
assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path
assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"]

# The following keys are explicitly removed in to_json_string, so we assert they are not present
assert "weights_dtype" not in parsed_json
assert "precision" not in parsed_json
2 changes: 2 additions & 0 deletions src/maxdiffusion/tests/input_pipeline_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from absl.testing import absltest

import numpy as np
import pytest
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import jax
Expand Down Expand Up @@ -69,6 +70,7 @@ class InputPipelineInterface(unittest.TestCase):
def setUp(self):
InputPipelineInterface.dummy_data = {}

@pytest.mark.skip(reason="Debug segfault")
def test_make_dreambooth_train_iterator(self):

instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class"
Expand Down
Loading