Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
60 lines (47 sloc)
2.3 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
from tensorflow.keras.callbacks import Callback | |
class LRFinder(Callback): | |
"""`Callback` that exponentially adjusts the learning rate after each training batch between `start_lr` and | |
`end_lr` for a maximum number of batches: `max_step`. The loss and learning rate are recorded at each step allowing | |
visually finding a good learning rate as per https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html via | |
the `plot` method. | |
""" | |
def __init__(self, start_lr: float = 1e-7, end_lr: float = 10, max_steps: int = 100, smoothing=0.9): | |
super(LRFinder, self).__init__() | |
self.start_lr, self.end_lr = start_lr, end_lr | |
self.max_steps = max_steps | |
self.smoothing = smoothing | |
self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0 | |
self.lrs, self.losses = [], [] | |
def on_train_begin(self, logs=None): | |
self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0 | |
self.lrs, self.losses = [], [] | |
def on_train_batch_begin(self, batch, logs=None): | |
self.lr = self.exp_annealing(self.step) | |
tf.keras.backend.set_value(self.model.optimizer.lr, self.lr) | |
def on_train_batch_end(self, batch, logs=None): | |
logs = logs or {} | |
loss = logs.get('loss') | |
step = self.step | |
if loss: | |
self.avg_loss = self.smoothing * self.avg_loss + (1 - self.smoothing) * loss | |
smooth_loss = self.avg_loss / (1 - self.smoothing ** (self.step + 1)) | |
self.losses.append(smooth_loss) | |
self.lrs.append(self.lr) | |
if step == 0 or loss < self.best_loss: | |
self.best_loss = loss | |
if smooth_loss > 4 * self.best_loss or tf.math.is_nan(smooth_loss): | |
self.model.stop_training = True | |
if step == self.max_steps: | |
self.model.stop_training = True | |
self.step += 1 | |
def exp_annealing(self, step): | |
return self.start_lr * (self.end_lr / self.start_lr) ** (step * 1. / self.max_steps) | |
def plot(self): | |
fig, ax = plt.subplots(1, 1) | |
ax.set_ylabel('Loss') | |
ax.set_xlabel('Learning Rate (log scale)') | |
ax.set_xscale('log') | |
ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e')) | |
ax.plot(self.lrs, self.losses) |