In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
from set_optimizer import build_optimizer, OptimizerSetting

In [None]:

init_x = 2.0
init_y = -4.0
n_iter = 400

# Main Function

In [None]:
def run(init_x, init_y, lr_x, lr_y, n_iterm, optim_name):
    # x, y Initialize
    x = torch.tensor(init_x, requires_grad = True)
    y = torch.tensor(init_y, requires_grad = True)

    # Objective
    obj = (1 + x**2) * (100 - y**2)

    if 'cgd' in optim_name:
        beta_update_rule = optim_name.split('_')[1]
        beta_momentum_coeff = float(optim_name.split('_')[2])
        optim_name = optim_name.split('_')[0]
    else:
        beta_update_rule = None
        beta_momentum_coeff = None

    #CGD
    optimizer_x = build_optimizer(
            OptimizerSetting(name=optim_name,
                            weight_decay = 0,
                            lr=lr_x,
                            momentum = 0.5,
                            beta_update_rule=beta_update_rule,
                            beta_momentum_coeff = beta_momentum_coeff,
                            model=[x]))
    
    optimizer_y = build_optimizer(
            OptimizerSetting(name=optim_name,
                            weight_decay = 0,
                            lr=lr_y,
                            momentum = 0.5,
                            beta_update_rule=beta_update_rule,
                            beta_momentum_coeff = beta_momentum_coeff,
                            model=[y]))
    optimizer_y.param_groups[0]['lr'] *= -1 
    
    lr_schedule_x = torch.optim.lr_scheduler.LambdaLR(optimizer_x, lr_lambda = lambda steps: 1)
    lr_schedule_y = torch.optim.lr_scheduler.LambdaLR(optimizer_y, lr_lambda = lambda steps: 1)

    x_hist = []
    y_hist = []
    obj_list = []
    norm_list = []

    for i in range(n_iter):

        obj = (1 + x**2) * (100 - y**2)
        obj.backward()
        optimizer_x.step()
        lr_schedule_x.step()

        obj = (1 + x**2) * (100 - y**2)
        obj.backward()
        optimizer_y.step() 
        lr_schedule_y.step()

        np_x = x.detach().numpy()
        np_y = y.detach().numpy()

        norm = np.sqrt(np_x ** 2 + np_y ** 2)
        obj = (1 + x**2) * (100 - y**2)

        x_hist.append(np_x.copy())
        y_hist.append(np_y.copy())
        obj_list.append(obj.detach().numpy())
        norm_list.append(norm)

    print (x, y, obj, norm)
    return x_hist, y_hist, obj_list, norm_list

# Plotter

## 2D

In [None]:
def plot_trajectory_2D(res_list:list, optim_name_list:list, lr_x, lr_y):

    cmap_label = cm.get_cmap('tab10',len(res_list))

    fig = plt.figure(figsize=[22, 5])
    ax =  fig.subplots(1, 3)

    counter = 0
    for res, optim_name in zip(res_list, optim_name_list):

        ax[0].plot(res[2], color=cmap_label(counter), label=f'{optim_name}')
        ax[1].plot(res[3], color=cmap_label(counter), label=f'{optim_name}')

        ax[2].scatter(res[0], res[1], s=1, alpha=1.0, c=cmap_label(counter), label=f'{optim_name}')

        ax[0].set_title("objective")
        ax[0].set_xlabel('steps')
        
        ax[1].set_title("norms")
        ax[1].set_xlabel('steps')

        ax[2].set_title("x vs y")
        ax[2].set_xlabel('x')
        ax[2].set_ylabel('y')

        counter+=1
        
    ax[0].legend()
    
    plt.savefig(f'../figs/toy_example/2D_{lr_x}_{lr_y}.pdf')
    plt.show()

## 3D

In [None]:
def plot_trajectory_3D(res_list:list, optim_name_list:list, lr_x, lr_y):

    cmap_label = cm.get_cmap('tab10',len(res_list))

    fig = plt.figure(figsize=[14,10])
    ax = fig.add_subplot(projection='3d')

    x = np.arange(-2.5, 2.5, 0.25)
    y = np.arange(-12, 12, 0.25)

    X, Y = np.meshgrid(x, y)
    Z = (1 + X**2) * (100 - Y**2)

    ax.plot_surface(X,Y,Z, cmap=cm.viridis, alpha=0.7)
    ax.scatter(0, 0, 100,marker='*', c='black', s=50, label='saddle point')
    ax.scatter(res_list[0][0][0], res_list[0][1][0], res_list[0][2][0],marker='+', c='black', s=50, label='start point')

    counter = 0
    for res, optim_name in zip(res_list, optim_name_list):

        ax.plot(res[0], res[1], res[2], markersize=10, alpha=1, c=cmap_label(counter), label=f'{optim_name}')
        ax.scatter(res[0][-1], res[1][-1], res[2][-1],marker='*', s=50, c=cmap_label(counter), label=f'end point of {optim_name}')

        counter += 1

    plt.legend(loc='upper right', fontsize=12)
    angle = 60
    ax.view_init(elev=angle, azim=240)
    obj_caption = r'$f(x,y) = (1+x^2) \cdot (100-y^2)$'
    plt.xlim(-3,3)
    plt.ylim(-15,15)
    ax.set_zlim(-300,700)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.savefig(f'../figs/toy_example/3D_{lr_x}_{lr_y}.pdf')
    plt.show()


# Call Plot Function

In [None]:
res_list =[ ]
optim_names = ['vanilla_sgd', 'momentum_sgd', 'cgd_FR_0.5', 'cgd_FR_1.0', 'cgd_PRP_0.5', 'cgd_PRP_1.0']


lr_x_list = [0.000005, 0.00001, 0.000025, 0.00005]
lr_y_list = [0.000005, 0.00001, 0.000025, 0.00005]

res_list =[ ]
counter = 0
n_iter = 400

for lr_x in lr_x_list:
    for lr_y in lr_y_list:
        res_list =[ ]
        for optim_name in optim_names:

            print(optim_name)
            res = run(init_x, init_y, lr_x, lr_y, n_iter, optim_name)
            res_list.append(res)

        plot_trajectory_2D(res_list, optim_names, lr_x, lr_y)
        plot_trajectory_3D(res_list, optim_names, lr_x, lr_y)