In [None]:
!pip install optax

Collecting optax
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
[?25l[K     |██▎                             | 10 kB 21.2 MB/s eta 0:00:01[K     |████▋                           | 20 kB 19.3 MB/s eta 0:00:01[K     |███████                         | 30 kB 18.4 MB/s eta 0:00:01[K     |█████████▎                      | 40 kB 17.3 MB/s eta 0:00:01[K     |███████████▋                    | 51 kB 15.8 MB/s eta 0:00:01[K     |██████████████                  | 61 kB 17.7 MB/s eta 0:00:01[K     |████████████████▎               | 71 kB 19.0 MB/s eta 0:00:01[K     |██████████████████▋             | 81 kB 20.0 MB/s eta 0:00:01[K     |█████████████████████           | 92 kB 21.3 MB/s eta 0:00:01[K     |███████████████████████▎        | 102 kB 21.9 MB/s eta 0:00:01[K     |█████████████████████████▋      | 112 kB 21.9 MB/s eta 0:00:01[K     |████████████████████████████    | 122 kB 21.9 MB/s eta 0:00:01[K     |██████████████████████████████▎ | 133 kB 21.9 MB/s eta 0:00:

In [None]:
import jax
import jax.numpy as jnp

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
@jax.jit
def loss_fn(param, train, labels):
  W, b = param
  preds = train @ W + b
  losses = labels * jax.nn.log_sigmoid(preds)
  return -jnp.mean(losses)

In [None]:
def hessian(f):
    return jax.jit(jax.jacfwd(jax.jacrev(f)))

In [None]:
from jax.scipy.sparse.linalg import cg

@jax.jit
def solve_newton(hessian, grad):
  def matvec(x):
    top = jnp.dot(hessian[0][0], x[0]) + hessian[0][1]*x[1]
    bottom = jnp.dot(hessian[1][0], x[0]) + hessian[1][1]*x[1]
    return top, bottom
  return cg(matvec, grad)[0]

In [None]:
from optax import adam
import optax

In [None]:
import tqdm

def run_xp(dtype, method, num_steps, verbose=0):
  if False:
    X = jnp.array([[-1.,-1.], [-1., 1.], [1., -1.], [1., 1.]], dtype=dtype)
    Y = jnp.array([0, 0, 1, 1], dtype=dtype)
    W, b = jnp.array([-1., -3.], dtype=dtype), jnp.array(2., dtype=dtype)
  else:
    X = jnp.array([[-1.], [1.]], dtype=dtype)
    Y = jnp.array([[0], [1]], dtype=dtype)
    W, b = jnp.array([-3.], dtype=dtype), jnp.array(2., dtype=dtype)
  losses = []
  grad_norms = []
  params = [(W, b)]
  params_norms = [(jnp.sum(W**2) + b**2) ** 0.5]
  value_and_grad = jax.value_and_grad(loss_fn)
  hessian_fun = hessian(loss_fn)
  step_size = 1. if method == 'newton' else 1.
  steps = jnp.arange(num_steps)
  step_W, step_b = jnp.array(0.), jnp.array(0.)
  if method == 'adam':
    optimizer = adam(step_size)
    opt_state = optimizer.init((W, b))
  for step in tqdm.trange(num_steps):
    loss, (grad_W, grad_b) = value_and_grad((W, b), X, Y)
    H = hessian_fun((W, b), X, Y)
    grad_norm = (jnp.sum(grad_W**2) + grad_b**2)**0.5
    if method == 'grad':
      step_W, step_b = grad_W, grad_b
    elif method == 'momentum':
      step_W = 0.9 * step_W + 0.1 * grad_W
      step_b = 0.9 * step_b + 0.1 * grad_b
    elif method == 'adam':
      updates, opt_state = optimizer.update((grad_W, grad_b), opt_state, (W, b))
      (W, b) = optax.apply_updates((W, b), updates)
    elif method == 'newton':
      step_W, step_b = solve_newton(H, (grad_W, grad_b))
    step_W, step_b = jnp.nan_to_num(step_W), jnp.nan_to_num(step_b)
    W = W - step_size * step_W
    b = b - step_size * step_b
    params_norm = (jnp.sum(W**2) + b**2) ** 0.5
    losses.append(loss)
    params_norms.append(params_norm)
    grad_norms.append(grad_norm)
    params.append((W, b))
  print('')
  if verbose:
    print(f'Loss'.ljust(25), f'Gradient Norm'.ljust(25), f'(         W        ,           b       )'.ljust(20))
    for loss, gradnorm, (W, b) in list(zip(losses, grad_norms, params)):
      print(f'{loss}'.ljust(25), f'{gradnorm}'.ljust(25), f'{tuple(float(w) for w in W) + (float(b),)}'.ljust(20))
  print(f'{losses[-1]}'.ljust(25), f'{grad_norms[-1]}'.ljust(25), f'{tuple(float(w) for w in params[-1][0]) + (float(params[-1][1]),)}'.ljust(20))
  return steps, losses, grad_norms, params, params_norms

In [None]:
num_steps = 50
nb_bits = 64
if nb_bits == 32:
  dtype = jnp.float32
elif nb_bits == 64:
  dtype = jnp.float64
else:
  assert False

In [None]:
steps, losses64, grad_norms64, params64, params_norms64 = run_xp(jnp.float64, 'newton', num_steps=num_steps, verbose=0)

100%|██████████| 50/50 [00:00<00:00, 59.33it/s]


2.2113448852757038e-23    3.0062780963628e-23       (-1.5953748631338067, 53.4152540064954)





In [None]:
steps, losses32, grad_norms32, params32, params_norms32 = run_xp(jnp.float32, 'newton', num_steps=num_steps, verbose=0)

100%|██████████| 50/50 [00:00<00:00, 59.06it/s]


4.328328494229705e-12     5.8842748382192944e-12    (-1.5953762531280518, 26.415239334106445)





In [None]:
steps, losses32_g, grad_norms32_g, params32_g, params_norms32_g = run_xp(dtype, 'grad', num_steps=num_steps, verbose=0)

100%|██████████| 50/50 [00:00<00:00, 155.54it/s]


0.012801957035884002      0.01583625888567597       (-1.0320903170025462, 4.1229148626892105)





In [None]:
steps, losses32_a, grad_norms32_a, params32_a, params_norms32_a = run_xp(dtype, 'adam', num_steps=num_steps, verbose=0)

100%|██████████| 50/50 [00:02<00:00, 17.64it/s]


0.00025723878355591635    0.00036350932064699873    (4.1214506808844, 11.027601687533423)





In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

g_blue = "#4285F4" 
g_green = "#38A854"
g_yellow = "#FBBC05"
g_red = "#EA4335"

plot_loss = False
if plot_loss:
  # Add traces
  fig.add_trace(
      go.Scatter(x=steps, y=losses64, name=f"Loss (Newton {64} bits)",
                line = dict(color=g_green, width=4, dash='dot')),
      secondary_y=True,
  )
  fig.add_trace(
      go.Scatter(x=steps, y=losses32, name=f"Loss (Newton {32} bits)",
                line = dict(color=g_yellow, width=4, dash='dot')),
      secondary_y=True,
  )
  fig.add_trace(
      go.Scatter(x=steps, y=losses32_g, name=f"Loss (GD {nb_bits} bits)",
                line = dict(color=g_blue, width=4, dash='dot')),
      secondary_y=True,
  )
  fig.add_trace(
      go.Scatter(x=steps, y=losses32_a, name=f"Loss (Adam {nb_bits} bits)",
                line = dict(color=g_red, width=4, dash='dot')),
      secondary_y=True,
  )

msg = 'Weight Norm'

fig.add_trace(
    go.Scatter(x=steps, y=[n for n in params_norms32_g], name=f"Gradient Descent float{nb_bits}",
               line = dict(color=g_blue, width=4)),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(x=steps, y=[n for n in params_norms32_a], name=f"Adam float{nb_bits}",
               line = dict(color=g_red, width=4)),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(x=steps, y=[n for n in params_norms64], name=f"Newton's method float{64}",
               line = dict(color=g_green, width=4)),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(x=steps, y=[n for n in params_norms32], name=f"Newton's method float{32}",
               line = dict(color=g_yellow, width=4)),
    secondary_y=False,
)

# Add figure title
fig.update_layout(title_text="", autosize=True,
                  width=1200, height=500,
                  font=dict(size=24),
                  legend=dict(
                    x=0.05,
                    y=0.9,
                    traceorder="reversed",
                    title_font_family="Computer Modern",
                    font=dict(
                        family="Computer Modern",
                        size=24,
                        color="black"
                    ),
                    bgcolor="White",
                    bordercolor="White",
                    borderwidth=2)
                  )

# Set x-axis title
fig.update_xaxes(title_text="Optimizer Step t", title_font_family='Computer Modern', title_font_size=32)

# Set y-axes titles
if plot_loss:
  fig.update_yaxes(title_text="<b>Loss</b> (logscale)", secondary_y=True, type='log')
fig.update_yaxes(title=dict(text="Parameters Norm", font_family='Computer Modern', font_size=36), secondary_y=False)

fig.show()