In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:


def scaling_selection(g, H, sigma, constant_learning_rate=True):
    Hg = torch.matmul(H, g)
    dot_product = torch.dot(g, Hg)
    norm_g = torch.norm(g)

    if constant_learning_rate:
        s_lpc_min = 1 / sigma
        s_lpc_max = 1 / sigma
    else:
        s_lpc_min = (1 / sigma) * torch.rand(1).item()
        s_lpc_max = 1 / sigma

    s_CG = norm_g**2 / dot_product
    s_MR = dot_product / torch.norm(Hg)**2
    s_GM = torch.sqrt(s_CG * s_MR)

    if dot_product > sigma * norm_g**2:
        spc = torch.tensor([s_CG, s_MR, s_GM])[torch.randint(0, 3, (1,)).item()]
        return -spc * g, "SPC"
    elif dot_product > 0 and dot_product < sigma * norm_g**2:
        slpc = torch.empty(1).uniform_(s_lpc_min, 1 / sigma).item()
        return -slpc * g, "LPC"
    else:
        snc = torch.empty(1).uniform_(s_lpc_min, s_lpc_max).item()
        return -snc * g, "NC"


In [24]:
#algorithm 3 backward tracking line search

def backtracking_LS(model, theta, rho, x, g, p):

    alpha = 1.0
    while model(x + alpha * p) > model(x) + alpha * rho * torch.dot(g, p):
        alpha *= theta


    return alpha


In [25]:
# algorithm 4 forward/backward tracking line search

def forward_backward_LS(model, theta, rho, x, g, p):
    alpha = 1.0
    if model(x + alpha * p) > model(x) + alpha * rho * torch.dot(g, p):
        backtracking_LS(model, theta, rho, x, g, p)
    else:
        while model(x + alpha * p) >= model(x) + alpha * rho * torch.dot(g, p):
            alpha /= theta

    return alpha * theta

    

In [None]:
# algorithm 2: scaled gradient descent with line search

def scaled_GD_norm_squared(model, x0, sigma, rho, theta_bt, theta_fb, MAX_ITER, eps):
    """
    sigma <<< 1
    0 < theta < 1
    0 < rho < 1/2
    """

    x_k = x0.clone()
    flag_distribution = {"SPC": 0, "LPC": 0, "NC": 0}

    for _ in range(MAX_ITER):
        g_k = 2 * x_k

        if torch.norm(g_k) < eps:
            break

        # Use identity matrix as Hessian approximation
        H_k = torch.eye(len(x_k), dtype=x_k.dtype, device=x_k.device)
        p_k, FLAG = scaling_selection(g_k, H_k, sigma)
        flag_distribution[FLAG] += 1

        if FLAG in ["SPC", "LPC"]:
            alpha_k = backtracking_LS(model, theta_bt, rho, x_k, g_k, p_k)
        else:
            alpha_k = forward_backward_LS(model, theta_fb, rho, x_k, g_k, p_k)

        x_k = x_k + alpha_k * p_k

    return x_k, flag_distribution


In [27]:
class NormSquaredModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.norm(x)**2

In [37]:
#try with ||x||² for now
model = NormSquaredModel()

sigma = 0.1
theta_bt = 0.5
theta_fb = 0.5
rho = 0.25

MAX_ITER = 1000
eps = 1e-12

x_0 = 10 * torch.randn(10)



x_star, flag_distribution = scaled_GD_norm_squared(model, x_0, sigma, rho, theta_bt, theta_fb, MAX_ITER, eps)

x_star, flag_distribution

(tensor([-5.4210e-20,  5.4210e-20, -1.3553e-20,  5.4210e-20,  2.7105e-20,
          2.7105e-20,  3.3881e-21, -2.7105e-20,  6.7763e-21, -5.4210e-20]),
 {'SPC': 3, 'LPC': 0, 'NC': 0})

In [None]:
# must define a classification/regression task and make scaled_gd overwrite torch.optim.Optimiser
# https://www.geeksforgeeks.org/custom-optimizers-in-pytorch/

class SimpleMLP(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, output_dim=1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class Scaled_GD_Optimizer(torch.optim.Optimizer):

    def __init__(self):
        raise NotImplementedError()
    

    def step(self):
        raise NotImplementedError()