In [5]:
from bokeh.io import output_notebook, show
from bokeh.plotting import figure, ColumnDataSource
from bokeh.layouts import column
from bokeh.models import Slider, CustomJS
import numpy as np

output_notebook()

# Parameters
initial_lr = 1e-3
min_lr = 1e-6
num_epochs = 500
epochs = np.arange(num_epochs)

# Initial slider values
lr_start_factor_init = 0.1
milestone_init = 50
gamma_init = 0.99

# Initial data
def compute_lr_schedule(lr_start_factor, milestone, gamma):
    lrs = np.zeros(num_epochs)
    for epoch in range(num_epochs):
        if epoch < milestone:
            lr = initial_lr * (lr_start_factor + (1 - lr_start_factor) * (epoch / milestone))
        else:
            decay_epochs = epoch - milestone
            lr = initial_lr * (gamma ** decay_epochs)
            lr = max(lr, min_lr)
        lrs[epoch] = lr
    return lrs

lrs = compute_lr_schedule(lr_start_factor_init, milestone_init, gamma_init)
source = ColumnDataSource(data=dict(epochs=epochs, lrs=lrs))

# Plot
p = figure(title="Learning Rate Schedule", x_axis_label='Epoch', y_axis_label='Learning Rate')
           #plot_height=400, plot_width=700)
p.line('epochs', 'lrs', source=source)

# Sliders
lr_start_factor_slider = Slider(start=0.1, end=1.0, value=lr_start_factor_init, step=0.01, title="LR Start Factor")
milestone_slider = Slider(start=10, end=300, value=milestone_init, step=4, title="Milestone")
gamma_slider = Slider(start=0.985, end=0.999, value=gamma_init, step=0.0001, title="Gamma")

# CustomJS callback
callback = CustomJS(args=dict(source=source,
                              lr_start_factor=lr_start_factor_slider,
                              milestone=milestone_slider,
                              gamma=gamma_slider,
                              initial_lr=initial_lr,
                              min_lr=min_lr,
                              num_epochs=num_epochs),
                    code="""
    var data = source.data;
    var epochs = data['epochs'];
    var lrs = data['lrs'];
    var lr_start_factor_val = lr_start_factor.value;
    var milestone_val = milestone.value;
    var gamma_val = gamma.value;
    var initial_lr_val = initial_lr;
    var min_lr_val = min_lr;
    var num_epochs_val = num_epochs;

    for (var i = 0; i < num_epochs_val; i++) {
        if (i < milestone_val) {
            var lr = initial_lr_val * (lr_start_factor_val + (1 - lr_start_factor_val) * (i / milestone_val));
        } else {
            var decay_epochs = i - milestone_val;
            var lr = initial_lr_val * Math.pow(gamma_val, decay_epochs);
            if (lr < min_lr_val) {
                lr = min_lr_val;
            }
        }
        lrs[i] = lr;
    }
    source.change.emit();
""")

lr_start_factor_slider.js_on_change('value', callback)
milestone_slider.js_on_change('value', callback)
gamma_slider.js_on_change('value', callback)

# Layout
layout = column(p, lr_start_factor_slider, milestone_slider, gamma_slider)
show(layout)
