@@ -120,7 +120,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
120
120
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
121
121
If `None`, normalization and activation layers is skipped in post-processing.
122
122
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
123
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124
123
num_attention_heads (`int`, *optional*): The number of attention heads.
125
124
"""
126
125
@@ -148,16 +147,10 @@ def __init__(
148
147
layers_per_block : int = 2 ,
149
148
norm_num_groups : Optional [int ] = 32 ,
150
149
cross_attention_dim : int = 1024 ,
151
- attention_head_dim : Union [int , Tuple [int ]] = None ,
152
150
num_attention_heads : Optional [Union [int , Tuple [int ]]] = 64 ,
153
151
):
154
152
super ().__init__ ()
155
153
156
- # We didn't define `attention_head_dim` when we first integrated this UNet. As a result,
157
- # we had to use `num_attention_heads` in to pass values for arguments that actually denote
158
- # attention head dimension. This is why we correct it here.
159
- attention_head_dim = num_attention_heads or attention_head_dim
160
-
161
154
# Check inputs
162
155
if len (down_block_types ) != len (up_block_types ):
163
156
raise ValueError (
@@ -179,7 +172,7 @@ def __init__(
179
172
180
173
self .transformer_in = TransformerTemporalModel (
181
174
num_attention_heads = 8 ,
182
- attention_head_dim = attention_head_dim ,
175
+ attention_head_dim = num_attention_heads ,
183
176
in_channels = block_out_channels [0 ],
184
177
num_layers = 1 ,
185
178
norm_num_groups = norm_num_groups ,
0 commit comments