Skip to content

Commit 4f1df69

Browse files
committed
Revert "add attention_head_dim"
This reverts commit 15f6b22.
1 parent 15f6b22 commit 4f1df69

File tree

2 files changed

+1
-9
lines changed

2 files changed

+1
-9
lines changed

src/diffusers/models/attention.py

-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def __init__(
158158
super().__init__()
159159
self.only_cross_attention = only_cross_attention
160160

161-
# We keep these boolean flags for backwards-compatibility.
162161
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
163162
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
164163
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"

src/diffusers/models/unets/unet_i2vgen_xl.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
120120
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
121121
If `None`, normalization and activation layers is skipped in post-processing.
122122
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
123-
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124123
num_attention_heads (`int`, *optional*): The number of attention heads.
125124
"""
126125

@@ -148,16 +147,10 @@ def __init__(
148147
layers_per_block: int = 2,
149148
norm_num_groups: Optional[int] = 32,
150149
cross_attention_dim: int = 1024,
151-
attention_head_dim: Union[int, Tuple[int]] = None,
152150
num_attention_heads: Optional[Union[int, Tuple[int]]] = 64,
153151
):
154152
super().__init__()
155153

156-
# We didn't define `attention_head_dim` when we first integrated this UNet. As a result,
157-
# we had to use `num_attention_heads` in to pass values for arguments that actually denote
158-
# attention head dimension. This is why we correct it here.
159-
attention_head_dim = num_attention_heads or attention_head_dim
160-
161154
# Check inputs
162155
if len(down_block_types) != len(up_block_types):
163156
raise ValueError(
@@ -179,7 +172,7 @@ def __init__(
179172

180173
self.transformer_in = TransformerTemporalModel(
181174
num_attention_heads=8,
182-
attention_head_dim=attention_head_dim,
175+
attention_head_dim=num_attention_heads,
183176
in_channels=block_out_channels[0],
184177
num_layers=1,
185178
norm_num_groups=norm_num_groups,

0 commit comments

Comments
 (0)