Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-594] Added Learning Rate Finder tutorial #11304

Merged
merged 7 commits into from
Jul 4, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions docs/tutorials/gluon/learning_rate_finder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@

# Learning Rate Finder

Setting the learning rate for stochastic gradient descent (SGD) is crucially important when training neural network because it controls both the speed of convergence and the ultimate performance of the network. Set the learning too low and you could be twiddling your thumbs for quite some time as the parameters update very slowly. Set it too high and the updates will skip over optimal solutions, or worse the optimizer might not converge at all!

Leslie Smith from the U.S. Naval Research Laboratory presented a method for finding a good learning rate in a paper called ["Cyclical Learning Rates for Training Neural Networks"](https://arxiv.org/abs/1506.01186). We take a look at the central idea of the paper, cyclical learning rate schedules, in the tutorial found here, but in this tutorial we implement a 'Learning Rate Finder' in MXNet with the Gluon API that you can use while training your own networks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"found here" - Did you mean to provide a link here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"but in this tutorial" - "but in this tutorial,"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, unpublished tutorial so replaced with name of tutorial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comma overload if add.


## Simple Idea

Given an initialized network, a defined loss and a training dataset we take the following steps:

1. train one batch at a time (a.k.a. an iteration)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
Train
Start
Record

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

2. start with a very small learning rate (e.g. 0.000001) and slowly increase it every iteration
3. record the training loss and continue until we see the training loss diverge

We then analyse the results by plotting a graph of the learning rate against the training loss as seen below (taking note of the log scales).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"analyse" -> "analyze"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the Queen's English Indu! 🇬🇧


![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_finder/finder_plot.png) <!--notebook-skip-line-->

As expected, for very small learning rates we don't see much change in the loss as the paramater updates are negligible. At a learning rate of 0.001 we start to see the loss fall. Setting the initial learning rate here is reasonable, but we still have the potential to learn faster. We observe a drop in the loss up until 0.1 where the loss appears to diverge. We want to set the initial learning rate as high as possible before the loss becomes unstable, so we choose a learning rate of 0.05.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"paramater" -> "parameter"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"At a learning rate of 0.001" -> "At a learning rate of 0.001,"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected.


## Epoch to Iteration

Usually our unit of work is an epoch (a full pass through the dataset) and the learning rate would typically be held constant throughout the epoch. With the Learning Rate Finder (and cyclical learning rate schedules) we are required to vary the learning rate every iteration. As such we structure our training code so that a single iteration can be run with a given learning rate. You can implement Learner as you wish. Just initialize the network, define the loss and trainer in `__init__` and keep your training logic for a single batch in `iteration`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Usually" -> "Usually,"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.



```python
import mxnet as mx

# Set seed for reproducibility
mx.random.seed(42)

class Learner():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this naming but at the same time can't really find a better one. IterationRunner? I don't know

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me neither, but going to leave to avoid confusion with the videos that have been recorded.

def __init__(self, net, data_loader, ctx):
"""
net: network (mx.gluon.Block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use standard pydoc format, see the Gluon code for examples
Could the loss be passed as an argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use reStructuredText Docstring Format (PEP 287)

data_loader: training data loader (mx.gluon.data.DataLoader)
ctx: context (mx.gpu or mx.cpu)
"""
self.net = net
self.data_loader = data_loader
self.ctx = ctx
# So we don't need to be in `for batch in data_loader` scope
# and can call for next batch in `iteration`
self.data_loader_iter = iter(self.data_loader)
self.net.collect_params().initialize(mx.init.Xavier(), ctx=self.ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

net.initialize()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

self.loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss()
self.trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .001})

def iteration(self, lr=None, take_step=True):
"""
lr: learning rate to use for iteration (float)
take_step: take trainer step to update weights (boolean)
"""
# Update learning rate if different this iteration
if lr and (lr != self.trainer.learning_rate):
self.trainer.set_learning_rate(lr)
# Get next batch, and move context (e.g. to GPU if set)
data, label = next(self.data_loader_iter)
data = data.as_in_context(self.ctx)
label = label.as_in_context(self.ctx)
# Standard forward and backward pass
with mx.autograd.record():
output = self.net(data)
loss = self.loss_fn(output, label)
loss.backward()
# Update parameters
if take_step: self.trainer.step(data.shape[0])
# Set and return loss.
# Although notice this is still an MXNet NDArray to avoid blocking
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think blocking would be good here since we are not trying to optimize performance and we shouldn't want to risk cuda malloc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was anticipating a comment, so made non blocking :)
Changed to be non-blocking.

self.iteration_loss = mx.nd.mean(loss)
return self.iteration_loss

def close(self):
# Close open iterator and associated workers
self.data_loader_iter.shutdown()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is simple self.data_loader_iter.close() is not enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After experimenting, determined that shutdown was best option for reliable closing.

```

We also adjust our `DataLoader` so that it continuously provides batches of data and doesn't stop after a single epoch. We can then call `iteration` as many times as required for the loss to diverge as part of the Learning Rate Finder process. We implement a custom `BatchSampler` for this, that keeps returning random indices of samples to be included in the next batch. We use the CIFAR-10 dataset for image classification to test our Learning Rate Finder.


```python
from multiprocessing import cpu_count
from mxnet.gluon.data.vision import transforms

transform = transforms.Compose([
# Switches HWC to CHW, and converts to `float32`
transforms.ToTensor(),
# Channel-wise, using pre-computed means and stds
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010])
])

dataset = mx.gluon.data.vision.datasets.CIFAR10(train=True).transform_first(transform)

class ContinuousBatchSampler():
def __init__(self, sampler, batch_size):
self._sampler = sampler
self._batch_size = batch_size

def __iter__(self):
batch = []
while True:
for i in self._sampler:
batch.append(i)
if len(batch) == self._batch_size:
yield batch
batch = []

sampler = mx.gluon.data.RandomSampler(len(dataset))
batch_sampler = ContinuousBatchSampler(sampler, batch_size=128)
data_loader = mx.gluon.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=cpu_count())
```

## Implementation

With preparation complete, we're ready to write our Learning Rate Finder that wraps the `Learner` we defined above. We implement a `find` method for the procedure, and `plot` for the visualization. Starting with a very low learning rate as defined by `lr_start` we train one iteration at a time and keep multiplying the learning rate by `lr_multiplier`. We analyse the loss and continue until it diverges according to `LRFinderStoppingCriteria` (which is defined later on). You may also notice that we save the parameters and state of the optimizer before the process and restore afterwards. This is so the Learning Rate Finder process doesn't impact the state of the model, and can be used at any point during training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"analyse" -> "analyze"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.



```python
from matplotlib import pyplot as plt

class LRFinder():
def __init__(self, learner):
"""
learner: able to take single iteration with given learning rate and return loss
and save and load parameters of the network (Learner)
"""
self.learner = learner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Private fields better to start with underscore "_"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will leave for this tutorial since stylistic, but will do in future!


def find(self, lr_start=1e-6, lr_multiplier=1.1, smoothing=0.3):
"""
lr_start: learning rate to start search (float)
lr_multiplier: factor the learning rate is multiplied by at each step of search (float)
smoothing: amount of smoothing applied to loss for stopping criteria (float)
"""
# Used to initialize weights; pass data, but don't take step.
# Would expect for new model with lazy weight initialization
self.learner.iteration(take_step=False)
# Used to initialize trainer (if no step has been taken)
if not self.learner.trainer._kv_initialized:
self.learner.trainer._init_kvstore()
# Store params and optimizer state for restore after lr_finder procedure
self.learner.net.save_params("lr_finder.params")
Copy link
Contributor

@safrooze safrooze Jun 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the code for saving and restoring parameters adding unnecessary complication. Given that the LR finding is a completely separate step before actual learning and its output plot must be manually analyzed, why not just create a completely separate network/trainer for LR Finder and once this process is over, create another network/trainer for confirming that correct LR has been selected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been suggested that this method can be used during training too, which would require save and load to get back to original state. Will add a comment to explain this. And also helps with reproducibility of results of learning rate finder (for different settings) since the same initialization values are used each time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"save_params" is deprecated. Should we use save_parameters instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

self.learner.trainer.save_states("lr_finder.state")
lr = lr_start
self.results = [] # List of (lr, loss) tuples
stopping_criteria = LRFinderStoppingCriteria(smoothing)
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to rewrite to while not stopping_criteria(loss), as it looks cleaner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would need duplication of two lines of code, so prefer as is.

# Run iteration, and block until loss is calculated.
loss = self.learner.iteration(lr).asscalar()
self.results.append((lr, loss))
if stopping_criteria(loss):
break
lr = lr * lr_multiplier
# Restore params (as finder changed them)
self.learner.net.load_params("lr_finder.params", ctx=self.learner.ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_params is deprecated. Should we use load_parameters instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

self.learner.trainer.load_states("lr_finder.state")
self.plot()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you return results here rather than calling plot? Users might not be a in a graphical environment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed plot, returned results. Called plot separately.


def plot(self):
lrs = [e[0] for e in self.results]
losses = [e[1] for e in self.results]
plt.figure(figsize=(6,8))
plt.scatter(lrs, losses)
plt.xlabel("Learning Rate")
plt.ylabel("Loss")
plt.xscale('log')
plt.yscale('log')
axes = plt.gca()
axes.set_xlim([lrs[0], lrs[-1]])
y_lower = min(losses) * 0.8
y_upper = losses[0] * 4
axes.set_ylim([y_lower, y_upper])
plt.show()
```


You can define the `LRFinderStoppingCriteria` as you wish, but empirical testing suggests using a smoothed average gives a more consistent stopping rule (see `smoothing`). We stop when the smoothed average of the loss exceeds twice the initial loss, assuming there have been a minimum number of iterations (see `min_iter`).


```python
class LRFinderStoppingCriteria():
def __init__(self, smoothing=0.3, min_iter=20):
"""
smoothing: applied to running mean which is used for thresholding (float)
min_iter: minimum number of iterations before early stopping can occur (int)
"""
self.smoothing = smoothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Underscore as a prefix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before.

self.min_iter = min_iter
self.first_loss = None
self.running_mean = None
self.counter = 0

def __call__(self, loss):
"""
loss: from single iteration (float)
"""
self.counter += 1
if self.first_loss is None:
self.first_loss = loss
if self.running_mean is None:
self.running_mean = loss
else:
self.running_mean = ((1 - self.smoothing) * loss) + (self.smoothing * self.running_mean)
return (self.running_mean > self.first_loss * 2) and (self.counter >= self.min_iter)
```

## Usage

Using a Pre-activation ResNet-18 from the Gluon model zoo, we instantiate our Learner and fire up our Learning Rate Finder!


```python
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
net = mx.gluon.model_zoo.vision.resnet18_v2(classes=10)
learner = Learner(net=net, data_loader=data_loader, ctx=ctx)
lr_finder = LRFinder(learner)
lr_finder.find(lr_start=1e-6)
```


![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/lr_finder/finder_plot.png) <!--notebook-skip-line-->


As discussed before, we should select a learning rate where the loss is falling (i.e. from 0.001 to 0.05) but before the loss starts to diverge (i.e. 0.1). We prefer higher learning rates where possible, so we select an initial learning rate of 0.05. Just as a test, we will run 500 epochs using this learning rate and evaluate the loss on the final batch. As we're working with a single batch of 128 samples, the variance of the loss estimates will be reasonably high, but it will give us a general idea. We save the initialized parameters for a later comparison with other learning rates.


```python
learner.net.save_params("net.params")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_params is deprecated. Should we use save_parameters instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

lr = 0.05

for iter_idx in range(500):
learner.iteration(lr=lr)
if ((iter_idx % 100) == 0):
print("Iteration: {}, Loss: {:.5g}".format(iter_idx, learner.iteration_loss.asscalar()))
print("Final Loss: {:.5g}".format(learner.iteration_loss.asscalar()))
```

Iteration: 0, Loss: 2.785 <!--notebook-skip-line-->

Iteration: 100, Loss: 1.6653 <!--notebook-skip-line-->

Iteration: 200, Loss: 1.4891 <!--notebook-skip-line-->

Iteration: 300, Loss: 1.0846 <!--notebook-skip-line-->

Iteration: 400, Loss: 1.0633 <!--notebook-skip-line-->

Final Loss: 1.1812 <!--notebook-skip-line-->


We see a sizable drop in the loss from approx. 2.7 to 1.2.

And now we have a baseline, let's see what happens when we train with a learning rate that's higher than advisable at 0.5.


```python
net = mx.gluon.model_zoo.vision.resnet18_v2(classes=10)
learner = Learner(net=net, data_loader=data_loader, ctx=ctx)
learner.net.load_params("net.params", ctx=ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_params is deprecated. Should we use load_parameters instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

lr = 0.5

for iter_idx in range(500):
learner.iteration(lr=lr)
if ((iter_idx % 100) == 0):
print("Iteration: {}, Loss: {:.5g}".format(iter_idx, learner.iteration_loss.asscalar()))
print("Final Loss: {:.5g}".format(learner.iteration_loss.asscalar()))
```

Iteration: 0, Loss: 2.6469 <!--notebook-skip-line-->

Iteration: 100, Loss: 1.9666 <!--notebook-skip-line-->

Iteration: 200, Loss: 1.6919 <!--notebook-skip-line-->

Iteration: 300, Loss: 1.3643 <!--notebook-skip-line-->

Iteration: 400, Loss: 1.4743 <!--notebook-skip-line-->

Final Loss: 1.366 <!--notebook-skip-line-->


We still observe a fall in the loss but aren't able to reach as low as before.

And lastly, we see how the model trains with a more conservative learning rate of 0.005.


```python
net = mx.gluon.model_zoo.vision.resnet18_v2(classes=10)
learner = Learner(net=net, data_loader=data_loader, ctx=ctx)
learner.net.load_params("net.params", ctx=ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_params is deprecated. Should we use load_parameters instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

lr = 0.005

for iter_idx in range(500):
learner.iteration(lr=lr)
if ((iter_idx % 100) == 0):
print("Iteration: {}, Loss: {:.5g}".format(iter_idx, learner.iteration_loss.asscalar()))
print("Final Loss: {:.5g}".format(learner.iteration_loss.asscalar()))
```

Iteration: 0, Loss: 2.605 <!--notebook-skip-line-->

Iteration: 100, Loss: 1.8621 <!--notebook-skip-line-->

Iteration: 200, Loss: 1.6316 <!--notebook-skip-line-->

Iteration: 300, Loss: 1.6295 <!--notebook-skip-line-->

Iteration: 400, Loss: 1.4019 <!--notebook-skip-line-->

Final Loss: 1.2919 <!--notebook-skip-line-->


Although we get quite similar results to when we set the learning rate at 0.05 (because we're still in the region of falling loss on the Learning Rate Finder plot), we can still optimize our network faster using a slightly higher rate.

## Wrap Up

Give Learning Rate Finder a try on your current projects, and experiment with the different learning rate schedules found in this tutorial too.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back to look for the "different learning rate schedules" and they not explicitly defined. Or maybe I missed that point, but I wasn't really able to find it.


<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
1 change: 1 addition & 0 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Select API:&nbsp;
* [Inference using an ONNX model](/tutorials/onnx/inference_on_onnx_model.html)
* [Fine-tuning an ONNX model on Gluon](/tutorials/onnx/fine_tuning_gluon.html)
* [Visualizing Decisions of Convolutional Neural Networks](/tutorials/vision/cnn_visualization.html)
* [Learning Rate Finder](/tutorials/gluon/learning_rate_finder.html)
* [Profiling MXNet Models](/tutorials/python/profiler.html)
* API Guides
* Core APIs
Expand Down
3 changes: 3 additions & 0 deletions tests/tutorials/test_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def test_gluon_hybrid():
def test_gluon_pretrained_models():
assert _test_tutorial_nb('gluon/pretrained_models')

def test_gluon_learning_rate_finder():
assert _test_tutorial_nb('gluon/learning_rate_finder')

def test_nlp_cnn():
assert _test_tutorial_nb('nlp/cnn')

Expand Down