In [1]:
import matplotlib.patches as mpatches
import os
import autograd.numpy as np
import matplotlib.pylab as pylab
import numpy as np
import matplotlib.pyplot as plt
import autograd.numpy as np

from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from autograd import elementwise_grad, value_and_grad
from mpl_toolkits.mplot3d import Axes3D
from problems import func1, func2, func3, func4

import warnings

warnings.filterwarnings('ignore')
plt.rcParams.update({'font.size': 14})
def_colors=(plt.rcParams['axes.prop_cycle'].by_key()['color'])
import seaborn as sns
# sns.set_theme()
plt.rcParams['figure.facecolor'] = 'white'

In [2]:
params = {'mathtext.default': 'regular' } 
plt.rcParams.update(params)

In [3]:
def simgd(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    xopt, yopt = problem.xopt, problem.yopt
    x_hist, y_hist = [x], [y]
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]
    for i in range(iteration):
        g_x, g_y = problem.grad(x,y)
        x -= lr * g_x
        y += lr * g_y
        x_hist.append(x)
        y_hist.append(y)
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist


def altgd(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    xopt, yopt = problem.xopt, problem.yopt
    x_hist, y_hist = [x], [y]
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]
    for i in range(iteration):
        g_x, _ = problem.grad(x,y)
        x -= lr * g_x
        _, g_y = problem.grad(x,y)
        y += lr * g_y
        x_hist.append(x)
        y_hist.append(y)
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist


def adam(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    LR = lr
    xopt, yopt = problem.xopt, problem.yopt
    x_hist, y_hist = [x], [y]
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]
    BETA_1 = 0.5
    BETA_2 = 0.99
    EPSILON = 1e-8
    v_x, v_y = 0., 0.
    m_x, m_y = 0., 0.
    for i in range(iteration):
        g_x, g_y = problem.grad(x,y)
        m_x = BETA_1*m_x + (1-BETA_1)*g_x
        m_y = BETA_1*m_y + (1-BETA_1)*g_y 
        v_x = BETA_2*v_x + (1-BETA_2)*g_x**2
        v_y = BETA_2*v_y + (1-BETA_2)*g_y**2
        m_hat_x = m_x/(1-BETA_1**(i+1))
        m_hat_y = m_y/(1-BETA_1**(i+1))
        v_hat_x = v_x
        v_hat_y = v_y
        x = x - LR*m_hat_x/(np.sqrt(v_hat_x)+EPSILON)
        y = y + LR*m_hat_y/(np.sqrt(v_hat_y)+EPSILON)
        x_hist.append(x)
        y_hist.append(y)
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist


def avg(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    xopt, yopt = problem.xopt, problem.yopt
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]
    xavg, yavg = x, y
    x_hist, y_hist = [xavg], [yavg]
    for i in range(iteration):
        x = x - lr/np.sqrt(i+1)*(y)
        y = y + lr/np.sqrt(i+1)*(x)        
        xavg = xavg*(i+1)/(i+2) + x/(i+2)
        yavg = yavg*(i+1)/(i+2) + y/(i+2)        
        x_hist.append(xavg)
        y_hist.append(yavg)
        loss.append(problem.loss(xavg, yavg))
    return loss, x_hist, y_hist

def omd(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    x_l, y_l = 0.5*x0, 0.5*y0
    g_xl, g_yl = problem.grad(x_l,y_l)
    xopt, yopt = problem.xopt, problem.yopt
    x_hist, y_hist = [x], [y]
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]
    for i in range(iteration):
        g_x, g_y = problem.grad(x,y)
        x = x - 2 * lr * g_x + lr * g_xl
        y = y + 2 * lr * g_y - lr * g_yl
        x_hist.append(x)
        y_hist.append(y)
        g_xl, g_yl =  g_x, g_y
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist

def eg(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    xopt, yopt = problem.xopt, problem.yopt
    x_hist, y_hist = [x], [y]
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]
    for i in range(iteration):
        g_x, g_y = problem.grad(x,y)
        xe = x - lr * g_x
        ye = y + lr * g_y
        g_x, g_y = problem.grad(xe,ye)
        x -= lr * g_x
        y += lr * g_y
        x_hist.append(x)
        y_hist.append(y)
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist

def fr(problem, x0, y0, iteration, lr, k=0):
    x, y = x0, y0
    xopt, yopt = problem.xopt, problem.yopt
    x_hist, y_hist = [x], [y]
    loss = [np.sqrt((x-xopt)**2 + (y-yopt)**2)]    
    for i in range(iteration):
        g_x, g_y = problem.grad(x,y)
        mod = problem.fr(x, y)
        x -= lr * g_x
        y += lr * g_y + lr * mod * g_x
        x_hist.append(x)
        y_hist.append(y)
        loss.append(problem.loss(x, y))
    return loss, x_hist, y_hist

In [4]:
from method import NLTGCR
import torch
import numpy.linalg as nalg

In [5]:
def main(problem, iteration, x0, y0, lrset, k=5):
    allloss = [[] for _ in  range(4)]
    allxpath = [[] for _ in  range(4)]
    allypath = [[] for _ in  range(4)]
    allloss[3], allxpath[3], allypath[3]= nltgcr(problem, x0, y0, 10, lrset['fr'],k)  
    allloss[0], allxpath[0], allypath[0] = altgd(problem, x0, y0, iteration, lr=lrset['altgd'])
    allloss[1], allxpath[1], allypath[1] = eg(problem, x0, y0, iteration, lr=lrset['eg'])
    allloss[2], allxpath[2], allypath[2]= fr(problem, x0, y0, iteration, lr=lrset['fr'])   
    return allloss, allxpath, allypath

