Skip to content

Commit

Permalink
Merge bda47bc into df90cdc
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnGiorgi committed Sep 9, 2019
2 parents df90cdc + bda47bc commit 66c3480
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion saber/config.ini
Expand Up @@ -29,7 +29,7 @@ grad_norm = 1.0
# For certain optimizers, these values are ignored. See compile_model() in
# saber/utils/model_utils.py.
learning_rate = 0.0
decay = 0.0
weight_decay = 0.1

# Three dropout values must be specified (separated by a comma), corresponding to the dropout rate
# to apply to the input, output and recurrent connections respectively. Must be a value between 0.0
Expand Down
7 changes: 3 additions & 4 deletions saber/config.py
Expand Up @@ -154,7 +154,7 @@ def _parse_config_args(self, config):
args['optimizer'] = config['training']['optimizer']
args['activation'] = config['training']['activation']
args['learning_rate'] = config['training'].getfloat('learning_rate')
args['decay'] = config['training'].getfloat('decay')
args['weight_decay'] = config['training'].getfloat('weight_decay')
args['grad_norm'] = config['training'].getfloat('grad_norm')
args['dropout_rate'] = config['training'].getfloat('dropout_rate')
args['batch_size'] = config['training'].getint('batch_size')
Expand Down Expand Up @@ -245,9 +245,8 @@ def _parse_cli_args(self):
parser.add_argument('--debug', required=False, action='store_true',
help=('If provided, only a small proportion of the dataset, and any '
'provided embeddings, are loaded. Useful for debugging.'))
parser.add_argument('--decay', required=False, type=float,
help=('float >= 0. Learning rate decay over each update. Note that for '
'certain optimizers this value is ignored. Defaults to 0.'))
parser.add_argument('--weight_decay', required=False, type=float,
help=('0 <= float <= 1. Weight decay.'))
parser.add_argument('--dropout_rate', required=False, type=float,
help='float between 0 and 1. Fraction of the input units to drop.')
parser.add_argument('--grad_norm', required=False, type=float,
Expand Down
8 changes: 4 additions & 4 deletions saber/tests/resources/constants.py
Expand Up @@ -114,7 +114,7 @@
'activation': 'relu',
'learning_rate': '0.0',
'grad_norm': '1.0',
'decay': '0.0',
'weight_decay': '0.1',
'dropout_rate': '0.1',
'batch_size': '32',
'validation_split': '0.0',
Expand All @@ -141,7 +141,7 @@
'optimizer': 'nadam',
'activation': 'relu',
'learning_rate': 0.0,
'decay': 0.0,
'weight_decay': 0.1,
'grad_norm': 1.0,
'dropout_rate': 0.1,
'batch_size': 32,
Expand All @@ -161,7 +161,7 @@
DUMMY_COMMAND_LINE_ARGS = {'optimizer': 'sgd',
'grad_norm': 1.0,
'learning_rate': 0.05,
'decay': 0.5,
'weight_decay': 0.5,
'dropout_rate': 0.6,
# the datasets are used for test purposes so they must
# point to the correct resources, this can be ensured by passing their
Expand All @@ -179,7 +179,7 @@
'optimizer': 'sgd',
'activation': 'relu',
'learning_rate': 0.05,
'decay': 0.5,
'weight_decay': 0.5,
'grad_norm': 1.0,
'dropout_rate': 0.6,
'batch_size': 32,
Expand Down
2 changes: 1 addition & 1 deletion saber/tests/resources/dummy_config.ini
Expand Up @@ -29,7 +29,7 @@ grad_norm = 1.0
# For certain optimizers, these values are ignored. See compile_model() in
# saber/utils/model_utils.py.
learning_rate = 0.0
decay = 0.0
weight_decay = 0.1

# Three dropout values must be specified (separated by a comma), corresponding to the dropout rate
# to apply to the input, output and recurrent connections respectively. Must be a value between 0.0
Expand Down
2 changes: 1 addition & 1 deletion saber/utils/bert_utils.py
Expand Up @@ -274,7 +274,7 @@ def get_bert_optimizer(config, model):
decay.append(param)

grouped_parameters = [
{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': 0.01}
{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': config.weight_decay}
]

optimizer = AdamW(grouped_parameters, lr=config.learning_rate, correct_bias=False)
Expand Down

0 comments on commit 66c3480

Please sign in to comment.