-
Notifications
You must be signed in to change notification settings - Fork 13
/
keras_callback.py
69 lines (54 loc) · 2.43 KB
/
keras_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from keras.callbacks import Callback
import keras.backend as K
import numpy as np
import matplotlib.pyplot as plt
class LRFinder(Callback):
def __init__(self, min_lr, max_lr, mom=0.9, stop_multiplier=None,
reload_weights=True, batches_lr_update=5):
self.min_lr = min_lr
self.max_lr = max_lr
self.mom = mom
self.reload_weights = reload_weights
self.batches_lr_update = batches_lr_update
if stop_multiplier is None:
self.stop_multiplier = -20*self.mom/3 + 10 # 4 if mom=0.9
# 10 if mom=0
else:
self.stop_multiplier = stop_multiplier
def on_train_begin(self, logs={}):
p = self.params
try:
n_iterations = p['epochs']*p['samples']//p['batch_size']
except:
n_iterations = p['steps']*p['epochs']
self.learning_rates = np.geomspace(self.min_lr, self.max_lr, \
num=n_iterations//self.batches_lr_update+1)
self.losses=[]
self.iteration=0
self.best_loss=0
if self.reload_weights:
self.model.save_weights('tmp.hdf5')
def on_batch_end(self, batch, logs={}):
loss = logs.get('loss')
if self.iteration!=0: # Make loss smoother using momentum
loss = self.losses[-1]*self.mom+loss*(1-self.mom)
if self.iteration==0 or loss < self.best_loss:
self.best_loss = loss
if self.iteration%self.batches_lr_update==0: # Evaluate each lr over 5 epochs
if self.reload_weights:
self.model.load_weights('tmp.hdf5')
lr = self.learning_rates[self.iteration//self.batches_lr_update]
K.set_value(self.model.optimizer.lr, lr)
self.losses.append(loss)
if loss > self.best_loss*self.stop_multiplier: # Stop criteria
self.model.stop_training = True
self.iteration += 1
def on_train_end(self, logs=None):
if self.reload_weights:
self.model.load_weights('tmp.hdf5')
plt.figure(figsize=(12, 6))
plt.plot(self.learning_rates[:len(self.losses)], self.losses)
plt.xlabel("Learning Rate")
plt.ylabel("Loss")
plt.xscale('log')
plt.show()