Skip to content

Commit

Permalink
Fix of issue huggingface#13327: Wrong weight initialization for TF t5…
Browse files Browse the repository at this point in the history
… model (huggingface#14241)

* Fix of issue huggingface#13327: Wrong weight initialization for TF t5 model

* run black formatter

* fix typo

* remove my name tag from comments

Co-authored-by: Shirron <dan.shirron@intel.com>
  • Loading branch information
2 people authored and Alberto Bégué committed Jan 27, 2022
1 parent 9c73480 commit 57c3551
Showing 1 changed file with 64 additions and 11 deletions.
75 changes: 64 additions & 11 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,18 @@ def call(self, hidden_states):
class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi")
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
)
self.wi = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wo = tf.keras.layers.Dense(
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = tf.keras.activations.relu

Expand All @@ -109,9 +119,21 @@ def call(self, hidden_states, training=False):
class TFT5GatedGeluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
)
self.wi_0 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wi_1 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wo = tf.keras.layers.Dense(
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation("gelu_new")

Expand Down Expand Up @@ -163,10 +185,34 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs):
self.inner_dim = self.n_heads * self.key_value_proj_dim

# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k")
self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v")
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
q_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
)
k_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
v_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
o_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)

self.q = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer
) # Update init weights as in flax
self.k = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer
) # Update init weights as in flax
self.v = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer
) # Update init weights as in flax
self.o = tf.keras.layers.Dense(
self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)

self.pruned_heads = set()
Expand All @@ -177,6 +223,7 @@ def build(self, input_shape):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=self.relative_attention_bias_initializer, # Add initializer
)

return super().build(input_shape)
Expand Down Expand Up @@ -1266,7 +1313,10 @@ def __init__(self, config, *inputs, **kwargs):
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")

if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)
self.lm_head = tf.keras.layers.Dense(
config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
) # Update init weights as in flax

def get_output_embeddings(self):
if self.config.tie_word_embeddings:
Expand All @@ -1280,7 +1330,10 @@ def set_output_embeddings(self, value):
if self.config.tie_word_embeddings:
self.set_input_embeddings(value)
else:
self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head")
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor)
self.lm_head = tf.keras.layers.Dense(
shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer
) # Update init weights as in flax
# in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens)
# value has a shape (num_tokens, dim) then needs to be transposed
transposed_value = tf.transpose(value)
Expand Down

0 comments on commit 57c3551

Please sign in to comment.