<a target="_blank" href="https://colab.research.google.com/github/PacktPublishing/Deep-Learning-Model-Visualization/blob/main/Chapter03/DLMV_Chapter03_04_GradientDescent_bqplot.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Chapter 03 - 04 - Gradient Descent with `bqplot`

## Gradient Descent

In [1]:
import numpy as np
def gradient_descent(start, grad_fn, lr, max_iters, tol):
  x = start
  history = [x]
  iter = 0
  while iter < max_iters:
    step = lr*grad_fn(x)
    x = x - step
    iter += 1
    history.append(x)
    if (np.abs(step) <= tol):
      break
  return np.array(history)

## Functions and Data

In [2]:
# function/data

def f(x):
  return x**2 - 10*np.sin(x)

def dfdx(x):
  return 2*x - 10*np.cos(x)

x = np.linspace(-2*np.pi, 2*np.pi, 100)

# Exploration

In [3]:
# The only change is in here
import bqplot.pyplot as plt
import bqplot as bq
# add support for custom widget
from google.colab import output
output.enable_custom_widget_manager()

In [4]:
print(bq.__version__)

0.12.43


In [5]:
def plot_history(step, x, f, history, lr, tol):
  # Step 1: create a figure and plot the objective function
  fig = plt.figure(
      layout=dict(width="500px", height="500px"))
  title = f'learning rate: {lr:.3f}, starting: {history[0]:.2f}, tolarance: {tol:.3f}'
  fig.title = title
  plt.plot(x, f(x))
  # Step 2: plot points and connection lines
  # data until this step
  x = history[:step+1]
  y = f(x)
  # tooltip for the points
  tooltip = bq.Tooltip(fields=["x", "y"],
                        formats=[".2f", ".2f"])
  # plot the points
  plt.scatter(x, y, stroke='black', tooltip=tooltip)
  # plot the connections
  plt.plot(x, y, 'r--')
  # Step 3: plot the annotations
  msg = f'iteration {step}, result {y[-1]:.2f}'
  plt.label(text=[msg], x=[-2], y=[30])
  plt.plot([-2, x[-1]], [30, y[-1]],'m--')
  # show
  display(fig)

In [6]:
import ipywidgets as widgets
from IPython.display import display
from IPython.display import HTML
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))

startSlider = widgets.FloatSlider(
    value=-2*np.pi,
    min=-2*np.pi,
    max=2*np.pi,
    step=0.1,
    description='start'
    )
lrSlider = widgets.FloatSlider(
    value=0.05,
    min=0.001,
    max=2.0,
    step=0.001,
    description='learning rate',
    readout_format='.3f'
)
maxIterSlider = widgets.IntSlider(
    value=100,
    min=1,
    max=200,
    step=1,
    description='max iters'
)
tolSlider = widgets.FloatSlider(
    value=0.001,
    min=0.001,
    max=1.0,
    step=0.001,
    description='tolerance',
    readout_format='.4f'
)


btn = widgets.Button(description='Train')
output = widgets.Output()
# arrange the widgets
hyperparameters = widgets.VBox(
    [startSlider, lrSlider, maxIterSlider, tolSlider, btn])
display(widgets.HBox([hyperparameters, output]))

# clear output
def clear_output(evt):
  output.clear_output()

startSlider.observe(clear_output, names=['value'], type='change')
lrSlider.observe(clear_output, names=['value'], type='change')
maxIterSlider.observe(clear_output, names=['value'], type='change')
tolSlider.observe(clear_output, names=['value'], type='change')

# playback
def playback(history, output, lr, tol):
  # Step 1: The playback controls
  play = widgets.Play(
      value=0, min=0, max=len(history)-1, step=1,
      interval=1000,
      description="Press play"
  )
  iter = widgets.IntSlider(
      value=0, min=0, max=len(history)-1, step=1,
      description='iteration'
  )
  widgets.jslink((play, 'value'), (iter, 'value'))

  # Step 2: The visualization output
  output.clear_output()
  plotoutput = widgets.interactive_output(
      plot_history,
      {
          "step":iter,
          "f": widgets.fixed(f),
          "x": widgets.fixed(x),
          "history": widgets.fixed(history),
          "lr": widgets.fixed(lr),
          "tol":widgets.fixed(tol)
      })
  # Step 3: The GUI
  with output:
    display(widgets.VBox(
        [widgets.HBox([iter, play]), plotoutput]
        ))

# event handler
def on_click(btn):
  history = gradient_descent(
      start=startSlider.value,
      grad_fn=dfdx,
      lr=lrSlider.value,
      max_iters=maxIterSlider.value,
      tol=tolSlider.value
  )
  playback(history, output, lrSlider.value, tol=tolSlider.value)

# handle events
btn.on_click(on_click);

HBox(children=(VBox(children=(FloatSlider(value=-6.283185307179586, description='start', max=6.283185307179586…