In [1]:
import argparse

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from tqdm import tqdm, trange

import descent_directions
import visualization       

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--epsilon", default=1e-4, type=float)
    parser.add_argument("--max_iter", default=256, type=int)

    return parser.parse_args(args=[])

In [3]:
def descent(args, init_x, init_y, f, grad_f, hessian_f, descent_dir_fn):
    x_traj = np.zeros(args.max_iter)
    y_traj = np.zeros(args.max_iter)
    val_traj = np.zeros(args.max_iter)
    
    x_traj[0] = init_x
    y_traj[0] = init_y
    val_traj[0] = f(x_traj[0], y_traj[0])

    for i in trange(args.max_iter):
        
        d_x, d_y = descent_dir_fn(x_traj[i], y_traj[i], f, grad_f, hessian_f)
        
        x_traj[i+1] = x_traj[i] + d_x
        y_traj[i+1] = y_traj[i] + d_y
        val_traj[i+1] = f(x_traj[i+1], y_traj[i+1])
        if val_traj[i] - val_traj[i-1] < args.epsilon:
            break
        
    return x_traj, y_traj, val_traj, i

In [None]:
sns.set(rc={'figure.figsize': (12.0, 8.0)})
args = get_args()

f = lambda x, y: (1 - x)**2 + 2 * (x**2 - y)**2
grad_f = lambda x, y: np.array([2 * (4 * x**3 - 4 * x * y + x - 1), 4 * y - 4 * x**2])
hessian_f = lambda x: np.asarray(
    [24 * x**2 - 8 * y + 2, -8 * x],
    [-8 * x, 4]
)

init_x, init_y = np.random.randn(2)
visualization.contour_plot(f, [-2.5, 2.5], [-2.5, 2.5], 1000, 1000, 'RdBu')

methods_dict = {
    "1-norm": descent_directions.descent_1n,
    "2-norm": descent_directions.descent_2n,
    "inf-norm": descent_directions.descent_in,
    "fr": descent_directions.descent_fr,
    "pr": descent_directions.descent_pr
}

labels = []
x_trajs = []
y_trajs = []
val_trajs = []

for method, descent_fn in methods_dict.items():
    x_traj, y_traj, val_traj, last_it = descent(
        args,
        init_x, init_y,
        f, grad_f, hessian_f,
        descent_fn
    )
    
    labels.append(method)
    x_trajs.append(x_traj[:last_it])
    y_trajs.append(y_traj[:last_it])
    val_trajs.append(val_traj[:last_it])
    
    print(f"{method}:\t f(x, y) = {val_traj[-1]}")
    
visualization.contour_plot_with_tours(f, [-2.5, 2.5], [-2.5, 2.5], 1000, 1000, 'RdBu')
visualization.val_descent(val_trajs, labels)