Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified code_style.sh
100644 → 100755
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import inspect
from importlib import import_module
from typing import Any, Dict, Optional, Tuple
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import time

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import argparse
import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pedagogical_examples/parameter_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Sequence
from absl import app
import jax
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/controlnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import TYPE_CHECKING

from ...utils import (
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

_import_structure = {"pipeline_jflux": "JfluxPipeline"}

from .flux_pipeline import (
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import TYPE_CHECKING

from ...utils import (
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import TYPE_CHECKING

from ...utils import (
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,9 @@ def prepare_latents_i2v_base(

vae_dtype = getattr(self.vae, "dtype", jnp.float32)
video_condition = video_condition.astype(vae_dtype)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec)
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()

# Normalize latents
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/tests/configuration_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import os

Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/tests/flop_calculations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest
from unittest.mock import Mock
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/tests/generate_flux_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest
import pytest
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/tests/generate_sdxl_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest
import pytest
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/tests/generate_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest
import pytest
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/utils/deprecation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import inspect
import warnings
from typing import Any, Dict, Optional, Union
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import io
import random
import struct
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/utils/loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
from typing import Callable, List, Optional, Union

Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/utils/pil_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import List

import PIL.Image
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import functools
import importlib
import inspect
Expand Down
Loading