Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
absl-py
accelerate
aqtp
av
chex
datasets
einops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ astroid>=4.0.4
astunparse>=1.6.3
attrs>=25.4.0
auditwheel>=6.6.0
av>=17.0.1
black>=25.12.0
build>=1.4.0
certifi>=2026.1.4
Expand Down
2 changes: 1 addition & 1 deletion maxdiffusion_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN python -m pip install --upgrade pip uv --no-warn-script-location

# Install system dependencies
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool g++ && rm -rf /var/lib/apt/lists/*

# Add the Google Cloud SDK package repository
RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
Expand Down
22 changes: 19 additions & 3 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,22 @@ def __init__(
):
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

tpu_type = get_tpu_type()
is_ironwood = tpu_type == TpuType.TPU_7X

# Hardware-aware sharding specs: Ironwood (v7x) keeps the embedding dimension (embed)
# replicated (None) to minimize cross-device communication, while other hardware (default)
# shards it to prevent OOM issues.
if is_ironwood:
net0_kernel_spec = (None, "mlp")
net2_kernel_spec = ("mlp", None)
net2_bias_spec = (None,)
else:
net0_kernel_spec = ("embed", "mlp")
net2_kernel_spec = ("mlp", "embed")
net2_bias_spec = ("embed",)

self.net_0 = nnx.Linear(
dim,
inner_dim,
Expand All @@ -1142,7 +1158,7 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net0_kernel_spec),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
)
self.act = get_activation(activation_fn)
Expand All @@ -1154,8 +1170,8 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net2_kernel_spec),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net2_bias_spec),
)

def __call__(self, hidden_states: Array) -> Array:
Expand Down
Loading