This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a new command `find-lr` that allows one to search for learning rate range where loss drops rapidly. This addresses feature request #537 . Refer the following [blog post](https://medium.com/@surmenok/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0) linked in that issue for overview of how the finder works. The major changes are making few of the fields in `Trainer` "public" (ie remove underscore from names). I have used matplotlib to plot learning rate vs loss graph. I haven't written unit tests, if the current code is ok, will do. I am a little unsure on what to test exactly.
- Loading branch information
1 parent
63836c4
commit 9fcc795
Showing
8 changed files
with
521 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
""" | ||
The ``find-lr`` subcommand can be used to find a good learning rate for a model. | ||
It requires a configuration file and a directory in | ||
which to write the results. | ||
.. code-block:: bash | ||
$ allennlp find-lr --help | ||
usage: allennlp train [-h] -s SERIALIZATION_DIR [-o OVERRIDES] | ||
[--start-lr START_LR] [--end-lr END_LR] | ||
[--num-batches NUM_BATCHES] [--linear] | ||
param_path | ||
Train the specified model on the specified dataset. | ||
positional arguments: | ||
param_path path to parameter file describing the model to be | ||
trained | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
-s SERIALIZATION_DIR, --serialization-dir SERIALIZATION_DIR | ||
directory in which to save Learning rate vs loss | ||
-o OVERRIDES, --overrides OVERRIDES | ||
a JSON structure used to override the experiment | ||
configuration | ||
--start-lr START_LR | ||
Learning rate to start the search. | ||
--end-lr END_LR | ||
Learning rate up to which search is done. | ||
--num-batches NUM_BATCHES | ||
Number of mini-batches to run Learning rate finder | ||
--linear Increase learning rate linearly instead of exponential increase | ||
""" | ||
from typing import List, Optional, Tuple | ||
import argparse | ||
import re | ||
import os | ||
import math | ||
import logging | ||
import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements,wrong-import-position | ||
import matplotlib.pyplot as plt # pylint: disablewrong-import-position | ||
|
||
from allennlp.commands.subcommand import Subcommand # pylint: disablewrong-import-position | ||
from allennlp.commands.train import datasets_from_params # pylint: disablewrong-import-position | ||
from allennlp.common.checks import ConfigurationError, check_for_gpu # pylint: disablewrong-import-position | ||
from allennlp.common import Params, Tqdm # pylint: disablewrong-import-position | ||
from allennlp.common.util import prepare_environment # pylint: disablewrong-import-position | ||
from allennlp.data import Vocabulary, DataIterator # pylint: disablewrong-import-position | ||
from allennlp.models import Model # pylint: disablewrong-import-position | ||
from allennlp.training import Trainer # pylint: disablewrong-import-position | ||
|
||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
class FindLearningRate(Subcommand): | ||
def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: | ||
# pylint: disable=protected-access | ||
description = '''Find a learning rate range where loss decreases quickly | ||
for the specified model and dataset.''' | ||
subparser = parser.add_parser(name, description=description, help='Train a model') | ||
|
||
subparser.add_argument('param_path', | ||
type=str, | ||
help='path to parameter file describing the model to be trained') | ||
subparser.add_argument('-s', '--serialization-dir', | ||
required=True, | ||
type=str, | ||
help='The directory in which to save results.') | ||
|
||
subparser.add_argument('-o', '--overrides', | ||
type=str, | ||
default="", | ||
help='a JSON structure used to override the experiment configuration') | ||
subparser.add_argument('--start-lr', | ||
type=float, | ||
default=1e-5, | ||
help='Learning rate to start the search.') | ||
subparser.add_argument('--end-lr', | ||
type=float, | ||
default=10, | ||
help='Learning rate up to which search is done.') | ||
subparser.add_argument('--num-batches', | ||
type=int, | ||
default=100, | ||
help='Number of mini-batches to run Learning rate finder') | ||
subparser.add_argument('--stopping-factor', | ||
type=float, | ||
default=4.0, | ||
help='Stop the search when the current loss exceeds the best loss recorded by ' | ||
'multiple of stopping factor') | ||
subparser.add_argument('--linear', | ||
action='store_true', | ||
help='Increase learning rate linearly instead of exponential increase') | ||
|
||
subparser.set_defaults(func=find_learning_rate_from_args) | ||
|
||
return subparser | ||
|
||
def find_learning_rate_from_args(args: argparse.Namespace) -> None: | ||
""" | ||
Start learning rate finder for given args | ||
""" | ||
params = Params.from_file(args.param_path, args.overrides) | ||
find_learning_rate_model(params, args.serialization_dir, | ||
args.start_lr, args.end_lr, | ||
args.num_batches, args.linear, args.stopping_factor) | ||
|
||
def find_learning_rate_model(params: Params, | ||
serialization_dir: str, | ||
start_lr: float, | ||
end_lr: float, | ||
num_batches: int, | ||
linear_steps: bool, | ||
stopping_factor: Optional[float]) -> None: | ||
""" | ||
Runs learning rate search for given `num_batches` and saves the results in ``serialization_dir`` | ||
Parameters | ||
---------- | ||
trainer: :class:`~allennlp.common.registrable.Registrable` | ||
params : ``Params`` | ||
A parameter object specifying an AllenNLP Experiment. | ||
serialization_dir : ``str`` | ||
The directory in which to save results. | ||
start_lr: ``float`` | ||
Learning rate to start the search. | ||
end_lr: ``float`` | ||
Learning rate upto which search is done. | ||
num_batches: ``int`` | ||
Number of mini-batches to run Learning rate finder. | ||
linear_steps: ``bool`` | ||
Increase learning rate linearly if False exponentially. | ||
stopping_factor: ``float`` | ||
Stop the search when the current loss exceeds the best loss recorded by | ||
multiple of stopping factor. If ``None`` search proceeds till the ``end_lr`` | ||
""" | ||
|
||
if os.path.exists(serialization_dir) and os.listdir(serialization_dir): | ||
raise ConfigurationError(f'Serialization directory {serialization_dir} already exists and is ' | ||
f'not empty.') | ||
|
||
prepare_environment(params) | ||
os.makedirs(serialization_dir, exist_ok=True) | ||
|
||
check_for_gpu(params.get('trainer').get('cuda_device', -1)) | ||
|
||
all_datasets = datasets_from_params(params) | ||
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets)) | ||
|
||
for dataset in datasets_for_vocab_creation: | ||
if dataset not in all_datasets: | ||
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}") | ||
|
||
logger.info("From dataset instances, %s will be considered for vocabulary creation.", | ||
", ".join(datasets_for_vocab_creation)) | ||
vocab = Vocabulary.from_params( | ||
params.pop("vocabulary", {}), | ||
(instance for key, dataset in all_datasets.items() | ||
for instance in dataset | ||
if key in datasets_for_vocab_creation) | ||
) | ||
|
||
model = Model.from_params(vocab=vocab, params=params.pop('model')) | ||
iterator = DataIterator.from_params(params.pop("iterator")) | ||
iterator.index_with(vocab) | ||
|
||
train_data = all_datasets['train'] | ||
|
||
trainer_params = params.pop("trainer") | ||
no_grad_regexes = trainer_params.pop("no_grad", ()) | ||
for name, parameter in model.named_parameters(): | ||
if any(re.search(regex, name) for regex in no_grad_regexes): | ||
parameter.requires_grad_(False) | ||
|
||
trainer = Trainer.from_params(model, | ||
serialization_dir, | ||
iterator, | ||
train_data, | ||
params=trainer_params, | ||
validation_data=None, | ||
validation_iterator=None) | ||
|
||
logger.info(f'Starting learning rate search from {start_lr} to {end_lr} in {num_batches} iterations.') | ||
learning_rates, losses = search_learning_rate(trainer, start_lr, | ||
end_lr, num_batches, | ||
linear_steps, stopping_factor) | ||
logger.info(f'Finished learning rate search.') | ||
losses = _smooth(losses, 0.98) | ||
|
||
_save_plot(learning_rates, losses, os.path.join(serialization_dir, 'lr-losses.png')) | ||
|
||
def search_learning_rate(trainer: Trainer, | ||
start_lr: float = 1e-5, | ||
end_lr: float = 10, | ||
num_batches: int = 100, | ||
linear_steps: bool = False, | ||
stopping_factor: Optional[float] = 4.0) -> Tuple[List[float], List[float]]: | ||
""" | ||
Runs training loop on the model using :class:`~allennlp.training.trainer.Trainer` | ||
increasing learning rate from ``start_lr`` to ``end_lr`` recording the losses. | ||
Parameters | ||
---------- | ||
trainer: :class:`~allennlp.training.trainer.Trainer` | ||
start_lr: ``float`` | ||
The learning rate to start the search. | ||
end_lr: ``float`` | ||
The learning rate upto which search is done. | ||
num_batches: ``int`` | ||
Number of batches to run the learning rate finder. | ||
linear_steps: ``bool`` | ||
Increase learning rate linearly if False exponentially. | ||
stopping_factor: ``float`` | ||
Stop the search when the current loss exceeds the best loss recorded by | ||
multiple of stopping factor. If ``None`` search proceeds till the ``end_lr`` | ||
Returns | ||
------- | ||
(learning_rates, losses): ``Tuple[List[float], List[float]]`` | ||
Returns list of learning rates and corresponding losses. | ||
Note: The losses are recorded before applying the corresponding learning rate | ||
""" | ||
if num_batches <= 10: | ||
raise ConfigurationError('The number of iterations for learning rate finder should be greater than 10.') | ||
|
||
trainer.model.train() | ||
|
||
train_generator = trainer.iterator(trainer.train_data, | ||
shuffle=trainer.shuffle) | ||
train_generator_tqdm = Tqdm.tqdm(train_generator, | ||
total=num_batches) | ||
|
||
learning_rates = [] | ||
losses = [] | ||
best = 1e9 | ||
if linear_steps: | ||
lr_update_factor = (end_lr - start_lr) / num_batches | ||
else: | ||
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches) | ||
|
||
for i, batch in enumerate(train_generator_tqdm): | ||
|
||
if linear_steps: | ||
current_lr = start_lr + (lr_update_factor * i) | ||
else: | ||
current_lr = start_lr * (lr_update_factor ** i) | ||
|
||
for param_group in trainer.optimizer.param_groups: | ||
param_group['lr'] = current_lr | ||
|
||
trainer.optimizer.zero_grad() | ||
loss = trainer.batch_loss(batch, for_training=True) | ||
loss.backward() | ||
loss = loss.detach().cpu().item() | ||
|
||
if stopping_factor is not None and (math.isnan(loss) or loss > stopping_factor * best): | ||
logger.info(f'Loss ({loss}) exceeds stopping_factor * lowest recorded loss.') | ||
break | ||
|
||
trainer.rescale_gradients() | ||
trainer.optimizer.step() | ||
|
||
learning_rates.append(current_lr) | ||
losses.append(loss) | ||
|
||
if loss < best and i > 10: | ||
best = loss | ||
|
||
if i == num_batches: | ||
break | ||
|
||
return learning_rates, losses | ||
|
||
|
||
def _smooth(values: List[float], beta: float) -> List[float]: | ||
""" Exponential smoothing of values """ | ||
avg_value = 0. | ||
smoothed = [] | ||
for i, value in enumerate(values): | ||
avg_value = beta * avg_value + (1 - beta) * value | ||
smoothed.append(avg_value / (1 - beta ** (i + 1))) | ||
return smoothed | ||
|
||
def _save_plot(learning_rates: List[float], losses: List[float], save_path: str): | ||
plt.ylabel('loss') | ||
plt.xlabel('learning rate (log10 scale)') | ||
plt.xscale('log') | ||
plt.plot(learning_rates, losses) | ||
logger.info(f'Saving learning_rate vs loss plot to {save_path}.') | ||
plt.savefig(save_path) |
Oops, something went wrong.