In [120]:
from autograd import grad
import autograd.numpy as np

In [121]:
x, y = 1.5, 0.5
lr = 0.05
gamma = 0.1
equation = 2

In [122]:
if equation == 0:
    # print('z = -3x^2 - y^2 + 4xy')
    loss_function = lambda x, y: -3*x*x - y*y + 4*x*y
elif equation == 1:
    # print('z = 3x^2 + y^2 + 4xy')
    loss_function = lambda x, y: 3*x*x + y*y + 4*x*y
elif equation == 2:
    # print('z = (4x^2 - (y-3x+0.05x^3)^2 - 0.1y^4)*exp(-0.01(x^2+y^2))')
    loss_function = lambda x, y: (4*x*x - (y-3*x+0.05*x**3)**2 - 0.1*y**4)*np.exp(-0.01*(x**2+y**2))

In [123]:
for c in range(300):
    # ∇f(x,y) = (df/dx, df/dy)
    grad_f = grad(loss_function, (0,1))
    grad_fx, grad_fy = grad_f(x, y)
    grad_fxy = np.asarray((grad_fx, grad_fy))
    # print(grad_fxy)

    # ∇||∇f(x,y)||^2 = 2*∇(∇f(x,y))*∇f(x,y)
    grad_grad_fx = grad(grad(loss_function, 0), (0,1))
    grad_grad_fy = grad(grad(loss_function, 1), (0,1))
    grad_grad_fxy = np.asarray((grad_grad_fx(x, y), grad_grad_fy(x, y)))
    # print(grad_grad_fxy)

    grad_fnorm_squared_x, grad_fnorm_squared_y = 2 * np.dot(grad_grad_fxy, grad_fxy)
    x = x - lr*grad_fx - gamma*lr*grad_fnorm_squared_x
    y = y + lr*grad_fy - gamma*lr*grad_fnorm_squared_y
    z = loss_function(x, y)

    print(f'x: {x:.4f}, y: {y:.4f}, z: {z:.4f}')

x: 1.3578, y: 1.4501, z: 0.6649
x: 1.1832, y: 1.9343, z: 1.7587
x: 1.1885, y: 1.9013, z: 1.7556
x: 1.1787, y: 1.9046, z: 1.7501
x: 1.1740, y: 1.8957, z: 1.7445
x: 1.1678, y: 1.8909, z: 1.7389
x: 1.1622, y: 1.8849, z: 1.7332
x: 1.1565, y: 1.8793, z: 1.7277
x: 1.1509, y: 1.8736, z: 1.7221
x: 1.1454, y: 1.8681, z: 1.7165
x: 1.1400, y: 1.8625, z: 1.7110
x: 1.1347, y: 1.8571, z: 1.7055
x: 1.1295, y: 1.8517, z: 1.7000
x: 1.1243, y: 1.8463, z: 1.6945
x: 1.1192, y: 1.8410, z: 1.6891
x: 1.1142, y: 1.8358, z: 1.6837
x: 1.1093, y: 1.8306, z: 1.6783
x: 1.1045, y: 1.8255, z: 1.6730
x: 1.0997, y: 1.8204, z: 1.6677
x: 1.0950, y: 1.8154, z: 1.6624
x: 1.0904, y: 1.8104, z: 1.6572
x: 1.0858, y: 1.8055, z: 1.6520
x: 1.0814, y: 1.8007, z: 1.6468
x: 1.0770, y: 1.7959, z: 1.6417
x: 1.0726, y: 1.7911, z: 1.6367
x: 1.0684, y: 1.7865, z: 1.6316
x: 1.0642, y: 1.7819, z: 1.6266
x: 1.0601, y: 1.7773, z: 1.6217
x: 1.0560, y: 1.7728, z: 1.6167
x: 1.0520, y: 1.7683, z: 1.6119
x: 1.0481, y: 1.7639, z: 1.6070
x: 1.044