In [2]:
import torchvision
from torch.optim import Optimizer
import torch
import numpy as np

The goal of this algorithm is using the stored vectors (s_i, y_i) to accelerate the computation of the approximation of Hessian matrix. With this algorithm, we can update x_k with less memory and computation. Inside the algorithm, computations are vector element-wise product.

In [5]:
class COptimizer(Optimizer):
    def __init__(self, *args, **kwargs):
        self.f = None
        self.has_f = False
        super(COptimizer, self).__init__(*args, **kwargs)

    def set_f(self, model, data, target, criterion):
        if self.has_f:
            names = list(n for n, _ in model.named_parameters())

            def f(*params):
                out: torch.Tensor = torch.func.functional_call(model, {n: p for n, p in zip(names, params)}, data)
                return criterion(out, target)

            self.f = f

### Implementation from in-built function

In [3]:
class LBFGS(torch.optim.LBFGS):
    def __init__(self, *args, **kwargs):
        super(LBFGS, self).__init__(*args, **kwargs)

    def step(self, closure=None):
        return super(LBFGS, self).step(closure)

    def set_f(self, model, data, target, criterion):
        return

### Implementation from scratch
Needs to be adpated to this https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py

In [None]:
class LBFGS(COptimizer):
    def __init__(self, params, lr=1e-1, f=lambda x: x, history_size=10, line_search='Wolfe', T=100):
        defaults = dict(lr=lr, f=f, T=T)
        self.has_f = False
        self.f = f
        self.lr = lr
        self.T = T
        self.count = 0
        state = self.state['global_state']
        state.setdefault('n_iter', 0)
        state.setdefault('curv_skips', 0)
        state.setdefault('fail_skips', 0)
        state.setdefault('H_diag',1)
        state.setdefault('fail', True)

        state['old_dirs'] = []
        state['old_stps'] = []
        
    def line_search(self, line_search):
        """
        Switches line search option.
        
        Inputs:
            line_search (str): designates line search to use
                Options:
                    'None': uses steplength designated in algorithm
                    'Armijo': uses Armijo backtracking line search
                    'Wolfe': uses Armijo-Wolfe bracketing line search
        
        """
        
        group = self.param_groups[0]
        group['line_search'] = line_search

    def step(self, **kwargs):
        # gradient, s_stored, y_stored, m
        q = gradient
    
        #a here is the alpha in the article
        a = torch.zeros(m)
    
        rou = torch.array([1/np.dot(y_stored[j, :], s_stored[j, :]) for j in range(m)])
        for i in range(m):
            a[m - 1 - i] = rou[m - 1 - i] * np.torch(s_stored[m - 1 - i, :], q)
            q = q - a[m - 1 - i]*y_stored[m - 1 - i, :]
    
        #Here I don't set H_k0 to be a matrix. I just set it as a scalar for computation
        H_k0 = (torch.dot(s_stored[m - 1], y_stored[m - 1])/torch.dot(y_stored[m - 1], y_stored[m - 1]))
        r = H_k0 * q
    
        for i in range(m):
            beta = rou[i] * torch.dot(y_stored[i, :], r)
            r = r + (a[i] - beta) * s_stored[i]
        return r