In [1]:
import jax.numpy as jnp
import jax

class PRNGSequence:
    def __init__(self, seed):
        self._key = jax.random.PRNGKey(seed)
    
    def next(self):
        self._key, subkey = jax.random.split(self._key)
        return subkey
keys = PRNGSequence(100)

x = jax.random.normal(keys.next(), (3,))


params = jnp.array([1.0, 2.0, 3.0])

def model(params, x):
    return x[:, 0] * params[0] + x[:, 1] * params[1] + x[:, 2] * params[2]

def loss(y_pred, y_true):
    return jnp.mean((y_pred - y_true) ** 2)

In [2]:
x = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y_true = jnp.array([[1.0], [2.0]])

In [3]:

def loss_at_params(params):
    y_pred = model(params, x)
    return loss(y_pred, y_true)

In [4]:
from dataclasses import dataclass


@dataclass
class GradientDescentState:
    params: jnp.ndarray
    step_size: float
    iteration: int = 0
def gradient_descent_step(state: GradientDescentState):
    grad = jax.grad(loss_at_params)(state.params)
    new_params = state.params - state.step_size * grad
    
    return GradientDescentState(new_params, state.step_size, state.iteration + 1)

In [5]:
state = GradientDescentState(params, step_size=0.01)
for i in range(10):
    state = gradient_descent_step(state)
    print(f"Iteration {state.iteration}: params = {state.params}, loss = {loss_at_params(state.params)}")

Iteration 1: params = [-0.34500003  0.22500002  0.7950001 ], loss = 5.285163879394531
Iteration 2: params = [-0.47550005  0.05445001  0.58440006], loss = 0.32736244797706604
Iteration 3: params = [-0.48943207  0.03788549  0.5652031 ], loss = 0.28132766485214233
Iteration 4: params = [-0.49216825  0.03610065  0.56436956], loss = 0.28053900599479675
Iteration 5: params = [-0.49382156  0.03573543  0.5652925 ], loss = 0.2801714539527893
Iteration 6: params = [-0.49536267  0.03550761  0.56637794], loss = 0.2798120081424713
Iteration 7: params = [-0.49688473  0.03529412  0.56747305], loss = 0.2794569134712219
Iteration 8: params = [-0.49839675  0.03508317  0.56856316], loss = 0.279106080532074
Iteration 9: params = [-0.49989966  0.0348736   0.56964695], loss = 0.2787593901157379
Iteration 10: params = [-0.50139356  0.03466528  0.57072425], loss = 0.27841684222221375


![image.png](attachment:image.png)

In [6]:
#algorithm 1: Hessian-aware scaling selection

def scaling_selection(g, H, sigma, constant_learning_rate=True):
    
    Hg =jnp.dot(H, g)
    dot_product =jnp.dot(g, Hg)
    norm_g =jnp.linalg.norm(g)

    if constant_learning_rate:
        s_lpc_min = 1 / sigma #set to 1/sigma for constant learning rate
        s_lpc_max = 1 / sigma #set to 1/sigma for constant learning rate
    else:
        s_lpc_min = 1 / sigma *jax.random.random(keys.next())

    s_CG =jnp.linalg.norm(g)**2 / dot_product
    s_MR = dot_product /jnp.linalg.norm(Hg)**2
    s_GM =jnp.sqrt(s_CG * s_MR)

    if dot_product > sigma * norm_g**2:
        spc =jax.random.choice(keys.next(), a=jnp.array([s_CG, s_MR, s_GM]))
        return -spc*g, "SPC"
    elif 0 < dot_product and dot_product < sigma * norm_g**2:
        slpc =jax.random.uniform(keys.next(), s_lpc_min, 1 / sigma)
        return -slpc * g, "LPC"
    else:
        snc =jax.random.uniform(keys.next(), s_lpc_min, s_lpc_max)
        return -snc * g, "NC"

![image.png](attachment:image.png)

![image.png](attachment:image.png)

In [7]:
#algorithm 3 backward tracking line search

def backtracking_LS(loss_at_params, theta, rho, x, g, p):

    alpha = 1.0
    while loss_at_params(x + alpha * p) > loss_at_params(x) + alpha * rho *jnp.dot(g, p):
        alpha *= theta


    return alpha


![image.png](attachment:image.png)

