Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Add weight decay
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Apr 29, 2019
1 parent c995940 commit be977e6
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras_bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_model(token_num,
head_num=12,
feed_forward_dim=3072,
dropout_rate=0.1,
weight_decay=0.01,
attention_activation=None,
feed_forward_activation=gelu,
custom_layers=None,
Expand All @@ -54,6 +55,7 @@ def get_model(token_num,
:param head_num: Number of heads in multi-head attention in each transformer.
:param feed_forward_dim: Dimension of the feed forward layer in each transformer.
:param dropout_rate: Dropout rate.
:param weight_decay: Weight decay rate.
:param attention_activation: Activation for attention layers.
:param feed_forward_activation: Activation for feed-forward layers.
:param custom_layers: A function that takes the embedding tensor and returns the tensor after feature extraction.
Expand Down Expand Up @@ -118,6 +120,13 @@ def get_model(token_num,
name='NSP',
)(nsp_dense_layer)
model = keras.models.Model(inputs=inputs, outputs=[masked_layer, nsp_pred_layer])
if weight_decay:
weight_decay *= 0.5
for layer in model.layers:
if hasattr(layer, 'embeddings_regularizer'):
layer.embeddings_regularizer = keras.regularizers.l2(weight_decay)
if hasattr(layer, 'kernel_regularizer'):
layer.kernel_regularizer = keras.regularizers.l2(weight_decay)
model.compile(
optimizer=keras.optimizers.Adam(lr=lr),
loss=keras.losses.sparse_categorical_crossentropy,
Expand Down

0 comments on commit be977e6

Please sign in to comment.