In [8]:
import jax.numpy as jnp

In [9]:
def nltgcr(problem, x0, y0, iteration, lr, k):
    lb = k
    epsf = 1e-3
    P = jnp.zeros((2, lb))
    AP = jnp.zeros((2, lb))
    w = jnp.array([x0, y0])
    def FF(w):
        g_x, g_y = problem.grad(w[0],w[1])
        return  jnp.array([-g_x, g_y])
    
    def FF2(w):
        g_x, g_y = problem.hgrad(w[0],w[1])
        return  jnp.array([-g_x, g_y])
    
    def FH(w, r):
        v = FF2(w+1e-3*1j)/(1e-3*1j)
        return  jnp.abs(v)
    
    x_hist = [x0]
    y_hist = [y0]
    loss = [problem.loss(x0, y0)]
    ep = 1e-8
    for jj in range(50):
        r = FF(w)
        rho = nalg.norm(r)
        Ar = (FH(w,r)-r)/ep
        t = nalg.norm(Ar)
        t = 1.0/t
        P[:,0] = t*r
        AP[:,0]=  t *Ar
        i2 = 1
        i = 1
        # Estimation of optimal parameters
        for it in range(lb):
            alph = np.dot(np.transpose(AP),r)
            w = w + P@(alph)
            r = FF(w)
            rho = nalg.norm(r)
            Ar = (FH(w, r)-r)/ep
            p = r
            if i <= lb:
                k = 0
            else:
                k = i2
            while True:
                if k ==lb:
                    k = 0
                k +=1
                tau = np.inner(Ar, AP[:,k-1])
                p = p - tau*(P[:,k-1])
                Ar = Ar -  tau*(AP[:,k-1])
                if k == i2:
                    break
            t = nalg.norm(Ar)
            if (i2) == lb:
                i2 = 0
            i2 = i2+1
            i = i+1
            t = 1.0/t
            AP[:,i2-1] = t*Ar
            P[:,i2-1] = t*p
            x,y = w[0],w[1]
            x_hist.append(x)
            y_hist.append(y)
            loss.append(problem.loss(x, y).item())
    return loss, x_hist, y_hist

In [13]:
from jax import grad
x = 3.0
y= 4.0
f = lambda x, y: (4*x**2 -(y-3*x+0.05*x**3)**2-0.1*y**4) * jnp.exp(-0.01 * (x**2+y**2))
grad(f)(x, y)

DeviceArray(9.447126, dtype=float32, weak_type=True)

In [14]:
grad(f,1)(x, y)

DeviceArray(-14.069971, dtype=float32, weak_type=True)

In [10]:
iteration =100
markevery= 10

x0, y0 = 4.,3.
problem = func3()
lr_set = {'simgd':0.05, 'altgd':0.1, 'avg':1, 'adam':0.01, 'eg':0.05,'omd':0.05, 'fr':0.05,'AA':0.2}
f = problem.f

type2=True
loss_f3, xpath_f3, ypath_f3 = main(problem, iteration, x0, y0, lr_set, k=3)
xmin, xmax, xstep = [-7, 7, .1]
ymin, ymax, ystep = [-7, 7, .1]
x, y = np.meshgrid(np.arange(xmin, xmax + xstep, xstep), np.arange(ymin, ymax + ystep, ystep))
z = f(x, y)
dz_dx = elementwise_grad(f, argnum=0)(x, y)
dz_dy = elementwise_grad(f, argnum=1)(x, y)
# fig3 = plot(loss_f3, xpath_f3, ypath_f3, iteration, k, [x0, y0])
fig, ax3 = plt.subplots(figsize=(5,5))
markevery=10
xpath1, xpath2, xpath3, xpath4 = xpath_f3
ypath1, ypath2, ypath3, ypath4 = ypath_f3
loss1, loss2, loss3, loss4 = loss_f3
ax3.contourf(x, y, z)
# ax3.quiver(x, y, x - dz_dx, y - dz_dy, alpha=.5)
ax3.scatter(x0, y0, marker='s', s=100, c='k',alpha=0.6,zorder=20, label='Start')
ax3.plot(xpath1, ypath1, 'b--', linewidth=2, label='AltGDA',markevery=markevery)
ax3.plot(xpath2, ypath2, 'm--', linewidth=2, label='EG',markevery=markevery)
ax3.plot(xpath3, ypath3, 'k--', linewidth=2, label='FR',markevery=markevery)
ax3.plot(xpath4, ypath4, 'r-^', linewidth=2, label='nlTGCR',markevery=markevery)

# ax3.legend([x_init],['Start'], markerscale=1, loc=4, fancybox=True, framealpha=1., fontsize=20)
ax3.set_xlabel('x')
ax3.set_ylabel('y')  
ax3.set_xlim([xmin,xmax])
ax3.set_ylim([ymin,ymax])
ax3.legend()



Quad


TypeError: Can't differentiate w.r.t. type <class 'jaxlib.xla_extension.DeviceArray'>

In [None]:
from jax import grad

In [None]:
x = 1+1j

In [None]:
grad(lambda w: problem.loss(x, 1), holomorphic=True)

In [None]:
plt.plot(loss3,'k',label='FR')
plt.plot(loss4,'r',label='nltcgr')
plt.legend()

In [None]:
loss4