In [8]:
# algorithm 4 forward/backward tracking line search

def forward_backward_LS(loss_at_params, theta, rho, x, g, p):
    alpha = 1.0
    if loss_at_params(x + alpha * p) > loss_at_params(x) + alpha * rho *jnp.dot(g, p):
        backtracking_LS(loss_at_params, theta, rho, x, g, p)
    else:
        while loss_at_params(x + alpha * p) >= loss_at_params(x) + alpha * rho *jnp.dot(g, p):
            alpha /= theta

    return alpha * theta

    

![image.png](attachment:image.png)

In [9]:
# algorithm 2: scaled gradient descent with line search


def scaled_GD(loss_at_params, x0, sigma, rho, theta_bt, theta_fb, MAX_ITER, eps):
    """
    sigma <<< 1
    0 < theta < 1
    0 < rho < 1/2
    """

    x_k = x0
    flag_distribution = {"SPC": 0, "LPC": 0, "NC": 0}


    for _ in range(MAX_ITER):

        g_k = 2 * x_k

        if jnp.linalg.norm(g_k) < eps:
            break
        

        p_k, FLAG = scaling_selection(g_k,jnp.eye(len(x_k)), sigma)
        flag_distribution[FLAG] += 1

        if FLAG == "SPC" or FLAG == "LPC":
            alpha_k = backtracking_LS(loss_at_params, theta_bt, rho, x_k, g_k, p_k)

        else:
            alpha_k = forward_backward_LS(loss_at_params, theta_fb, rho, x_k, g_k, p_k)

        x_k += alpha_k * p_k

    return x_k, flag_distribution

In [10]:
#algorithm 3 backward tracking line search

def backtracking_LS(loss_at_params, theta, rho, x, g, p):

    alpha = 1.0
    while loss_at_params(x + alpha * p) > loss_at_params(x) + alpha * rho *jnp.dot(g, p):
        alpha *= theta


    return alpha


In [17]:
@dataclass
class FirstishOrderMethodState:
    params: jnp.ndarray
    sigma: float = 0.1
    rho: float = 0.25
    theta_bt: float = 0.5
    theta_fb: float = 0.5
    iteration: int = 0

def firstish_order_step(state: FirstishOrderMethodState) -> FirstishOrderMethodState:
    grad = jax.grad(loss_at_params)(state.params)
    hess = jax.hessian(loss_at_params)(state.params)
    p, flag = scaling_selection(grad, hess, state.sigma)
    
    if flag == "SPC" or flag == "LPC":
        alpha = backtracking_LS(loss_at_params, state.theta_bt, state.rho, state.params, grad, p)
    else:
        alpha = forward_backward_LS(loss_at_params, state.theta_fb, state.rho, state.params, grad, p)

    new_params = state.params + alpha * p
    return FirstishOrderMethodState(new_params, state.sigma, state.rho, state.theta_bt, state.theta_fb, state.iteration + 1)

In [18]:

state = FirstishOrderMethodState(params)
for i in range(10):
    state = firstish_order_step(state)
    print(f"Iteration {state.iteration}: params = {state.params}, loss = {loss_at_params(state.params)}")

Iteration 1: params = [-0.48778844  0.03656185  0.5609119 ], loss = 0.281612366437912
Iteration 2: params = [-7.499069e-01  1.077950e-04  7.501223e-01], loss = 0.25000157952308655
Iteration 3: params = [-7.4998719e-01  1.8469073e-06  7.4999070e-01], loss = 0.25
Iteration 4: params = [-7.4998826e-01  1.7124412e-06  7.4999154e-01], loss = 0.25
Iteration 5: params = [-7.4998844e-01  1.6258019e-06  7.4999154e-01], loss = 0.25
Iteration 6: params = [-7.4998850e-01  1.6472990e-06  7.4999166e-01], loss = 0.25
Iteration 7: params = [-7.4999940e-01  1.4632974e-07  7.4999952e-01], loss = 0.25
Iteration 8: params = [-7.4999940e-01  1.4218756e-07  7.4999952e-01], loss = 0.25
Iteration 9: params = [-7.4999940e-01  1.3574382e-07  7.4999952e-01], loss = 0.25
Iteration 10: params = [-7.4999940e-01  1.3160164e-07  7.4999952e-01], loss = 0.25
