<a href="https://colab.research.google.com/github/1316827294/-/blob/master/lbfgs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
os.chdir("/content/drive/MyDrive/lbfgs")

In [None]:
# !pip install pycutest

In [None]:
# !git clone https://github.com/hjmshi/PyTorch-LBFGS.git

In [None]:
# !kill -9 91

In [None]:
import os
os.chdir("/content/drive/MyDrive/lbfgs")
from torch import matmul

import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
from copy import deepcopy
from torch.optim import Optimizer


def is_legal(v):
    """
    Checks that tensor is not NaN or Inf.

    Inputs:
        v (tensor): tensor to be checked

    """
    legal = not torch.isnan(v).any() and not torch.isinf(v)

    return legal


def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
    """
    Gives the minimizer and minimum of the interpolating polynomial over given points
    based on function and derivative information. Defaults to bisection if no critical
    points are valid.

    Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
    modifications.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 12/6/18.

    Inputs:
        points (nparray): two-dimensional array with each point of form [x f g]
        x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
        x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
        plot (bool): plot interpolating polynomial

    Outputs:
        x_sol (float): minimizer of interpolating polynomial
        F_min (float): minimum of interpolating polynomial

    Note:
      . Set f or g to np.nan if they are unknown

    """
    no_points = points.shape[0]
    order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1

    x_min = np.min(points[:, 0])
    x_max = np.max(points[:, 0])

    # compute bounds of interpolation area
    if x_min_bound is None:
        x_min_bound = x_min
    if x_max_bound is None:
        x_max_bound = x_max

    # explicit formula for quadratic interpolation
    if no_points == 2 and order == 2 and plot is False:
        # Solution to quadratic interpolation is given by:
        # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
        # x_min = x1 - g1/(2a)
        # if x1 = 0, then is given by:
        # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))

        if points[0, 0] == 0:
            x_sol = -points[0, 2] * points[1, 0] ** 2 / (
                        2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
        else:
            a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (
                        points[0, 0] - points[1, 0]) ** 2
            x_sol = points[0, 0] - points[0, 2] / (2 * a)

        x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)

    # explicit formula for cubic interpolation
    elif no_points == 2 and order == 3 and plot is False:
        # Solution to cubic interpolation is given by:
        # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
        # d2 = sqrt(d1^2 - g1*g2)
        # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
        d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
        d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
        if np.isreal(d2):
            x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * (
                        (points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
            x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
        else:
            x_sol = (x_max_bound + x_min_bound) / 2

    # solve linear system
    else:
        # define linear constraints
        A = np.zeros((0, order + 1))
        b = np.zeros((0, 1))

        # add linear constraints on function values
        for i in range(no_points):
            if not np.isnan(points[i, 1]):
                constraint = np.zeros((1, order + 1))
                for j in range(order, -1, -1):
                    constraint[0, order - j] = points[i, 0] ** j
                A = np.append(A, constraint, 0)
                b = np.append(b, points[i, 1])

        # add linear constraints on gradient values
        for i in range(no_points):
            if not np.isnan(points[i, 2]):
                constraint = np.zeros((1, order + 1))
                for j in range(order):
                    constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
                A = np.append(A, constraint, 0)
                b = np.append(b, points[i, 2])

        # check if system is solvable
        if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
            x_sol = (x_min_bound + x_max_bound) / 2
            f_min = np.Inf
        else:
            # solve linear system for interpolating polynomial
            coeff = np.linalg.solve(A, b)

            # compute critical points
            dcoeff = np.zeros(order)
            for i in range(len(coeff) - 1):
                dcoeff[i] = coeff[i] * (order - i)

            crit_pts = np.array([x_min_bound, x_max_bound])
            crit_pts = np.append(crit_pts, points[:, 0])

            if not np.isinf(dcoeff).any():
                roots = np.roots(dcoeff)
                crit_pts = np.append(crit_pts, roots)

            # test critical points
            f_min = np.Inf
            x_sol = (x_min_bound + x_max_bound) / 2  # defaults to bisection
            for crit_pt in crit_pts:
                if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
                    F_cp = np.polyval(coeff, crit_pt)
                    if np.isreal(F_cp) and F_cp < f_min:
                        x_sol = np.real(crit_pt)
                        f_min = np.real(F_cp)

            if (plot):
                plt.figure()
                x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound) / 10000)
                f = np.polyval(coeff, x)
                plt.plot(x, f)
                plt.plot(x_sol, f_min, 'x')

    return x_sol


class LBFGS2(Optimizer):
    """
    Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap
    L-BFGS implementations and (stochastic) Powell damping. Partly based on the
    original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code,
    and Michael Overton's weak Wolfe line search MATLAB code.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 10/20/20.

    Warnings:
      . Does not support per-parameter options and parameter groups.
      . All parameters have to be on a single device.

    Inputs:
        lr (float): steplength or learning rate (default: 1)
        history_size (int): update history size (default: 10)
        line_search (str): designates line search to use (default: 'Wolfe')
            Options:
                'None': uses steplength designated in algorithm
                'Armijo': uses Armijo backtracking line search
                'Wolfe': uses Armijo-Wolfe bracketing line search
        dtype: data type (default: torch.float)
        debug (bool): debugging mode

    References:
    [1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS
        Method for Machine Learning." Advances in Neural Information Processing
        Systems. 2016.
    [2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine
        Learning." International Conference on Machine Learning. 2018.
    [3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton
        Methods." Mathematical Programming 141.1-2 (2013): 135-163.
    [4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for
        Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528.
    [5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage."
        Mathematics of Computation 35.151 (1980): 773-782.
    [6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York,
        2006.
    [7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization
        in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
        (2005).
    [8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton
        Method for Online Convex Optimization." Artificial Intelligence and Statistics.
        2007.
    [9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic
        Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956.

    """

    def __init__(self, params, lr=1., history_size=10, line_search='Wolfe',
                 dtype=torch.float, debug=False):

        # ensure inputs are valid
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0 <= history_size:
            raise ValueError("Invalid history size: {}".format(history_size))
        if line_search not in ['Armijo', 'Wolfe', 'None']:
            raise ValueError("Invalid line search: {}".format(line_search))

        defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug)
        super(LBFGS2, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError("L-BFGS doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        state = self.state['global_state']
        state.setdefault('n_iter', 0)
        state.setdefault('curv_skips', 0)
        state.setdefault('fail_skips', 0)
        state.setdefault('H_diag', 1)
        state.setdefault('fail', True)

        state['old_dirs'] = []
        state['old_stps'] = []

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.data.new(p.data.numel()).zero_()
            elif p.grad.data.is_sparse:
                view = p.grad.data.to_dense().view(-1)
            else:
                view = p.grad.data.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _add_update(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
            offset += numel
        assert offset == self._numel()

    def _copy_params(self):
        current_params = []
        for param in self._params:
            current_params.append(deepcopy(param.data))
        return current_params

    def _load_params(self, current_params):
        i = 0
        for param in self._params:
            param.data[:] = current_params[i]
            i += 1

    def line_search(self, line_search):
        """
        Switches line search option.

        Inputs:
            line_search (str): designates line search to use
                Options:
                    'None': uses steplength designated in algorithm
                    'Armijo': uses Armijo backtracking line search
                    'Wolfe': uses Armijo-Wolfe bracketing line search

        """

        group = self.param_groups[0]
        group['line_search'] = line_search

        return

    def two_loop_recursion(self, vec):
        """
        Performs two-loop recursion on given vector to obtain Hv.

        Inputs:
            vec (tensor): 1-D tensor to apply two-loop recursion to

        Output:
            r (tensor): matrix-vector product Hv

        """

        group = self.param_groups[0]
        history_size = group['history_size']

        state = self.state['global_state']
        old_dirs = state.get('old_dirs')  # change in gradients
        old_stps = state.get('old_stps')  # change in iterates
        H_diag = state.get('H_diag')

        # compute the product of the inverse Hessian approximation and the gradient
        num_old = len(old_dirs)

        if 'rho' not in state:
            state['rho'] = [None] * history_size
            state['alpha'] = [None] * history_size
        rho = state['rho']
        alpha = state['alpha']

        for i in range(num_old):
            rho[i] = 1. / old_stps[i].dot(old_dirs[i])

        q = vec
        for i in range(num_old - 1, -1, -1):
            alpha[i] = old_dirs[i].dot(q) * rho[i]
            q.add_(-alpha[i], old_stps[i])

        # multiply by initial Hessian
        # r/d is the final direction
        r = torch.mul(q, H_diag)
        for i in range(num_old):
            beta = old_stps[i].dot(r) * rho[i]
            r.add_(alpha[i] - beta, old_dirs[i])

        return r

    def curvature_update(self, flat_grad, eps=1e-2, damping=False):
        """
        Performs curvature update.

        Inputs:
            flat_grad (tensor): 1-D tensor of flattened gradient for computing
                gradient difference with previously stored gradient
            eps (float): constant for curvature pair rejection or damping (default: 1e-2)
            damping (bool): flag for using Powell damping (default: False)
        """

        assert len(self.param_groups) == 1

        # load parameters
        if (eps <= 0):
            raise (ValueError('Invalid eps; must be positive.'))

        group = self.param_groups[0]
        history_size = group['history_size']
        debug = group['debug']

        # variables cached in state (for tracing)
        state = self.state['global_state']
        fail = state.get('fail')

        # check if line search failed
        if not fail:

            d = state.get('d')
            t = state.get('t')
            old_dirs = state.get('old_dirs')
            old_stps = state.get('old_stps')
            H_diag = state.get('H_diag')
            prev_flat_grad = state.get('prev_flat_grad')
            Bs = state.get('Bs')

            # compute y's
            y = flat_grad.sub(prev_flat_grad)
            s = d.mul(t)
            sBs = s.dot(Bs)
            ys = y.dot(s)  # y*s

            # update L-BFGS matrix
            if ys > eps * sBs or damping == True:

                # perform Powell damping
                if damping == True and ys < eps * sBs:
                    if debug:
                        print('Applying Powell damping...')
                    theta = ((1 - eps) * sBs) / (sBs - ys)
                    y = theta * y + (1 - theta) * Bs

                # updating memory
                if len(old_dirs) == history_size:
                    # shift history by one (limited-memory)
                    old_dirs.pop(0)
                    old_stps.pop(0)

                # store new direction/step
                old_dirs.append(s)
                old_stps.append(y)

                # update scale of initial Hessian approximation
                H_diag = ys / y.dot(y)  # (y*y)

                state['old_dirs'] = old_dirs
                state['old_stps'] = old_stps
                state['H_diag'] = H_diag

            else:
                # save skip
                state['curv_skips'] += 1
                if debug:
                    print('Curvature pair skipped due to failed criterion')

        else:
            # save skip
            state['fail_skips'] += 1
            if debug:
                print('Line search failed; curvature pair update skipped')

        return

    def _step(self, p_k, g_Ok, g_Sk=None, options=None):

        if options is None:
            options = {}
        assert len(self.param_groups) == 1

        # load parameter options
        group = self.param_groups[0]
        lr = group['lr']
        line_search = group['line_search']
        dtype = group['dtype']
        debug = group['debug']

        # variables cached in state (for tracing)
        state = self.state['global_state']
        d = state.get('d')
        t = state.get('t')
        prev_flat_grad = state.get('prev_flat_grad')
        Bs = state.get('Bs')

        # keep track of nb of iterations
        state['n_iter'] += 1

        # set search direction
        d = p_k

        # modify previous gradient
        if prev_flat_grad is None:
            prev_flat_grad = g_Ok.clone()
        else:
            prev_flat_grad.copy_(g_Ok)

        # set initial step size
        t = lr

        # closure evaluation counter
        closure_eval = 0

        if g_Sk is None:
            g_Sk = g_Ok.clone()

        # perform Armijo backtracking line search
        if line_search == 'Armijo':

            # load options
            if options:
                if 'closure' not in options.keys():
                    raise (ValueError('closure option not specified.'))
                else:
                    closure = options['closure']

                if 'gtd' not in options.keys():
                    # if closure_eval==0:
                    #     d =-prev_flat_grad
                    gtd = g_Sk.dot(d)
                else:
                    gtd = options['gtd']

                if 'current_loss' not in options.keys():
                    F_k = closure()
                    closure_eval += 1
                else:
                    F_k = options['current_loss']

                if 'eta' not in options.keys():
                    eta = 2
                elif options['eta'] <= 0:
                    raise (ValueError('Invalid eta; must be positive.'))
                else:
                    eta = options['eta']

                if 'c1' not in options.keys():
                    c1 = 1e-4
                elif options['c1'] >= 1 or options['c1'] <= 0:
                    raise (ValueError('Invalid c1; must be strictly between 0 and 1.'))
                else:
                    c1 = options['c1']

                if 'max_ls' not in options.keys():
                    max_ls = 10
                elif options['max_ls'] <= 0:
                    raise (ValueError('Invalid max_ls; must be positive.'))
                else:
                    max_ls = options['max_ls']

                if 'interpolate' not in options.keys():
                    interpolate = True
                else:
                    interpolate = options['interpolate']

                if 'inplace' not in options.keys():
                    inplace = True
                else:
                    inplace = options['inplace']

                if 'ls_debug' not in options.keys():
                    ls_debug = False
                else:
                    ls_debug = options['ls_debug']

            else:
                raise (ValueError('Options are not specified; need closure evaluating function.'))

            # initialize values
            if interpolate:
                if torch.cuda.is_available():
                    F_prev = torch.tensor(np.nan, dtype=dtype).cuda()
                else:
                    F_prev = torch.tensor(np.nan, dtype=dtype)

            ls_step = 0
            t_prev = 0  # old steplength
            fail = False  # failure flag

            # begin print for debug mode
            if ls_debug:
                print(
                    '==================================== Begin Armijo line search ===================================')
                print('F(x): %.8e  g*d: %.8e' % (F_k, gtd))

            # check if search direction is descent direction
            if gtd >= 0:
                desc_dir = False
                if debug:
                    print('Not a descent direction!')
            else:
                desc_dir = True

            # store values if not in-place
            if not inplace:
                current_params = self._copy_params()
            # if closure_eval == 0:
            #     d = -prev_flat_grad
            # update and evaluate at new point
            self._add_update(t, d)
            F_new = closure()
            closure_eval += 1

            # print info if debugging
            if ls_debug:
                print('LS Step: %d  t: %.8e  F(x+td): %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                      % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))

            # check Armijo condition
            while F_new > F_k + c1 * t * gtd or not is_legal(F_new):

                # check if maximum number of iterations reached
                if ls_step >= max_ls:
                    if inplace:
                        self._add_update(-t, d)
                    else:
                        self._load_params(current_params)

                    t = 0
                    F_new = closure()
                    closure_eval += 1
                    fail = True
                    break

                else:
                    # store current steplength
                    t_new = t

                    # compute new steplength

                    # if first step or not interpolating, then multiply by factor
                    if ls_step == 0 or not interpolate or not is_legal(F_new):
                        t = t / eta

                    # if second step, use function value at new point along with
                    # gradient and function at current iterate
                    elif ls_step == 1 or not is_legal(F_prev):
                        t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]]))

                    # otherwise, use function values at new point, previous point,
                    # and gradient and function at current iterate
                    else:
                        t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan],
                                                 [t_prev, F_prev.item(), np.nan]]))

                    # if values are too extreme, adjust t
                    if interpolate:
                        if t < 1e-3 * t_new:
                            t = 1e-3 * t_new
                        elif t > 0.6 * t_new:
                            t = 0.6 * t_new

                        # store old point
                        F_prev = F_new
                        t_prev = t_new

                    # update iterate and reevaluate

                    if inplace:
                        self._add_update(t - t_new, d)
                    else:
                        self._load_params(current_params)
                        self._add_update(t, d)

                    F_new = closure()
                    closure_eval += 1
                    ls_step += 1  # iterate

                    # print info if debugging
                    if ls_debug:
                        print('LS Step: %d  t: %.8e  F(x+td):   %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                              % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            # print final steplength
            if ls_debug:
                print('Final Steplength:', t)
                print(
                    '===================================== End Armijo line search ====================================')

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = fail

            return F_new, t, ls_step, closure_eval, desc_dir, fail

        # perform weak Wolfe line search
        elif line_search == 'Wolfe':

            # load options
            if options:
                if 'closure' not in options.keys():
                    raise (ValueError('closure option not specified.'))
                else:
                    closure = options['closure']

                if 'current_loss' not in options.keys():
                    F_k = closure()
                    closure_eval += 1
                else:
                    F_k = options['current_loss']

                if 'gtd' not in options.keys():
                    gtd = g_Sk.dot(d)
                else:
                    gtd = options['gtd']

                if 'eta' not in options.keys():
                    eta = 2
                elif options['eta'] <= 1:
                    raise (ValueError('Invalid eta; must be greater than 1.'))
                else:
                    eta = options['eta']

                if 'c1' not in options.keys():
                    c1 = 1e-4
                elif options['c1'] >= 1 or options['c1'] <= 0:
                    raise (ValueError('Invalid c1; must be strictly between 0 and 1.'))
                else:
                    c1 = options['c1']

                if 'c2' not in options.keys():
                    c2 = 0.9
                elif options['c2'] >= 1 or options['c2'] <= 0:
                    raise (ValueError('Invalid c2; must be strictly between 0 and 1.'))
                elif options['c2'] <= c1:
                    raise (ValueError('Invalid c2; must be strictly larger than c1.'))
                else:
                    c2 = options['c2']

                if 'max_ls' not in options.keys():
                    max_ls = 10
                elif options['max_ls'] <= 0:
                    raise (ValueError('Invalid max_ls; must be positive.'))
                else:
                    max_ls = options['max_ls']

                if 'interpolate' not in options.keys():
                    interpolate = True
                else:
                    interpolate = options['interpolate']

                if 'inplace' not in options.keys():
                    inplace = True
                else:
                    inplace = options['inplace']

                if 'ls_debug' not in options.keys():
                    ls_debug = False
                else:
                    ls_debug = options['ls_debug']

            else:
                raise (ValueError('Options are not specified; need closure evaluating function.'))

            # initialize counters
            ls_step = 0
            grad_eval = 0  # tracks gradient evaluations
            t_prev = 0  # old steplength

            # initialize bracketing variables and flag
            alpha = 0
            beta = float('Inf')
            fail = False

            # initialize values for line search
            if (interpolate):
                F_a = F_k
                g_a = gtd

                if (torch.cuda.is_available()):
                    F_b = torch.tensor(np.nan, dtype=dtype).cuda()
                    g_b = torch.tensor(np.nan, dtype=dtype).cuda()
                else:
                    F_b = torch.tensor(np.nan, dtype=dtype)
                    g_b = torch.tensor(np.nan, dtype=dtype)

            # begin print for debug mode
            if ls_debug:
                print(
                    '==================================== Begin Wolfe line search ====================================')
                print('F(x): %.8e  g*d: %.8e' % (F_k, gtd))

            # check if search direction is descent direction
            if gtd >= 0:
                desc_dir = False
                if debug:
                    print('Not a descent direction!')
            else:
                desc_dir = True

            # store values if not in-place
            if not inplace:
                current_params = self._copy_params()

            # update and evaluate at new point
            self._add_update(t, d)
            F_new = closure()
            closure_eval += 1

            # main loop
            while True:

                # check if maximum number of line search steps have been reached
                if ls_step >= max_ls:
                    if inplace:
                        self._add_update(-t, d)
                    else:
                        self._load_params(current_params)

                    t = 0
                    F_new = closure()
                    F_new.backward()
                    g_new = self._gather_flat_grad()
                    closure_eval += 1
                    grad_eval += 1
                    fail = True
                    break

                # print info if debugging
                if ls_debug:
                    print('LS Step: %d  t: %.8e  alpha: %.8e  beta: %.8e'
                          % (ls_step, t, alpha, beta))
                    print('Armijo:  F(x+td): %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                          % (F_new, F_k + c1 * t * gtd, F_k))

                # check Armijo condition
                if F_new > F_k + c1 * t * gtd:

                    # set upper bound
                    beta = t
                    t_prev = t

                    # update interpolation quantities
                    if interpolate:
                        F_b = F_new
                        if torch.cuda.is_available():
                            g_b = torch.tensor(np.nan, dtype=dtype).cuda()
                        else:
                            g_b = torch.tensor(np.nan, dtype=dtype)

                else:

                    # compute gradient
                    F_new.backward()
                    g_new = self._gather_flat_grad()

                    grad_eval += 1
                    gtd_new = g_new.dot(d)

                    # print info if debugging
                    if ls_debug:
                        print('Wolfe: g(x+td)*d: %.8e  c2*g*d: %.8e  gtd: %.8e'
                              % (gtd_new, c2 * gtd, gtd))

                    # check curvature condition
                    if gtd_new < c2 * gtd:

                        # set lower bound
                        alpha = t
                        t_prev = t

                        # update interpolation quantities
                        if interpolate:
                            F_a = F_new
                            g_a = gtd_new

                    else:
                        break

                # compute new steplength

                # if first step or not interpolating, then bisect or multiply by factor
                if not interpolate or not is_legal(F_b):
                    if beta == float('Inf'):
                        t = eta * t
                    else:
                        t = (alpha + beta) / 2.0

                # otherwise interpolate between a and b
                else:
                    t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]]))

                    # if values are too extreme, adjust t
                    if beta == float('Inf'):
                        if t > 2 * eta * t_prev:
                            t = 2 * eta * t_prev
                        elif t < eta * t_prev:
                            t = eta * t_prev
                    else:
                        if t < alpha + 0.2 * (beta - alpha):
                            t = alpha + 0.2 * (beta - alpha)
                        elif t > (beta - alpha) / 2.0:
                            t = (beta - alpha) / 2.0

                    # if we obtain nonsensical value from interpolation
                    if t <= 0:
                        t = (beta - alpha) / 2.0

                # update parameters
                if inplace:
                    self._add_update(t - t_prev, d)
                else:
                    self._load_params(current_params)
                    self._add_update(t, d)

                # evaluate closure
                F_new = closure()
                closure_eval += 1
                ls_step += 1

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            # print final steplength
            if ls_debug:
                print('Final Steplength:', t)
                print(
                    '===================================== End Wolfe line search =====================================')

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = fail
            return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail

        else:

            # perform update
            self._add_update(t, d)

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = False

            return t

    def step(self, p_k, g_Ok, g_Sk=None, options={}):
        return self._step(p_k, g_Ok, g_Sk, options)


class FullBatchLBFGS(LBFGS2):

    def __init__(self, params, lr=1, history_size=10, line_search='Wolfe',
                 dtype=torch.float, debug=False):
        super(FullBatchLBFGS, self).__init__(params, lr, history_size, line_search,
                                             dtype, debug)

    def step(self, options=None):

        # load options for damping and eps
        if 'damping' not in options.keys():
            damping = False
        else:
            damping = options['damping']

        if 'eps' not in options.keys():
            eps = 1e-2
        else:
            eps = options['eps']

        # gather gradient
        grad = self._gather_flat_grad()

        # update curvature if after 1st iteration
        state = self.state['global_state']
        if state['n_iter'] > 0:
            self.curvature_update(grad, eps, damping)

        # compute search direction
        p = self.two_loop_recursion(-grad)

        # take step
        return self._step(p, grad, options=options)

In [None]:

import sys
sys.path.append('/content/drive/MyDrive/lbfgs/PyTorch-LBFGS/functions')

import numpy as np
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from keras.datasets import cifar10 # to load dataset
from time import process_time

from utils import compute_stats, get_grad
# from LBFGS import LBFGS

# Parameters for L-BFGS training
max_iter = 100
ghost_batch = 128
batch_size = 8192

# Load data
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = X_train / 255
X_test = X_test / 255

X_train = np.transpose(X_train, (0, 3, 1, 2))
X_test = np.transpose(X_test, (0, 3, 1, 2))

# Define network
# Define network
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 1000)
        self.fc2 = nn.Linear(1000, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

'''
创建VGG块
参数分别为输入通道数，输出通道数，卷积层个数，是否做最大池化
'''
def make_vgg_block(in_channel, out_channel, convs, pool=True):
    net = []

    # 不改变图片尺寸卷积
    net.append(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1))
    net.append(nn.BatchNorm2d(out_channel))
    net.append(nn.ReLU(inplace=True))

    for i in range(convs - 1):
        # 不改变图片尺寸卷积
        net.append(nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1))
        net.append(nn.BatchNorm2d(out_channel))
        net.append(nn.ReLU(inplace=True))

    if pool:
        # 2*2最大池化，图片变为w/2 * h/2
        net.append(nn.MaxPool2d(2))

    return nn.Sequential(*net)


# 定义网络模型
class VGG19Net(nn.Module):
    def __init__(self):
        super(VGG19Net, self).__init__()

        net = []

        # 输入32*32，输出16*16
        net.append(make_vgg_block(3, 64, 2))

        # 输出8*8
        net.append(make_vgg_block(64, 128, 2))

        # 输出4*4
        net.append(make_vgg_block(128, 256, 4))

        # 输出2*2
        net.append(make_vgg_block(256, 512, 4))

        # 无池化层，输出保持2*2
        net.append(make_vgg_block(512, 512, 4, False))

        self.cnn = nn.Sequential(*net)

        self.fc = nn.Sequential(
            # 512个feature，每个feature 2*2
            nn.Linear(512*2*2, 256),
            nn.ReLU(),

            nn.Linear(256, 256),
            nn.ReLU(),

            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.cnn(x)

        # x.size()[0]: batch size
        x = x.view(x.size()[0], -1)
        x = self.fc(x)

        return x


# Check cuda availability
cuda = torch.cuda.is_available()
    
# Create neural network model
if cuda:
    torch.cuda.manual_seed(2018)
    model = VGG19Net().cuda() 
else:
    torch.manual_seed(2018)
    model = VGG19Net()
# Define helper functions

# Forward pass
if cuda:
    opfun = lambda X: model.forward(torch.from_numpy(X).cuda())
else:
    opfun = lambda X: model.forward(torch.from_numpy(X))

# Forward pass through the network given the input
if cuda:
    predsfun = lambda op: np.argmax(op.cpu().data.numpy(), 1)
else:
    predsfun = lambda op: np.argmax(op.data.numpy(), 1)

# Do the forward pass, then compute the accuracy
accfun = lambda op, y: np.mean(np.equal(predsfun(op), y.squeeze())) * 100

# Define optimizer
# optimizer1 = LBFGS1(model.parameters(), lr=1., history_size=10, line_search='Wolfe', debug=True)
optimizer2 = LBFGS2(model.parameters(), lr=1., history_size=10, line_search='Wolfe', debug=True)
train_loss_list ,test_loss_list, test_acc_list,x_list = [],[],[],[]
process_time()
# Main training loop
for n_iter in range(max_iter):
    
    # training mode
    model.train()
    
    # sample batch
    random_index = np.random.permutation(range(X_train.shape[0]))
    Sk = random_index[0:batch_size]
    
    # compute initial gradient and objective
    grad, obj = get_grad(optimizer2, X_train[Sk], y_train[Sk], opfun)
    
    # two-loop recursion to compute search direction
    p = optimizer2.two_loop_recursion(-grad)
            
    # define closure for line search
    def closure():              
        
        optimizer2.zero_grad()
        
        if cuda:
            loss_fn = torch.tensor(0, dtype=torch.float).cuda()
        else:
            loss_fn = torch.tensor(0, dtype=torch.float)
        
        for subsmpl in np.array_split(Sk, max(int(batch_size / ghost_batch), 1)):
                        
            ops = opfun(X_train[subsmpl])
            
            if cuda:
                tgts = torch.from_numpy(y_train[subsmpl]).cuda().long().squeeze()
            else:
                tgts = torch.from_numpy(y_train[subsmpl]).long().squeeze()
                
            loss_fn += F.cross_entropy(ops, tgts) * (len(subsmpl) / batch_size)
                        
        return loss_fn
    
    # perform line search step
    options = {'closure': closure, 'current_loss': obj}
    obj, grad, lr, _, _, _, _, _ = optimizer2.step(p, grad, options=options)
        
    # curvature update
    optimizer2.curvature_update(grad)
    
    # compute statistics
    model.eval()
    train_loss, test_loss, test_acc = compute_stats(X_train, y_train, X_test, y_test, opfun, accfun,
                                                    ghost_batch=128)
    
    train_loss_list.append(float(train_loss))
    test_loss_list.append(float(test_loss))
    test_acc_list.append(float(test_acc))
    x_list.append(n_iter+1)
    # print data
    print('Iter:', n_iter + 1, 'lr:', lr, 'Training Loss:', train_loss, 'Test Loss:', test_loss,
          'Test Accuracy:', test_acc)
print("运行时间是: {:9.9}s".format(process_time()))
a=np.load('train_loss_list.npy')
train_loss_list_new=a.tolist()

a=np.load('test_loss_list.npy')
test_loss_list_new=a.tolist()

a=np.load('test_acc_list.npy')
test_acc_list_new=a.tolist()

plt.subplot(1, 2, 1)
plt.plot(test_acc_list_new, label='New test Accuracy')
plt.plot(test_acc_list, label='Test Accuracy')
plt.title('New and old test Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_loss_list_new, label='New test Loss')
plt.plot(test_loss_list, label='Test Loss')
plt.title('New and old test Loss')
plt.legend()
plt.show()



# plt.plot(x_list, train_loss_list, color='red', marker='o', linestyle='dashed', linewidth=2, markersize=1)
# plt.plot(x_list, test_loss_list, color='green', marker='o', linestyle='dashed', linewidth=2, markersize=1)
# plt.plot(x_list, test_acc_list, color='blue', marker='o', linestyle='dashed', linewidth=2, markersize=1)
# plt.legend(labels=('Training Loss', 'Test Loss', 'Test Accuracy'))
# plt.savefig('./a2.jpg')
# plt.show()




Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1420.)
  p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))


OutOfMemoryError: ignored

In [None]:
# !ps -aux|grep python

In [None]:
%reset

In [None]:
from torch import matmul

import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
from copy import deepcopy
from torch.optim import Optimizer


def is_legal(v):
    """
    Checks that tensor is not NaN or Inf.

    Inputs:
        v (tensor): tensor to be checked

    """
    legal = not torch.isnan(v).any() and not torch.isinf(v)

    return legal


def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
    """
    Gives the minimizer and minimum of the interpolating polynomial over given points
    based on function and derivative information. Defaults to bisection if no critical
    points are valid.

    Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
    modifications.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 12/6/18.

    Inputs:
        points (nparray): two-dimensional array with each point of form [x f g]
        x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
        x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
        plot (bool): plot interpolating polynomial

    Outputs:
        x_sol (float): minimizer of interpolating polynomial
        F_min (float): minimum of interpolating polynomial

    Note:
      . Set f or g to np.nan if they are unknown

    """
    no_points = points.shape[0]
    order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1

    x_min = np.min(points[:, 0])
    x_max = np.max(points[:, 0])

    # compute bounds of interpolation area
    if x_min_bound is None:
        x_min_bound = x_min
    if x_max_bound is None:
        x_max_bound = x_max

    # explicit formula for quadratic interpolation
    if no_points == 2 and order == 2 and plot is False:
        # Solution to quadratic interpolation is given by:
        # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
        # x_min = x1 - g1/(2a)
        # if x1 = 0, then is given by:
        # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))

        if points[0, 0] == 0:
            x_sol = -points[0, 2] * points[1, 0] ** 2 / (
                        2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
        else:
            a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (
                        points[0, 0] - points[1, 0]) ** 2
            x_sol = points[0, 0] - points[0, 2] / (2 * a)

        x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)

    # explicit formula for cubic interpolation
    elif no_points == 2 and order == 3 and plot is False:
        # Solution to cubic interpolation is given by:
        # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
        # d2 = sqrt(d1^2 - g1*g2)
        # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
        d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
        d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
        if np.isreal(d2):
            x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * (
                        (points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
            x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
        else:
            x_sol = (x_max_bound + x_min_bound) / 2

    # solve linear system
    else:
        # define linear constraints
        A = np.zeros((0, order + 1))
        b = np.zeros((0, 1))

        # add linear constraints on function values
        for i in range(no_points):
            if not np.isnan(points[i, 1]):
                constraint = np.zeros((1, order + 1))
                for j in range(order, -1, -1):
                    constraint[0, order - j] = points[i, 0] ** j
                A = np.append(A, constraint, 0)
                b = np.append(b, points[i, 1])

        # add linear constraints on gradient values
        for i in range(no_points):
            if not np.isnan(points[i, 2]):
                constraint = np.zeros((1, order + 1))
                for j in range(order):
                    constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
                A = np.append(A, constraint, 0)
                b = np.append(b, points[i, 2])

        # check if system is solvable
        if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
            x_sol = (x_min_bound + x_max_bound) / 2
            f_min = np.Inf
        else:
            # solve linear system for interpolating polynomial
            coeff = np.linalg.solve(A, b)

            # compute critical points
            dcoeff = np.zeros(order)
            for i in range(len(coeff) - 1):
                dcoeff[i] = coeff[i] * (order - i)

            crit_pts = np.array([x_min_bound, x_max_bound])
            crit_pts = np.append(crit_pts, points[:, 0])

            if not np.isinf(dcoeff).any():
                roots = np.roots(dcoeff)
                crit_pts = np.append(crit_pts, roots)

            # test critical points
            f_min = np.Inf
            x_sol = (x_min_bound + x_max_bound) / 2  # defaults to bisection
            for crit_pt in crit_pts:
                if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
                    F_cp = np.polyval(coeff, crit_pt)
                    if np.isreal(F_cp) and F_cp < f_min:
                        x_sol = np.real(crit_pt)
                        f_min = np.real(F_cp)

            if (plot):
                plt.figure()
                x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound) / 10000)
                f = np.polyval(coeff, x)
                plt.plot(x, f)
                plt.plot(x_sol, f_min, 'x')

    return x_sol


class LBFGS1(Optimizer):
    """
    Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap
    L-BFGS implementations and (stochastic) Powell damping. Partly based on the
    original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code,
    and Michael Overton's weak Wolfe line search MATLAB code.

    Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
    Last edited 10/20/20.

    Warnings:
      . Does not support per-parameter options and parameter groups.
      . All parameters have to be on a single device.

    Inputs:
        lr (float): steplength or learning rate (default: 1)
        history_size (int): update history size (default: 10)
        line_search (str): designates line search to use (default: 'Wolfe')
            Options:
                'None': uses steplength designated in algorithm
                'Armijo': uses Armijo backtracking line search
                'Wolfe': uses Armijo-Wolfe bracketing line search
        dtype: data type (default: torch.float)
        debug (bool): debugging mode

    References:
    [1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS
        Method for Machine Learning." Advances in Neural Information Processing
        Systems. 2016.
    [2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine
        Learning." International Conference on Machine Learning. 2018.
    [3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton
        Methods." Mathematical Programming 141.1-2 (2013): 135-163.
    [4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for
        Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528.
    [5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage."
        Mathematics of Computation 35.151 (1980): 773-782.
    [6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York,
        2006.
    [7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization
        in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
        (2005).
    [8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton
        Method for Online Convex Optimization." Artificial Intelligence and Statistics.
        2007.
    [9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic
        Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956.

    """

    def __init__(self, params, lr=1., history_size=10, line_search='Wolfe',
                 dtype=torch.float, debug=False):

        # ensure inputs are valid
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0 <= history_size:
            raise ValueError("Invalid history size: {}".format(history_size))
        if line_search not in ['Armijo', 'Wolfe', 'None']:
            raise ValueError("Invalid line search: {}".format(line_search))

        defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug)
        super(LBFGS1, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError("L-BFGS doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        state = self.state['global_state']
        state.setdefault('n_iter', 0)
        state.setdefault('curv_skips', 0)
        state.setdefault('fail_skips', 0)
        state.setdefault('H_diag', 1)
        state.setdefault('fail', True)

        state['old_dirs'] = []
        state['old_stps'] = []

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.data.new(p.data.numel()).zero_()
            elif p.grad.data.is_sparse:
                view = p.grad.data.to_dense().view(-1)
            else:
                view = p.grad.data.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _add_update(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
            offset += numel
        assert offset == self._numel()

    def _copy_params(self):
        current_params = []
        for param in self._params:
            current_params.append(deepcopy(param.data))
        return current_params

    def _load_params(self, current_params):
        i = 0
        for param in self._params:
            param.data[:] = current_params[i]
            i += 1

    def line_search(self, line_search):
        """
        Switches line search option.

        Inputs:
            line_search (str): designates line search to use
                Options:
                    'None': uses steplength designated in algorithm
                    'Armijo': uses Armijo backtracking line search
                    'Wolfe': uses Armijo-Wolfe bracketing line search

        """

        group = self.param_groups[0]
        group['line_search'] = line_search

        return

    def two_loop_recursion(self, vec):
        """
        Performs two-loop recursion on given vector to obtain Hv.

        Inputs:
            vec (tensor): 1-D tensor to apply two-loop recursion to

        Output:
            r (tensor): matrix-vector product Hv

        """

        group = self.param_groups[0]
        history_size = group['history_size']

        state = self.state['global_state']
        old_dirs = state.get('old_dirs')  # change in gradients
        old_stps = state.get('old_stps')  # change in iterates
        H_diag = state.get('H_diag')

        # compute the product of the inverse Hessian approximation and the gradient
        num_old = len(old_dirs)

        if 'rho' not in state:
            state['rho'] = [None] * history_size
            state['alpha'] = [None] * history_size
        rho = state['rho']
        alpha = state['alpha']

        for i in range(num_old):
            rho[i] = 1. / old_stps[i].dot(old_dirs[i])

        q = vec
        for i in range(num_old - 1, -1, -1):
            alpha[i] = old_dirs[i].dot(q) * rho[i]
            q.add_(-alpha[i], old_stps[i])

        # multiply by initial Hessian
        # r/d is the final direction
        r = torch.mul(q, H_diag)
        for i in range(num_old):
            beta = old_stps[i].dot(r) * rho[i]
            r.add_(alpha[i] - beta, old_dirs[i])

        return r

    def curvature_update(self, flat_grad, eps=1e-2, damping=False):
        """
        Performs curvature update.

        Inputs:
            flat_grad (tensor): 1-D tensor of flattened gradient for computing
                gradient difference with previously stored gradient
            eps (float): constant for curvature pair rejection or damping (default: 1e-2)
            damping (bool): flag for using Powell damping (default: False)
        """

        assert len(self.param_groups) == 1

        # load parameters
        if (eps <= 0):
            raise (ValueError('Invalid eps; must be positive.'))

        group = self.param_groups[0]
        history_size = group['history_size']
        debug = group['debug']

        # variables cached in state (for tracing)
        state = self.state['global_state']
        fail = state.get('fail')

        # check if line search failed
        if not fail:

            d = state.get('d')
            t = state.get('t')
            old_dirs = state.get('old_dirs')
            old_stps = state.get('old_stps')
            H_diag = state.get('H_diag')
            prev_flat_grad = state.get('prev_flat_grad')
            Bs = state.get('Bs')

            # compute y's
            y = flat_grad.sub(prev_flat_grad)
            s = d.mul(t)
            sBs = s.dot(Bs)
            ys = y.dot(s)  # y*s

            # update L-BFGS matrix
            if ys > eps * sBs or damping == True:

                # perform Powell damping
                if damping == True and ys < eps * sBs:
                    if debug:
                        print('Applying Powell damping...')
                    theta = ((1 - eps) * sBs) / (sBs - ys)
                    y = theta * y + (1 - theta) * Bs

                # updating memory
                if len(old_dirs) == history_size:
                    # shift history by one (limited-memory)
                    old_dirs.pop(0)
                    old_stps.pop(0)

                # store new direction/step
                old_dirs.append(s)
                old_stps.append(y)

                # update scale of initial Hessian approximation
                H_diag = ys / y.dot(y)  # (y*y)

                state['old_dirs'] = old_dirs
                state['old_stps'] = old_stps
                state['H_diag'] = H_diag

            else:
                # save skip
                state['curv_skips'] += 1
                if debug:
                    print('Curvature pair skipped due to failed criterion')

        else:
            # save skip
            state['fail_skips'] += 1
            if debug:
                print('Line search failed; curvature pair update skipped')

        return

    def _step(self, p_k, g_Ok, g_Sk=None, options=None):

        if options is None:
            options = {}
        assert len(self.param_groups) == 1

        # load parameter options
        group = self.param_groups[0]
        lr = group['lr']
        line_search = group['line_search']
        dtype = group['dtype']
        debug = group['debug']

        # variables cached in state (for tracing)
        state = self.state['global_state']
        d = state.get('d')
        t = state.get('t')
        prev_flat_grad = state.get('prev_flat_grad')
        Bs = state.get('Bs')

        # keep track of nb of iterations
        state['n_iter'] += 1

        d = p_k
        

        # modify previous gradient
        if prev_flat_grad is None:
            prev_flat_grad = g_Ok.clone()
        else:
            prev_flat_grad.copy_(g_Ok)
        # d =-prev_flat_grad
        # if state['n_iter']==1:
        #     d =-prev_flat_grad
        # else:
        #     # set search direction
        #     d = p_k
        # set initial step size
        t = lr

        # closure evaluation counter
        closure_eval = 0

        if g_Sk is None:
            g_Sk = g_Ok.clone()

        # perform Armijo backtracking line search
        if line_search == 'Armijo':

            # load options
            if options:
                if 'closure' not in options.keys():
                    raise (ValueError('closure option not specified.'))
                else:
                    closure = options['closure']

                if 'gtd' not in options.keys():
                    # if closure_eval==0:
                    #     d =-prev_flat_grad
                    gtd = g_Sk.dot(d)
                else:
                    gtd = options['gtd']

                if 'current_loss' not in options.keys():
                    F_k = closure()
                    closure_eval += 1
                else:
                    F_k = options['current_loss']

                if 'eta' not in options.keys():
                    eta = 2
                elif options['eta'] <= 0:
                    raise (ValueError('Invalid eta; must be positive.'))
                else:
                    eta = options['eta']

                if 'c1' not in options.keys():
                    c1 = 1e-4
                elif options['c1'] >= 1 or options['c1'] <= 0:
                    raise (ValueError('Invalid c1; must be strictly between 0 and 1.'))
                else:
                    c1 = options['c1']

                if 'max_ls' not in options.keys():
                    max_ls = 10
                elif options['max_ls'] <= 0:
                    raise (ValueError('Invalid max_ls; must be positive.'))
                else:
                    max_ls = options['max_ls']

                if 'interpolate' not in options.keys():
                    interpolate = True
                else:
                    interpolate = options['interpolate']

                if 'inplace' not in options.keys():
                    inplace = True
                else:
                    inplace = options['inplace']

                if 'ls_debug' not in options.keys():
                    ls_debug = False
                else:
                    ls_debug = options['ls_debug']

            else:
                raise (ValueError('Options are not specified; need closure evaluating function.'))

            # initialize values
            if interpolate:
                if torch.cuda.is_available():
                    F_prev = torch.tensor(np.nan, dtype=dtype).cuda()
                else:
                    F_prev = torch.tensor(np.nan, dtype=dtype)

            ls_step = 0
            t_prev = 0  # old steplength
            fail = False  # failure flag

            # begin print for debug mode
            if ls_debug:
                print(
                    '==================================== Begin Armijo line search ===================================')
                print('F(x): %.8e  g*d: %.8e' % (F_k, gtd))

            # check if search direction is descent direction
            if gtd >= 0:
                desc_dir = False
                if debug:
                    print('Not a descent direction!')
            else:
                desc_dir = True

            # store values if not in-place
            if not inplace:
                current_params = self._copy_params()
            # if closure_eval == 0:
            #     d = -prev_flat_grad
            # update and evaluate at new point
            self._add_update(t, d)
            F_new = closure()
            closure_eval += 1

            # print info if debugging
            if ls_debug:
                print('LS Step: %d  t: %.8e  F(x+td): %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                      % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))

            # check Armijo condition
            while F_new > F_k + c1 * t * gtd or not is_legal(F_new):

                # check if maximum number of iterations reached
                if ls_step >= max_ls:
                    if inplace:
                        self._add_update(-t, d)
                    else:
                        self._load_params(current_params)

                    t = 0
                    F_new = closure()
                    closure_eval += 1
                    fail = True
                    break

                else:
                    # store current steplength
                    t_new = t

                    # compute new steplength

                    # if first step or not interpolating, then multiply by factor
                    if ls_step == 0 or not interpolate or not is_legal(F_new):
                        t = t / eta

                    # if second step, use function value at new point along with
                    # gradient and function at current iterate
                    elif ls_step == 1 or not is_legal(F_prev):
                        t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]]))

                    # otherwise, use function values at new point, previous point,
                    # and gradient and function at current iterate
                    else:
                        t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan],
                                                 [t_prev, F_prev.item(), np.nan]]))

                    # if values are too extreme, adjust t
                    if interpolate:
                        if t < 1e-3 * t_new:
                            t = 1e-3 * t_new
                        elif t > 0.6 * t_new:
                            t = 0.6 * t_new

                        # store old point
                        F_prev = F_new
                        t_prev = t_new

                    # update iterate and reevaluate

                    if inplace:
                        self._add_update(t - t_new, d)
                    else:
                        self._load_params(current_params)
                        self._add_update(t, d)

                    F_new = closure()
                    closure_eval += 1
                    ls_step += 1  # iterate

                    # print info if debugging
                    if ls_debug:
                        print('LS Step: %d  t: %.8e  F(x+td):   %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                              % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            # print final steplength
            if ls_debug:
                print('Final Steplength:', t)
                print(
                    '===================================== End Armijo line search ====================================')

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = fail

            return F_new, t, ls_step, closure_eval, desc_dir, fail

        # perform weak Wolfe line search
        elif line_search == 'Wolfe':

            # load options
            if options:
                if 'closure' not in options.keys():
                    raise (ValueError('closure option not specified.'))
                else:
                    closure = options['closure']

                if 'current_loss' not in options.keys():
                    F_k = closure()
                    closure_eval += 1
                else:
                    F_k = options['current_loss']

                if 'gtd' not in options.keys():
                    gtd = g_Sk.dot(d)
                else:
                    gtd = options['gtd']

                if 'eta' not in options.keys():
                    eta = 2
                elif options['eta'] <= 1:
                    raise (ValueError('Invalid eta; must be greater than 1.'))
                else:
                    eta = options['eta']

                if 'c1' not in options.keys():
                    c1 = 1e-4
                elif options['c1'] >= 1 or options['c1'] <= 0:
                    raise (ValueError('Invalid c1; must be strictly between 0 and 1.'))
                else:
                    c1 = options['c1']

                if 'c2' not in options.keys():
                    c2 = 0.9
                elif options['c2'] >= 1 or options['c2'] <= 0:
                    raise (ValueError('Invalid c2; must be strictly between 0 and 1.'))
                elif options['c2'] <= c1:
                    raise (ValueError('Invalid c2; must be strictly larger than c1.'))
                else:
                    c2 = options['c2']

                if 'max_ls' not in options.keys():
                    max_ls = 10
                elif options['max_ls'] <= 0:
                    raise (ValueError('Invalid max_ls; must be positive.'))
                else:
                    max_ls = options['max_ls']

                if 'interpolate' not in options.keys():
                    interpolate = True
                else:
                    interpolate = options['interpolate']

                if 'inplace' not in options.keys():
                    inplace = True
                else:
                    inplace = options['inplace']

                if 'ls_debug' not in options.keys():
                    ls_debug = False
                else:
                    ls_debug = options['ls_debug']

            else:
                raise (ValueError('Options are not specified; need closure evaluating function.'))

            # initialize counters
            ls_step = 0
            grad_eval = 0  # tracks gradient evaluations
            t_prev = 0  # old steplength

            # initialize bracketing variables and flag
            alpha = 0
            beta = float('Inf')
            fail = False

            # initialize values for line search
            if (interpolate):
                F_a = F_k
                g_a = gtd

                if (torch.cuda.is_available()):
                    F_b = torch.tensor(np.nan, dtype=dtype).cuda()
                    g_b = torch.tensor(np.nan, dtype=dtype).cuda()
                else:
                    F_b = torch.tensor(np.nan, dtype=dtype)
                    g_b = torch.tensor(np.nan, dtype=dtype)

            # begin print for debug mode
            if ls_debug:
                print(
                    '==================================== Begin Wolfe line search ====================================')
                print('F(x): %.8e  g*d: %.8e' % (F_k, gtd))

            # check if search direction is descent direction
            if gtd >= 0:
                desc_dir = False
                if debug:
                    print('Not a descent direction!')
            else:
                desc_dir = True

            # store values if not in-place
            if not inplace:
                current_params = self._copy_params()

            # update and evaluate at new point
            self._add_update(t, d)
            F_new = closure()
            closure_eval += 1

            # main loop
            while True:
                
                # check if maximum number of line search steps have been reached
                if ls_step >= max_ls:
                    if inplace:
                        self._add_update(-t, d)
                    else:
                        self._load_params(current_params)

                    t = 0
                    F_new = closure()
                    F_new.backward()
                    g_new = self._gather_flat_grad()
                    closure_eval += 1
                    grad_eval += 1
                    fail = True
                    break

                # print info if debugging
                if ls_debug:
                    print('LS Step: %d  t: %.8e  alpha: %.8e  beta: %.8e'
                          % (ls_step, t, alpha, beta))
                    print('Armijo:  F(x+td): %.8e  F-c1*t*g*d: %.8e  F(x): %.8e'
                          % (F_new, F_k + c1 * t * gtd, F_k))
                global g_baocun

                # check Armijo condition
                if F_new > F_k + c1 * t * gtd:

                    # set upper bound
                    beta = t
                    t_prev = t
                    # if ls_step>=1 and ls_step < 200:
                    #     g_new = gtd/d
                    #     old_d =d
                    #     bj = (matmul(g_new.t(), (g_new - g_baocun))) / (
                    #         matmul(g_baocun.t(), g_baocun))
                    #     d = -g_new + (bj * d)

                    #     print('1',old_d,d)
                    #     g_baocun = g_new
                    # else:
                    
                    g_baocun = gtd/d#
                    # update interpolation quantities
                    if interpolate:
                        F_b = F_new
                        if torch.cuda.is_available():
                            g_b = torch.tensor(np.nan, dtype=dtype).cuda()
                        else:
                            g_b = torch.tensor(np.nan, dtype=dtype)

                else:

                    # compute gradient
                    F_new.backward()
                    g_new = self._gather_flat_grad()
                    if ls_step>1 and ls_step < 200:
                        bj = (matmul(g_new.t(), (g_new - g_baocun))) / (
                            matmul(g_baocun.t(), g_baocun))
                        d = -g_new + (bj * d)
                        # print('2')
                        g_baocun = g_new
                    else:
                        g_baocun = g_new

                    grad_eval += 1
                    gtd_new = g_new.dot(d)

                    # print info if debugging
                    if ls_debug:
                        print('Wolfe: g(x+td)*d: %.8e  c2*g*d: %.8e  gtd: %.8e'
                              % (gtd_new, c2 * gtd, gtd))

                    # check curvature condition
                    if gtd_new < c2 * gtd:

                        # set lower bound
                        alpha = t
                        t_prev = t

                        # update interpolation quantities
                        if interpolate:
                            F_a = F_new
                            g_a = gtd_new

                    else:
                        break

                # compute new steplength

                # if first step or not interpolating, then bisect or multiply by factor
                if not interpolate or not is_legal(F_b):
                    if beta == float('Inf'):
                        t = eta * t
                    else:
                        t = (alpha + beta) / 2.0

                # otherwise interpolate between a and b
                else:
                    t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]]))

                    # if values are too extreme, adjust t
                    if beta == float('Inf'):
                        if t > 2 * eta * t_prev:
                            t = 2 * eta * t_prev
                        elif t < eta * t_prev:
                            t = eta * t_prev
                    else:
                        if t < alpha + 0.2 * (beta - alpha):
                            t = alpha + 0.2 * (beta - alpha)
                        elif t > (beta - alpha) / 2.0:
                            t = (beta - alpha) / 2.0

                    # if we obtain nonsensical value from interpolation
                    if t <= 0:
                        t = (beta - alpha) / 2.0

                
                    # gtd = g_Sk.dot(d)
                # update parameters
                if inplace:
                    self._add_update(t - t_prev, d)
                else:
                    self._load_params(current_params)
                    self._add_update(t, d)

                # evaluate closure
                F_new = closure()
                closure_eval += 1
                ls_step += 1

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            # print final steplength
            if ls_debug:
                print('Final Steplength:', t)
                print(
                    '===================================== End Wolfe line search =====================================')

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = fail
            return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail

        else:

            # perform update
            self._add_update(t, d)

            # store Bs
            if Bs is None:
                Bs = (g_Sk.mul(-t)).clone()
            else:
                Bs.copy_(g_Sk.mul(-t))

            state['d'] = d
            state['prev_flat_grad'] = prev_flat_grad
            state['t'] = t
            state['Bs'] = Bs
            state['fail'] = False

            return t

    def step(self, p_k, g_Ok, g_Sk=None, options={}):
        return self._step(p_k, g_Ok, g_Sk, options)


class FullBatchLBFGS(LBFGS1):

    def __init__(self, params, lr=1, history_size=10, line_search='Wolfe',
                 dtype=torch.float, debug=False):
        super(FullBatchLBFGS, self).__init__(params, lr, history_size, line_search,
                                             dtype, debug)

    def step(self, options=None):

        # load options for damping and eps
        if 'damping' not in options.keys():
            damping = False
        else:
            damping = options['damping']

        if 'eps' not in options.keys():
            eps = 1e-2
        else:
            eps = options['eps']

        # gather gradient
        grad = self._gather_flat_grad()

        # update curvature if after 1st iteration
        state = self.state['global_state']
        if state['n_iter'] > 0:
            self.curvature_update(grad, eps, damping)

        # compute search direction
        p = self.two_loop_recursion(-grad)

        # take step
        return self._step(p, grad, options=options)

In [None]:
"""
Full-Overlap L-BFGS Implementation with Stochastic Wolfe Line Search

Demonstrates how to implement full-overlap L-BFGS with stochastic weak Wolfe line
search without Powell damping to train a simple convolutional neural network using the 
LBFGS optimizer. Full-overlap L-BFGS is a stochastic quasi-Newton method that uses 
the same sample as the one used in the stochastic gradient to perform quasi-Newton 
updating, then resamples an entirely independent new sample in the next iteration.

This implementation is CUDA-compatible.

Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
Last edited 10/20/20.

Requirements:
    - Keras (for CIFAR-10 dataset)
    - NumPy
    - PyTorch

Run Command:
    python full_overlap_lbfgs_example.py

Based on stable quasi-Newton updating introduced by Schraudolph, Yu, and Gunter in
"A Stochastic Quasi-Newton Method for Online Convex Optimization" (2007)

"""
from time import process_time

import sys
sys.path.append('/content/drive/MyDrive/lbfgs/PyTorch-LBFGS/functions')

import numpy as np
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from keras.datasets import cifar10 # to load dataset

from utils import compute_stats, get_grad
# from LBFGS import LBFGS

# Parameters for L-BFGS training
max_iter = 100
ghost_batch = 128
batch_size = 8192

# Load data
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = X_train / 255
X_test = X_test / 255

X_train = np.transpose(X_train, (0, 3, 1, 2))
X_test = np.transpose(X_test, (0, 3, 1, 2))
# class Alexnet(nn.Module):
#     def __init__(self):
#         super(Alexnet, self).__init__()
#         self.conv1 = nn.Conv2d(3,64,3,2,1)
#         self.pool = nn.MaxPool2d(3, 2)
#         self.conv2 = nn.Conv2d(64,192, 5, 1, 2)
#         self.conv3 = nn.Conv2d(192, 384, 3, 1, 1)
#         self.conv4 = nn.Conv2d(384,256, 3, 1, 1)
#         self.conv5 = nn.Conv2d(256,256, 3, 1, 1)
#         self.drop = nn.Dropout(0.5)
#         self.fc1 = nn.Linear(256*6*6, 4096)
#         self.fc2 = nn.Linear(4096, 4096)
#         self.fc3 = nn.Linear(4096, 1000)
#         self.fc4 = nn.Linear(1000, 10)

#     def forward(self, x):
#         print('1')
#         x = self.pool(F.relu(self.conv1(x)))
#         print('1')
#         x = self.pool(F.relu(self.conv2(x)))
#         print('1')
#         x = F.relu(self.conv3(x))
#         x = F.relu(self.conv4(x))
#         x = self.pool(F.relu(self.conv5(x)))
#         x = x.view(-1, self.num_flat_features(x))
#         x = F.relu(self.fc1(x))
#         x = self.drop(F.relu(self.fc1(x)))
#         x = self.drop(F.relu(self.fc2(x)))
#         x = self.fc3(x)
#         x = self.fc4(x)
#         return x
# Define network
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 1000)
        self.fc2 = nn.Linear(1000, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Check cuda availability
cuda = torch.cuda.is_available()
    
# Create neural network model
if cuda:
    torch.cuda.manual_seed(2018)
    model = ConvNet().cuda() 
else:
    torch.manual_seed(2018)
    model = ConvNet()
# Define helper functions

# Forward pass
if cuda:
    opfun = lambda X: model.forward(torch.from_numpy(X).cuda())
else:
    opfun = lambda X: model.forward(torch.from_numpy(X))

# Forward pass through the network given the input
if cuda:
    predsfun = lambda op: np.argmax(op.cpu().data.numpy(), 1)
else:
    predsfun = lambda op: np.argmax(op.data.numpy(), 1)

# Do the forward pass, then compute the accuracy
accfun = lambda op, y: np.mean(np.equal(predsfun(op), y.squeeze())) * 100

# Define optimizer
optimizer = LBFGS1(model.parameters(), lr=1., history_size=10, line_search='Wolfe', debug=True)
# optimizer2 = LBFGS2(model.parameters(), lr=1., history_size=10, line_search='Wolfe', debug=True)
train_loss_list ,test_loss_list, test_acc_list,x_list = [],[],[],[]
# Main training loop
process_time()
for n_iter in range(max_iter):
    
    # training mode
    model.train()
    
    # sample batch
    random_index = np.random.permutation(range(X_train.shape[0]))
    Sk = random_index[0:batch_size]
    
    # compute initial gradient and objective
    grad, obj = get_grad(optimizer, X_train[Sk], y_train[Sk], opfun)
    
    # two-loop recursion to compute search direction
    p = optimizer.two_loop_recursion(-grad)
            
    # define closure for line search
    def closure():              
        
        optimizer.zero_grad()
        
        if cuda:
            loss_fn = torch.tensor(0, dtype=torch.float).cuda()
        else:
            loss_fn = torch.tensor(0, dtype=torch.float)
        
        for subsmpl in np.array_split(Sk, max(int(batch_size / ghost_batch), 1)):
                        
            ops = opfun(X_train[subsmpl])
            
            if cuda:
                tgts = torch.from_numpy(y_train[subsmpl]).cuda().long().squeeze()
            else:
                tgts = torch.from_numpy(y_train[subsmpl]).long().squeeze()
                
            loss_fn += F.cross_entropy(ops, tgts) * (len(subsmpl) / batch_size)
                        
        return loss_fn
    
    # perform line search step
    options = {'closure': closure, 'current_loss': obj}
    obj, grad, lr, _, _, _, _, _ = optimizer.step(p, grad, options=options)
        
    # curvature update
    optimizer.curvature_update(grad)
    
    # compute statistics
    model.eval()
    train_loss, test_loss, test_acc = compute_stats(X_train, y_train, X_test, y_test, opfun, accfun,
                                                    ghost_batch=128)
    train_loss_list.append(float(train_loss))
    test_loss_list.append(float(test_loss))
    test_acc_list.append(float(test_acc))
    x_list.append(n_iter+1)
    # print data
    print('Iter:', n_iter + 1, 'lr:', lr, 'Training Loss:', train_loss, 'Test Loss:', test_loss,
          'Test Accuracy:', test_acc)
print("运行时间是: {:9.9}s".format(process_time()))
plt.plot(x_list, train_loss_list, color='red', marker='o', linestyle='dashed', linewidth=2, markersize=1)
plt.plot(x_list, test_loss_list, color='green', marker='o', linestyle='dashed', linewidth=2, markersize=1)
plt.plot(x_list, test_acc_list, color='blue', marker='o', linestyle='dashed', linewidth=2, markersize=1)
plt.legend(labels=('Training Loss', 'Test Loss', 'Test Accuracy'))
plt.savefig('./a1.jpg')
plt.show()
import numpy as np
train_loss_list=np.array(train_loss_list)
np.save('train_loss_list.npy',train_loss_list)
test_loss_list=np.array(test_loss_list)
np.save('test_loss_list.npy',test_loss_list)
test_acc_list=np.array(test_acc_list)
np.save('test_acc_list.npy',test_acc_list)

In [None]:
test_loss_list_new

In [None]:
sdsd

In [None]:
!pip install backpack-for-pytorch

In [None]:
# !git clone https://github.com/youli-jlu/PyTorch_Adam_vs_LBFGS.git

In [None]:
import torch
from functools import reduce
from torch.optim.optimizer import Optimizer
from torch.nn.functional import normalize
from torch import  matmul
import math
be_verbose = False


class LBFGSNew(Optimizer):
    """Implements L-BFGS algorithm.

    .. warning::
        This optimizer doesn't support per-parameter options and parameter
        groups (there can be only one).

    .. warning::
        Right now all parameters have to be on a single device. This will be
        improved in the future.

    .. note::
        This is a very memory intensive optimizer (it requires additional
        ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
        try reducing the history size, or use a different algorithm.

    Arguments:
        lr (float): learning rate (fallback value when line search fails. not really needed) (default: 1)
        max_iter (int): maximal number of iterations per optimization step
            (default: 10)
        max_eval (int): maximal number of function evaluations per optimization
            step (default: max_iter * 1.25).
        tolerance_grad (float): termination tolerance on first order optimality
            (default: 1e-5).
        tolerance_change (float): termination tolerance on function
            value/parameter changes (default: 1e-9).
        history_size (int): update history size (default: 7).
        line_search_fn: if True, use cubic interpolation to findstep size, if False: fixed step size
        batch_mode: True for stochastic version (default False)

        Example usage for full batch mode:

          optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=100, line_search_fn=True, batch_mode=False)

        Example usage for batch mode (stochastic):

          optimizer = LBFGSNew(net.parameters(), history_size=7, max_iter=4, line_search_fn=True,batch_mode=True)
          Note: when using a closure(), only do backward() after checking the gradient is available,
          Eg:
            def closure():
             optimizer.zero_grad()
             outputs=net(inputs)
             loss=criterion(outputs,labels)
             if loss.requires_grad:
               loss.backward()
             return loss

    """

    def __init__(self, params, lr=1, max_iter=10, max_eval=None,
                 tolerance_grad=1e-5, tolerance_change=1e-9, history_size=7,
                 line_search_fn=False, batch_mode=False):
        if max_eval is None:
            max_eval = max_iter * 5 // 4
        defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
                        tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
                        history_size=history_size, line_search_fn=line_search_fn,
                        batch_mode=batch_mode)
        super(LBFGSNew, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError("LBFGS doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.data.new(p.data.numel()).zero_()
            elif p.grad.data.is_sparse:
                view = p.grad.data.to_dense().contiguous().view(-1)
            else:
                view = p.grad.data.contiguous().view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.data.add_(update[offset:offset + numel].view_as(p.data), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    # FF copy the parameter values out, create a single vector
    def _copy_params_out(self):
        offset = 0
        new_params = []
        for p in self._params:
            numel = p.numel()
            new_param1 = p.data.clone().contiguous().view(-1)
            offset += numel
            new_params.append(new_param1)
        assert offset == self._numel()
        return torch.cat(new_params, 0)

    # FF copy the parameter values back, dividing the vector into a list
    def _copy_params_in(self, new_params):
        offset = 0
        for p in self._params:
            numel = p.numel()
            p.data.copy_(new_params[offset:offset + numel].view_as(p.data))
            offset += numel
        assert offset == self._numel()

    # FF line search xk=self._params, pk=step direction, gk=gradient, alphabar=max. step size
    def _linesearch_backtrack(self, closure, pk, gk, alphabar):
        """Line search (backtracking)

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            pk: step direction vector
            gk: gradient vector
            alphabar: max step size
        """

        # constants (FIXME) find proper values
        # c1: large values better for small batch sizes
        c1 = 1e-4
        citer = 35
        alphak = alphabar  # default return step

        # state parameter
        state = self.state[self._params[0]]

        # make a copy of original params
        xk = self._copy_params_out()

        f_old = float(closure())
        # param = param + alphak * pk
        self._add_grad(alphak, pk)
        f_new = float(closure())

        # prod = c1 * ( alphak ) * gk^T pk = alphak * prodterm
        s = gk
        prodterm = c1 * (s.dot(pk))

        ci = 0
        if be_verbose:
            print('LN %d alpha=%f fnew=%f fold=%f prod=%f' % (ci, alphak, f_new, f_old, prodterm))
        # catch cases where f_new is NaN
        while (ci < citer and (math.isnan(f_new) or f_new > f_old + alphak * prodterm)):
            alphak = 0.5 * alphak
            self._copy_params_in(xk)
            self._add_grad(alphak, pk)
            f_new = float(closure())
            if be_verbose:
                print('LN %d alpha=%f fnew=%f fold=%f' % (ci, alphak, f_new, f_old))
            ci = ci + 1

        # if the cost is not sufficiently decreased, also try -ve steps
        if (f_old - f_new < torch.abs(prodterm)):
            alphak1 = -alphabar
            self._copy_params_in(xk)
            self._add_grad(alphak1, pk)
            f_new1 = float(closure())
            if be_verbose:
                print('NLN fnew=%f' % f_new1)
            while (ci < citer and (math.isnan(f_new1) or f_new1 > f_old + alphak1 * prodterm)):
                alphak1 = 0.5 * alphak1
                self._copy_params_in(xk)
                self._add_grad(alphak1, pk)
                f_new1 = float(closure())
                if be_verbose:
                    print('NLN %d alpha=%f fnew=%f fold=%f' % (ci, alphak1, f_new1, f_old))
                ci = ci + 1

            if f_new1 < f_new:
                # select -ve step
                alphak = alphak1

        # recover original params
        self._copy_params_in(xk)
        # update state
        state['func_evals'] += ci
        return alphak

    # FF line search xk=self._params, pk=gradient
    def _linesearch_cubic(self, closure, pk, step):
        """Line search (strong-Wolfe)

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            pk: gradient vector
            step: step size for differencing
        """

        # constants
        alpha1 = 10 * self.param_groups[0]['lr']  # 10.0
        sigma = 0.1
        rho = 0.01
        t1 = 9
        t2 = 0.1
        t3 = 0.5
        alphak = self.param_groups[0]['lr']  # default return step

        # state parameter
        state = self.state[self._params[0]]

        # make a copy of original params
        xk = self._copy_params_out()

        phi_0 = float(closure())
        tol = min(phi_0 * 0.01, 1e-6)

        # xp <- xk+step. pk
        self._add_grad(step, pk)  # FF param = param + t * grad
        p01 = float(closure())
        # xp <- xk-step. pk
        self._add_grad(-2.0 * step, pk)  # FF param = param - t * grad
        p02 = float(closure())

        ##print("p01="+str(p01)+" p02="+str(p02))
        gphi_0 = (p01 - p02) / (2.0 * step)
        ##print("tol="+str(tol)+" phi_0="+str(phi_0)+" gphi_0="+str(gphi_0))
        # catch instances when step size is too small
        if abs(gphi_0) < 1e-12:
            return 1.0

        mu = (tol - phi_0) / (rho * gphi_0)
        # catch if mu is not finite
        if math.isnan(mu):
            return 1.0

        ##print("mu="+str(mu))

        # counting function evals
        closure_evals = 3

        ci = 1
        alphai = alpha1  # initial value for alpha(i) : check if 0<alphai<=mu
        alphai1 = 0.0
        phi_alphai1 = phi_0
        while (ci < 4):  # FIXME
            # evalualte phi(alpha(i))=f(xk+alphai pk)
            self._copy_params_in(xk)  # original
            # xp <- xk+alphai. pk
            self._add_grad(alphai, pk)  #
            phi_alphai = float(closure())
            if phi_alphai < tol:
                alphak = alphai
                if be_verbose:
                    print("Linesearch: condition 0 met")
                break
            if (phi_alphai > phi_0 + alphai * gphi_0) or (ci > 1 and phi_alphai >= phi_alphai1):
                # ai=alphai1, bi=alphai bracket
                if be_verbose:
                    print("bracket " + str(alphai1) + "," + str(alphai))
                alphak = self._linesearch_zoom(closure, xk, pk, alphai1, alphai, phi_0, gphi_0, sigma, rho, t1, t2, t3,
                                               step)
                if be_verbose:
                    print("Linesearch: condition 1 met")
                break

            # evaluate grad(phi(alpha(i))) */
            # note that self._params already is xk+alphai. pk, so only add the missing term
            # xp <- xk+(alphai+step). pk
            self._add_grad(step, pk)  # FF param = param - t * grad
            p01 = float(closure())
            # xp <- xk+(alphai-step). pk
            self._add_grad(-2.0 * step, pk)  # FF param = param - t * grad
            p02 = float(closure())
            gphi_i = (p01 - p02) / (2.0 * step);

            if (abs(gphi_i) <= -sigma * gphi_0):
                alphak = alphai
                if be_verbose:
                    print("Linesearch: condition 2 met")
                break

            if gphi_i >= 0.0:
                # ai=alphai, bi=alphai1 bracket
                if be_verbose:
                    print("bracket " + str(alphai) + "," + str(alphai1))
                alphak = self._linesearch_zoom(closure, xk, pk, alphai, alphai1, phi_0, gphi_0, sigma, rho, t1, t2, t3,
                                               step)
                if be_verbose:
                    print("Linesearch: condition 3 met")
                break
            # else preserve old values
            if (mu <= 2.0 * alphai - alphai1):
                alphai1 = alphai
                alphai = mu
            else:
                # choose by interpolation in [2*alphai-alphai1,min(mu,alphai+t1*(alphai-alphai1)]
                p01 = 2.0 * alphai - alphai1;
                p02 = min(mu, alphai + t1 * (alphai - alphai1))
                alphai = self._cubic_interpolate(closure, xk, pk, p01, p02, step)

            phi_alphai1 = phi_alphai;
            # update function evals
            closure_evals += 3
            ci = ci + 1

        # recover original params
        self._copy_params_in(xk)
        # update state
        state['func_evals'] += closure_evals
        return alphak

    def _cubic_interpolate(self, closure, xk, pk, a, b, step):
        """ Cubic interpolation within interval [a,b] or [b,a] (a>b is possible)

           Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            xk: copy of parameter values
            pk: gradient vector
            a/b:  interval for interpolation
            step: step size for differencing
        """

        self._copy_params_in(xk)

        # state parameter
        state = self.state[self._params[0]]
        # count function evals
        closure_evals = 0

        # xp <- xk+a. pk
        self._add_grad(a, pk)  # FF param = param + t * grad
        f0 = float(closure())
        # xp <- xk+(a+step). pk
        self._add_grad(step, pk)  # FF param = param + t * grad
        p01 = float(closure())
        # xp <- xk+(a-step). pk
        self._add_grad(-2.0 * step, pk)  # FF param = param - t * grad
        p02 = float(closure())
        f0d = (p01 - p02) / (2.0 * step)

        # xp <- xk+b. pk
        self._add_grad(-a + step + b, pk)  # FF param = param + t * grad
        f1 = float(closure())
        # xp <- xk+(b+step). pk
        self._add_grad(step, pk)  # FF param = param + t * grad
        p01 = float(closure())
        # xp <- xk+(b-step). pk
        self._add_grad(-2.0 * step, pk)  # FF param = param - t * grad
        p02 = float(closure())
        f1d = (p01 - p02) / (2.0 * step)

        closure_evals = 6

        aa = 3.0 * (f0 - f1) / (b - a) + f1d - f0d
        p01 = aa * aa - f0d * f1d
        if (p01 > 0.0):
            cc = math.sqrt(p01)
            # print('f0='+str(f0d)+' f1='+str(f1d)+' cc='+str(cc))
            if (f1d - f0d + 2.0 * cc) == 0.0:
                return (a + b) * 0.5
            z0 = b - (f1d + cc - aa) * (b - a) / (f1d - f0d + 2.0 * cc)
            aa = max(a, b)
            cc = min(a, b)
            if z0 > aa or z0 < cc:
                fz0 = f0 + f1
            else:
                # xp <- xk+(a+z0*(b-a))*pk
                self._add_grad(-b + step + a + z0 * (b - a), pk)  # FF param = param + t * grad
                fz0 = float(closure())
                closure_evals += 1

            # update state
            state['func_evals'] += closure_evals

            if f0 < f1 and f0 < fz0:
                return a

            if f1 < fz0:
                return b
            # else
            return z0
        else:

            # update state
            state['func_evals'] += closure_evals

            if f0 < f1:
                return a
            else:
                return b

        # update state
        state['func_evals'] += closure_evals

        # fallback value
        return (a + b) * 0.5

    # FF bracket [a,b]
    # xk: copy of parameters, use it to refresh self._param
    def _linesearch_zoom(self, closure, xk, pk, a, b, phi_0, gphi_0, sigma, rho, t1, t2, t3, step):
        """Zoom step in line search

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            xk: copy of parameter values
            pk: gradient vector
            a/b:  bracket interval for line search,
            phi_0: phi(0)
            gphi_0: grad(phi(0))
            sigma,rho,t1,t2,t3: line search parameters (from Fletcher)
            step: step size for differencing
        """

        # state parameter
        state = self.state[self._params[0]]
        # count function evals
        closure_evals = 0

        aj = a
        bj = b
        ci = 0
        found_step = False
        while ci < 4:  # FIXME original 10
            # choose alphaj from [a+t2(b-a),b-t3(b-a)]
            p01 = aj + t2 * (bj - aj)
            p02 = bj - t3 * (bj - aj)
            alphaj = self._cubic_interpolate(closure, xk, pk, p01, p02, step)

            # evaluate phi(alphaj)
            self._copy_params_in(xk)
            # xp <- xk+alphaj. pk
            self._add_grad(alphaj, pk)  # FF param = param + t * grad
            phi_j = float(closure())

            # evaluate phi(aj)
            # xp <- xk+aj. pk
            self._add_grad(-alphaj + aj, pk)  # FF param = param + t * grad
            phi_aj = float(closure())

            closure_evals += 2

            if (phi_j > phi_0 + rho * alphaj * gphi_0) or phi_j >= phi_aj:
                bj = alphaj  # aj is unchanged
            else:
                # evaluate grad(alphaj)
                # xp <- xk+(alphaj+step). pk
                self._add_grad(-aj + alphaj + step, pk)  # FF param = param + t * grad
                p01 = float(closure())
                # xp <- xk+(alphaj-step). pk
                self._add_grad(-2.0 * step, pk)  # FF param = param + t * grad
                p02 = float(closure())
                gphi_j = (p01 - p02) / (2.0 * step)

                closure_evals += 2

                # termination due to roundoff/other errors pp. 38, Fletcher
                if (aj - alphaj) * gphi_j <= step:
                    alphak = alphaj
                    found_step = True
                    break

                if abs(gphi_j) <= -sigma * gphi_0:
                    alphak = alphaj
                    found_step = True
                    break

                if gphi_j * (bj - aj) >= 0.0:
                    bj = aj
                # else bj is unchanged
                aj = alphaj

            ci = ci + 1

        if not found_step:
            alphak = alphaj

        # update state
        state['func_evals'] += closure_evals

        return alphak

    def step(self, closure):
        """Performs a single optimization step.

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        group = self.param_groups[0]
        lr = group['lr']
        max_iter = group['max_iter']
        max_eval = group['max_eval']
        tolerance_grad = group['tolerance_grad']
        tolerance_change = group['tolerance_change']
        line_search_fn = group['line_search_fn']
        history_size = group['history_size']

        batch_mode = group['batch_mode']

        # NOTE: LBFGS has only global state, but we register it as state for
        # the first param, because this helps with casting in load_state_dict
        state = self.state[self._params[0]]
        state.setdefault('func_evals', 0)
        state.setdefault('n_iter', 0)
        # ---------------------------------------------------

        first_total = 0.001
        orig_loss = closure()
        flat_grad_first = self._gather_flat_grad()
        abs_grad_sum = flat_grad_first.abs().sum()
        if abs_grad_sum <= first_total:
            return orig_loss
        d = -flat_grad_first
        torch.set_grad_enabled(False)
        # if not batch_mode:
        t = self._linesearch_cubic(closure, d, 1e-6)
        torch.set_grad_enabled(True)

        if math.isnan(t):
            print('Warning: stepsize nan')
            t = lr
        self._add_grad(t, d)  # FF param = param + t * d
        maxk = 200
        k = 0
        grad_list = []
        num_list = []
        flat_grad_first = normalize(flat_grad_first, p=2.0, dim=0)
        second_out = 1000
        while k < maxk:
            loss = float(closure())
            flat_grad = self._gather_flat_grad()
            abs_grad_sum = flat_grad.abs().sum()
            if math.isnan(abs_grad_sum):
                print('Warning: gradient nan')
                break
            if abs_grad_sum <= second_out:
                break
            flat_grad = normalize(flat_grad, p=2.0, dim=0)
            bj = (matmul(flat_grad.t(), (flat_grad - flat_grad_first))) / (matmul(flat_grad_first.t(), flat_grad_first))
            if len(grad_list) < 3:
                grad_list.append(flat_grad - flat_grad_first)
            else:
                grad_list.append(flat_grad - flat_grad_first)
                grad_list.pop(0)
            flat_grad_first = flat_grad
            d = -flat_grad + (bj * d)
            d = normalize(d, p=2.0, dim=0)
            torch.set_grad_enabled(False)
            t = self._linesearch_cubic(closure, d, 1e-6)
            torch.set_grad_enabled(True)
            if len(num_list) < 3:
                num_list.append(-t * d)
            else:
                num_list.append(-t * d)
                num_list.pop(0)
            self._add_grad(t, d)  # FF param = param + t * d
            k += 1
        if len(grad_list)>=1:
            yk_1 = grad_list[0]
            rk = (matmul(num_list[0].t(), yk_1)) / (matmul(yk_1.t(), yk_1))
            # del yk_1
        else:
            rk =1
        # del num_list,grad_list,flat_grad_first,t,d,abs_grad_sum,flat_grad,bj,orig_loss
        # ----------------------------------------------------
        # evaluate initial f(x) and df/dx
        orig_loss = closure()
        loss = float(orig_loss)
        current_evals = 1
        state['func_evals'] += 1

        flat_grad = self._gather_flat_grad()
        abs_grad_sum = flat_grad.abs().sum()

        if abs_grad_sum <= tolerance_grad:
            return orig_loss

        # tensors cached in state (for tracing)
        d = state.get('d')
        t = state.get('t')
        old_dirs = state.get('old_dirs')
        old_stps = state.get('old_stps')
        H_diag = state.get('H_diag')
        prev_flat_grad = state.get('prev_flat_grad')
        prev_loss = state.get('prev_loss')

        n_iter = 0

        if batch_mode:
            alphabar = lr
            lm0 = 1e-6

        # optimize for a max of max_iter iterations
        grad_nrm = flat_grad.norm().item()
        while n_iter < max_iter and not math.isnan(grad_nrm):
            # keep track of nb of iterations
            n_iter += 1
            state['n_iter'] += 1

            ############################################################
            # compute gradient descent direction
            ############################################################
            if state['n_iter'] == 1:
                d = flat_grad.neg()
                old_dirs = []
                old_stps = []
                H_diag = rk
                if batch_mode:
                    running_avg = torch.zeros_like(flat_grad.data)
                    running_avg_sq = torch.zeros_like(flat_grad.data)
            else:
                if batch_mode:
                    running_avg = state.get('running_avg')
                    running_avg_sq = state.get('running_avg_sq')
                    if running_avg is None:
                        running_avg = torch.zeros_like(flat_grad.data)
                        running_avg_sq = torch.zeros_like(flat_grad.data)

                # do lbfgs update (update memory)
                # what happens if current and prev grad are equal, ||y||->0 ??
                y = flat_grad.sub(prev_flat_grad)

                s = d.mul(t)

                if batch_mode:  # y = y+ lm0 * s, to have a trust region
                    y.add_(s, alpha=lm0)

                ys = y.dot(s)  # y^T*s
                sn = s.norm().item()  # ||s||
                # FIXME batch_changed does not work for full batch mode (data might be the same)
                batch_changed = batch_mode and (n_iter == 1 and state['n_iter'] > 1)
                if batch_changed:  # batch has changed
                    # online estimate of mean,variance of gradient (inter-batch, not intra-batch)
                    # newmean <- oldmean + (grad - oldmean)/niter
                    # moment <- oldmoment + (grad-oldmean)(grad-newmean)
                    # variance = moment/(niter-1)

                    g_old = flat_grad.clone()
                    g_old.add_(running_avg, alpha=-1.0)  # grad-oldmean
                    running_avg.add_(g_old, alpha=1.0 / state['n_iter'])  # newmean
                    g_new = flat_grad.clone()
                    g_new.add_(running_avg, alpha=-1.0)  # grad-newmean
                    running_avg_sq.addcmul_(g_new, g_old, value=1)  # +(grad-newmean)(grad-oldmean)
                    alphabar = 1 / (1 + running_avg_sq.sum() / ((state['n_iter'] - 1) * (grad_nrm)))
                    if be_verbose:
                        print('iter %d |mean| %f |var| %f ||grad|| %f step %f y^Ts %f alphabar=%f' % (
                            state['n_iter'], running_avg.sum(), running_avg_sq.sum() / (state['n_iter'] - 1), grad_nrm,
                            t,
                            ys, alphabar))

                if ys > 1e-10 * sn * sn and not batch_changed:
                    # updating memory (only when we have y within a single batch)
                    if len(old_dirs) == history_size:
                        # shift history by one (limited-memory)
                        old_dirs.pop(0)
                        old_stps.pop(0)

                    # store new direction/step
                    old_dirs.append(y)
                    old_stps.append(s)

                    # update scale of initial Hessian approximation
                    H_diag = ys / y.dot(y)  # (y*y)

                if math.isnan(H_diag):
                    print('Warning H_diag nan')

                # compute the approximate (L-BFGS) inverse Hessian
                # multiplied by the gradient
                num_old = len(old_dirs)

                if 'ro' not in state:
                    state['ro'] = [None] * history_size
                    state['al'] = [None] * history_size
                ro = state['ro']
                al = state['al']

                for i in range(num_old):
                    ro[i] = 1. / old_dirs[i].dot(old_stps[i])

                # iteration in L-BFGS loop collapsed to use just one buffer
                q = flat_grad.neg()
                for i in range(num_old - 1, -1, -1):
                    al[i] = old_stps[i].dot(q) * ro[i]
                    q.add_(old_dirs[i], alpha=-al[i])

                # multiply by initial Hessian
                # r/d is the final direction
                d = r = torch.mul(q, H_diag)
                for i in range(num_old):
                    be_i = old_dirs[i].dot(r) * ro[i]
                    r.add_(old_stps[i], alpha=al[i] - be_i)

            if prev_flat_grad is None:
                prev_flat_grad = flat_grad.clone()

            else:
                prev_flat_grad.copy_(flat_grad)

            prev_loss = loss

            ############################################################
            # compute step length
            ############################################################
            # reset initial guess for step size
            if state['n_iter'] == 1:
                t = min(1., 1. / abs_grad_sum) * lr
            else:
                t = lr

            # directional derivative
            gtd = flat_grad.dot(d)  # g * d

            if math.isnan(gtd.item()):
                print('Warning grad norm infinite')
                print('iter %d' % state['n_iter'])
                print('||grad||=%f' % grad_nrm)
                print('||d||=%f' % d.norm().item())
            # optional line search: user function
            ls_func_evals = 0
            if line_search_fn:
                # perform line search, using user function
                ##raise RuntimeError("line search function is not supported yet")
                # FF#################################
                # Note: we disable gradient calculation during line search
                # because it is not needed
                torch.set_grad_enabled(False)
                if not batch_mode:
                    t = self._linesearch_cubic(closure, d, 1e-6)
                else:
                    t = self._linesearch_backtrack(closure, d, flat_grad, alphabar)
                torch.set_grad_enabled(True)

                if math.isnan(t):
                    print('Warning: stepsize nan')
                    t = lr
                self._add_grad(t, d)  # FF param = param + t * d
                if be_verbose:
                    print('step size=%f' % (t))
                # FF#################################
            else:
                # FF Here, t = stepsize,  d = -grad, in cache
                # no line search, simply move with fixed-step
                self._add_grad(t, d)  # FF param = param + t * d
            if n_iter != max_iter:
                # re-evaluate function only if not in last iteration
                # the reason we do this: in a stochastic setting,
                # no use to re-evaluate that function here
                loss = float(closure())
                flat_grad = self._gather_flat_grad()
                abs_grad_sum = flat_grad.abs().sum()
                if math.isnan(abs_grad_sum):
                    print('Warning: gradient nan')
                    break
                ls_func_evals = 1

            # update func eval
            current_evals += ls_func_evals
            state['func_evals'] += ls_func_evals

            ############################################################
            # check conditions
            ############################################################
            if n_iter == max_iter:
                break

            if current_evals >= max_eval:
                break

            if abs_grad_sum <= tolerance_grad:
                break

            if gtd > -tolerance_change:
                break

            if d.mul(t).abs_().sum() <= tolerance_change:
                break

            if abs(loss - prev_loss) < tolerance_change:
                break

        state['d'] = d
        state['t'] = t
        state['old_dirs'] = old_dirs
        state['old_stps'] = old_stps
        state['H_diag'] = H_diag
        state['prev_flat_grad'] = prev_flat_grad
        state['prev_loss'] = prev_loss

        if batch_mode:
            if 'running_avg' not in locals() or running_avg is None:
                running_avg = torch.zeros_like(flat_grad.data)
                running_avg_sq = torch.zeros_like(flat_grad.data)
            state['running_avg'] = running_avg
            state['running_avg_sq'] = running_avg_sq

        return orig_loss

In [None]:
#!/home/youli/miniconda3/bin/python3
# coding=utf8
"""
# Author: youli
# Created Time : 2021-12-27 15:38:05

# File Name: model_construct.py
# Description:
    test for Pytorch model

"""
print(f"pytorch test")

import numpy as np
import pandas as pd
import time
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.optim as optim

input_size = 20000
train_size = int(input_size*0.9)
test_size  = input_size-train_size
batch_size = 64


# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")


 
# 搭建基于SENet的Conv Block和Identity Block的网络结构
class Block(nn.Module):
    def __init__(self, in_channels, filters, stride = 1, is_1x1conv = False):
        super(Block, self).__init__()
 
        # 各个Stage中的每一大块中每一小块的输出维度，即channel（filter1 = filter2 = filter3 / 4）
        filter1, filter2, filter3 = filters
 
        self.is_1x1conv = is_1x1conv # 判断是否是Conv Block
        self.relu = nn.ReLU(inplace = True) # RELU操作
 
        # 第一小块， stride = 1(stage = 1) or stride = 2(stage = 2, 3, 4)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filter1, kernel_size = 1, stride = stride, bias = False),
            nn.BatchNorm2d(filter1),
            nn.ReLU()
        )
 
        # 中间小块
        self.conv2 = nn.Sequential(
            nn.Conv2d(filter1, filter2, kernel_size=3, stride = 1, padding = 1, bias=False),
            nn.BatchNorm2d(filter2),
            nn.ReLU()
        )
 
        # 最后小块，不需要进行ReLu操作
        self.conv3 = nn.Sequential(
            nn.Conv2d(filter2, filter3, kernel_size = 1, stride = 1, bias=False),
            nn.BatchNorm2d(filter3),
        )
 
        # Conv Block的输入需要额外进行卷积和归一化操作（结合Conv Block网络图理解）
        if is_1x1conv:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(filter3)
            )
 
        # SENet(结合SENet的网络图理解)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化
            nn.Conv2d(filter3, filter3 // 16, kernel_size=1), # 16表示r，filter3//16表示C/r，这里用卷积层代替全连接层
            nn.ReLU(),
            nn.Conv2d(filter3 // 16, filter3, kernel_size=1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        x_shortcut = x
        x1 = self.conv1(x) # 执行第一Block操作
        x1 = self.conv2(x1) # 执行中间Block操作
        x1 = self.conv3(x1) # 执行最后Block操作
 
        x2 = self.se(x1)  # 利用SENet计算出每个通道的权重大小
        x1 = x1 * x2  # 对原通道进行加权操作
 
        if self.is_1x1conv:  # Conv Block进行额外的卷积归一化操作
            x_shortcut = self.shortcut(x_shortcut)
 
        x1 = x1 + x_shortcut  # Add操作
        x1 = self.relu(x1)  # ReLU操作
 
        return x1

# 搭建SEResNet50
class SEResnet(nn.Module):
    def __init__(self, cfg):
        super(SEResnet, self).__init__()
        classes = cfg['classes']  # 分类的类别
        num = cfg['num']  # ResNet50[3, 4, 6, 3]；Conv Block和 Identity Block的个数
 
        # Stem Block
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        )
 
        # Stage1
        filters = (64, 64, 256)  # channel
        self.Stage1 = self._make_layer(in_channels = 64, filters = filters, num = num[0], stride = 1)
 
        # Stage2
        filters = (128, 128, 512) # channel
        self.Stage2 = self._make_layer(in_channels = 256, filters = filters, num = num[1], stride = 2)
 
        # Stage3
        filters = (256, 256, 1024) # channel
        self.Stage3 = self._make_layer(in_channels = 512, filters = filters, num = num[2], stride = 2)
 
        # Stage4
        filters = (512, 512, 2048) # channel
        self.Stage4 = self._make_layer(in_channels = 1024, filters = filters, num = num[3], stride = 2)
 
        # 自适应平均池化，(1, 1)表示输出的大小(H x W)
        self.global_average_pool = nn.AdaptiveAvgPool2d((1, 1))
 
        # 全连接层 这里可理解为网络中四个Stage后的Subsequent Processing 环节
        self.fc = nn.Sequential(
            nn.Linear(2048, classes)
        )
 
 
    # 形成单个Stage的网络结构
    def _make_layer(self, in_channels, filters, num, stride = 1):
        layers = []
 
        # Conv Block
        block_1 = Block(in_channels, filters, stride = stride, is_1x1conv = True)
        layers.append(block_1)
 
        # Identity Block结构叠加; 基于[3, 4, 6, 3]
        for i in range(1, num):
            layers.append(Block(filters[2], filters, stride = 1, is_1x1conv = False))
 
        # 返回Conv Block和Identity Block的集合，形成一个Stage的网络结构
        return nn.Sequential(*layers)
 
 
    def forward(self, x):
 
        # Stem Block环节
        x = self.conv1(x)
 
        # 执行四个Stage环节
        x = self.Stage1(x)
        x = self.Stage2(x)
        x = self.Stage3(x)
        x = self.Stage4(x)
 
        # 执行Subsequent Processing环节
        x = self.global_average_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
 
        return x
 
    
# SeResNet50的参数  （注意调用这个函数将间接调用SEResnet，这里单独编写一个函数是为了方便修改成其它ResNet网络的结构）
def SeResNet50():
    cfg = {
        'num':(3, 4, 6, 3), # ResNet50，四个Stage中Block的个数（其中Conv Block为1个，剩下均为增加Identity Block）
        'classes': (10)  # 数据集分类的个数
    }
 
    return SEResnet(cfg) # 调用SEResnet网络


model = SeResNet50().to(device)
print(model)

# loss_fn = nn.MSELoss()
loss_fn = nn.CrossEntropyLoss() 
# optimizer_lbfgs= torch.optim.LBFGS(model.parameters(), lr=1, 
#         history_size=100, max_iter=20,
#         line_search_fn="strong_wolfe"
#         )
optimizer_lbfgs= LBFGSNew(model.parameters(),  
        history_size=100, max_iter=20,
        line_search_fn=True,batch_mode=True
        )
# device = "cuda" if torch.cuda.is_available() else "cpu"
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    lm_lbfgs=model.to(device)
    #spacial function for LBFGS
    for batch, (X, y) in enumerate(dataloader):
        x_ = Variable(X, requires_grad=True).to(device)
        y_ = Variable(y).to(device)
        def closure():
            # Zero gradients
            # optimizer.zero_grad()
            # # Forward pass
            # y_pred = lm_lbfgs(x_)
            # # Compute loss
            # loss = loss_fn(y_pred, y_)
            # # [x.grad.data for x in model.parameters()]
            # # Backward pass
            # # loss.requires_grad = True
            # loss.backward()
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            y_pred = lm_lbfgs(x_)
            # print(y_pred.shape,y_.shape)
            loss = loss_fn(y_pred, y_)
            if loss.requires_grad:
                loss.backward()
            # if torch.is_grad_enabled():
            #     optimizer.zero_grad()
            # y_pred = lm_lbfgs(x_)
            # loss = loss_fn(y_pred, y_)
            # if loss.requires_grad:
            #     loss.backward()
            return loss

        optimizer.step(closure)
        loss=closure()

        if batch % train_size == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return loss_train

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    print(f"Test Error: \n  Avg loss: {test_loss:>8f} \n")
    return test_loss


# 下载训练集
train_set = torchvision.datasets.CIFAR10(
    root = "data/cifar-10", train = True,
    download = True, transform = transforms.ToTensor()
)

# 下载测试集
test_set = torchvision.datasets.CIFAR10(
    root = "data/cifar-10", train = False,
    download = True, transform = transforms.ToTensor()
)

train_iter = torch.utils.data.DataLoader(
    train_set, batch_size = batch_size, shuffle = True, num_workers = 2
)

test_iter = torch.utils.data.DataLoader(
    test_set, batch_size = batch_size, shuffle = True, num_workers = 2
)


# training

opt_label='lbfgs_original-t20-t20'
epochs = 50
print(f"test for {opt_label}")
optimizer=optimizer_lbfgs
loss_train=[]
loss_test=[]

t1= time.perf_counter()
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_iter, model, loss_fn, optimizer)
    loss_train+=[
            test(train_iter, model, loss_fn)
            ]
    loss_test+=[
            test(test_iter, model, loss_fn)
            ]
    print("Done!")

t2= time.perf_counter()
print("Elapsed time: ", t2- t1)
record=pd.DataFrame({
    'epochs':np.arange(epochs)
    ,'loss_train':np.array(loss_train)
    ,'loss_test':np.array(loss_test)
    })
record.to_csv(f"{opt_label}",sep=' ')




torch.save(model.state_dict(), f"model{opt_label}.pth")
print(f"Saved PyTorch Model State to model{opt_label}.pth")

In [None]:
import torch
from functools import reduce
from torch.optim.optimizer import Optimizer

import math

be_verbose=False

class LBFGSNew(Optimizer):
    """Implements L-BFGS algorithm.

    .. warning::
        This optimizer doesn't support per-parameter options and parameter
        groups (there can be only one).

    .. warning::
        Right now all parameters have to be on a single device. This will be
        improved in the future.

    .. note::
        This is a very memory intensive optimizer (it requires additional
        ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
        try reducing the history size, or use a different algorithm.

    Arguments:
        lr (float): learning rate (fallback value when line search fails. not really needed) (default: 1)
        max_iter (int): maximal number of iterations per optimization step
            (default: 10)
        max_eval (int): maximal number of function evaluations per optimization
            step (default: max_iter * 1.25).
        tolerance_grad (float): termination tolerance on first order optimality
            (default: 1e-5).
        tolerance_change (float): termination tolerance on function
            value/parameter changes (default: 1e-9).
        history_size (int): update history size (default: 7).
        line_search_fn: if True, use cubic interpolation to findstep size, if False: fixed step size
        batch_mode: True for stochastic version (default False)

        Example usage for full batch mode:

          optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=100, line_search_fn=True, batch_mode=False)

        Example usage for batch mode (stochastic):

          optimizer = LBFGSNew(net.parameters(), history_size=7, max_iter=4, line_search_fn=True,batch_mode=True)
          Note: when using a closure(), only do backward() after checking the gradient is available,
          Eg: 
            def closure():
             optimizer.zero_grad()
             outputs=net(inputs)
             loss=criterion(outputs,labels)
             if loss.requires_grad:
               loss.backward()
             return loss

    """

    def __init__(self, params, lr=1, max_iter=10, max_eval=None,
                 tolerance_grad=1e-5, tolerance_change=1e-9, history_size=7,
                 line_search_fn=False, batch_mode=False):
        if max_eval is None:
            max_eval = max_iter * 5 // 4
        defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
                        tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
                        history_size=history_size, line_search_fn=line_search_fn,
                        batch_mode=batch_mode)
        super(LBFGSNew, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError("LBFGS doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None
        torch.set_grad_enabled(True)
    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.data.new(p.data.numel()).zero_()
            elif p.grad.data.is_sparse:
                view = p.grad.data.to_dense().contiguous().view(-1)
            else:
                view = p.grad.data.contiguous().view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.data.add_(update[offset:offset + numel].view_as(p.data), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    #FF copy the parameter values out, create a single vector
    def _copy_params_out(self):
        offset = 0
        new_params = []
        for p in self._params:
            numel = p.numel()
            new_param1=p.data.clone().contiguous().view(-1)
            offset += numel
            new_params.append(new_param1)
        assert offset == self._numel()
        return torch.cat(new_params,0)

    #FF copy the parameter values back, dividing the vector into a list
    def _copy_params_in(self,new_params):
        offset = 0
        for p in self._params:
            numel = p.numel()
            p.data.copy_(new_params[offset:offset+numel].view_as(p.data))
            offset += numel
        assert offset == self._numel()

    #FF line search xk=self._params, pk=step direction, gk=gradient, alphabar=max. step size
    def _linesearch_backtrack(self,closure,pk,gk,alphabar):
        """Line search (backtracking)

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            pk: step direction vector
            gk: gradient vector 
            alphabar: max step size
        """


        # constants (FIXME) find proper values
        # c1: large values better for small batch sizes
        c1=1e-4
        citer=35
        alphak=alphabar# default return step
 
        # state parameter 
        state = self.state[self._params[0]]

        # make a copy of original params
        xk=self._copy_params_out()

   
        f_old=float(closure())
        # param = param + alphak * pk
        self._add_grad(alphak, pk)
        f_new=float(closure())

        # prod = c1 * ( alphak ) * gk^T pk = alphak * prodterm
        s=gk
        prodterm=c1*(s.dot(pk))

        ci=0
        if be_verbose:
         print('LN %d alpha=%f fnew=%f fold=%f prod=%f'%(ci,alphak,f_new,f_old,prodterm))
        # catch cases where f_new is NaN
        while (ci<citer and (math.isnan(f_new) or  f_new > f_old + alphak*prodterm)):
           alphak=0.5*alphak
           self._copy_params_in(xk)
           self._add_grad(alphak, pk)
           f_new=float(closure())
           if be_verbose:
             print('LN %d alpha=%f fnew=%f fold=%f'%(ci,alphak,f_new,f_old))
           ci=ci+1

        # if the cost is not sufficiently decreased, also try -ve steps
        if (f_old-f_new < torch.abs(prodterm)):
          alphak1=-alphabar
          self._copy_params_in(xk)
          self._add_grad(alphak1, pk)
          f_new1=float(closure())
          if be_verbose:
            print('NLN fnew=%f'%f_new1)
          while (ci<citer and (math.isnan(f_new1) or  f_new1 > f_old + alphak1*prodterm)):
             alphak1=0.5*alphak1
             self._copy_params_in(xk)
             self._add_grad(alphak1, pk)
             f_new1=float(closure())
             if be_verbose:
               print('NLN %d alpha=%f fnew=%f fold=%f'%(ci,alphak1,f_new1,f_old))
             ci=ci+1

          if f_new1<f_new:
            # select -ve step
            alphak=alphak1

        # recover original params
        self._copy_params_in(xk)
        # update state
        state['func_evals'] += ci
        return alphak



    #FF line search xk=self._params, pk=gradient
    def _linesearch_cubic(self,closure,pk,step):
        """Line search (strong-Wolfe)

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            pk: gradient vector 
            step: step size for differencing 
        """

        # constants
        alpha1=10*self.param_groups[0]['lr']#10.0
        sigma=0.1
        rho=0.01
        t1=9 
        t2=0.1
        t3=0.5
        alphak=self.param_groups[0]['lr']# default return step
 
        # state parameter 
        state = self.state[self._params[0]]

        # make a copy of original params
        xk=self._copy_params_out()

   
        phi_0=float(closure())
        tol=min(phi_0*0.01,1e-6)

        # xp <- xk+step. pk
        self._add_grad(step, pk) #FF param = param + t * grad 
        p01=float(closure())
        # xp <- xk-step. pk
        self._add_grad(-2.0*step, pk) #FF param = param - t * grad 
        p02=float(closure())

        ##print("p01="+str(p01)+" p02="+str(p02))
        gphi_0=(p01-p02)/(2.0*step)
        ##print("tol="+str(tol)+" phi_0="+str(phi_0)+" gphi_0="+str(gphi_0))
        # catch instances when step size is too small 
        if abs(gphi_0)<1e-12:
          return 1.0

        mu=(tol-phi_0)/(rho*gphi_0)
        # catch if mu is not finite
        if math.isnan(mu):
           return 1.0

        ##print("mu="+str(mu))
        
        # counting function evals
        closure_evals=3

        ci=1
        alphai=alpha1 # initial value for alpha(i) : check if 0<alphai<=mu 
        alphai1=0.0
        phi_alphai1=phi_0
        while (ci<4) : # FIXME
          # evalualte phi(alpha(i))=f(xk+alphai pk)
          self._copy_params_in(xk) # original
          # xp <- xk+alphai. pk
          self._add_grad(alphai, pk) #
          phi_alphai=float(closure())
          if phi_alphai<tol:
             alphak=alphai 
             if be_verbose:
              print("Linesearch: condition 0 met")
             break
          if (phi_alphai>phi_0+alphai*gphi_0) or (ci>1 and phi_alphai>=phi_alphai1) :
             # ai=alphai1, bi=alphai bracket
             if be_verbose:
              print("bracket "+str(alphai1)+","+str(alphai))
             alphak=self._linesearch_zoom(closure,xk,pk,alphai1,alphai,phi_0,gphi_0,sigma,rho,t1,t2,t3,step)
             if be_verbose:
              print("Linesearch: condition 1 met") 
             break

          # evaluate grad(phi(alpha(i))) */
          # note that self._params already is xk+alphai. pk, so only add the missing term
          # xp <- xk+(alphai+step). pk
          self._add_grad(step, pk) #FF param = param - t * grad 
          p01=float(closure())
          # xp <- xk+(alphai-step). pk
          self._add_grad(-2.0*step, pk) #FF param = param - t * grad 
          p02=float(closure())
          gphi_i=(p01-p02)/(2.0*step);
        
          if (abs(gphi_i)<=-sigma*gphi_0):
             alphak=alphai
             if be_verbose:
              print("Linesearch: condition 2 met") 
             break

          if gphi_i>=0.0 :
             # ai=alphai, bi=alphai1 bracket
             if be_verbose:
              print("bracket "+str(alphai)+","+str(alphai1))
             alphak=self._linesearch_zoom(closure,xk,pk,alphai,alphai1,phi_0,gphi_0,sigma,rho,t1,t2,t3,step)
             if be_verbose:
              print("Linesearch: condition 3 met") 
             break
          # else preserve old values
          if (mu<=2.0*alphai-alphai1):
             alphai1=alphai
             alphai=mu
          else:
             # choose by interpolation in [2*alphai-alphai1,min(mu,alphai+t1*(alphai-alphai1)] 
            p01=2.0*alphai-alphai1;
            p02=min(mu,alphai+t1*(alphai-alphai1))
            alphai=self._cubic_interpolate(closure,xk,pk,p01,p02,step)


          phi_alphai1=phi_alphai;
          # update function evals
          closure_evals +=3
          ci=ci+1

          


        # recover original params
        self._copy_params_in(xk)
        # update state
        state['func_evals'] += closure_evals
        return alphak


    def _cubic_interpolate(self,closure,xk,pk,a,b,step):
        """ Cubic interpolation within interval [a,b] or [b,a] (a>b is possible)
          
           Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            xk: copy of parameter values 
            pk: gradient vector 
            a/b:  interval for interpolation
            step: step size for differencing 
        """


        self._copy_params_in(xk)

        # state parameter 
        state = self.state[self._params[0]]
        # count function evals
        closure_evals=0

        # xp <- xk+a. pk
        self._add_grad(a, pk) #FF param = param + t * grad 
        f0=float(closure())
        # xp <- xk+(a+step). pk
        self._add_grad(step, pk) #FF param = param + t * grad 
        p01=float(closure())
        # xp <- xk+(a-step). pk
        self._add_grad(-2.0*step, pk) #FF param = param - t * grad 
        p02=float(closure())
        f0d=(p01-p02)/(2.0*step)

        # xp <- xk+b. pk
        self._add_grad(-a+step+b, pk) #FF param = param + t * grad 
        f1=float(closure())
        # xp <- xk+(b+step). pk
        self._add_grad(step, pk) #FF param = param + t * grad 
        p01=float(closure())
        # xp <- xk+(b-step). pk
        self._add_grad(-2.0*step, pk) #FF param = param - t * grad 
        p02=float(closure())
        f1d=(p01-p02)/(2.0*step)

        closure_evals=6

        aa=3.0*(f0-f1)/(b-a)+f1d-f0d
        p01=aa*aa-f0d*f1d
        if (p01>0.0):
           cc=math.sqrt(p01)
           #print('f0='+str(f0d)+' f1='+str(f1d)+' cc='+str(cc))
           if (f1d-f0d+2.0*cc)==0.0:
             return (a+b)*0.5
           z0=b-(f1d+cc-aa)*(b-a)/(f1d-f0d+2.0*cc)
           aa=max(a,b)
           cc=min(a,b)
           if z0>aa or z0<cc:
             fz0=f0+f1
           else:
             # xp <- xk+(a+z0*(b-a))*pk
             self._add_grad(-b+step+a+z0*(b-a), pk) #FF param = param + t * grad 
             fz0=float(closure())
             closure_evals +=1

           # update state
           state['func_evals'] += closure_evals

           if f0<f1 and f0<fz0:
             return a

           if f1<fz0:
             return b
           # else
           return z0
        else:

           # update state
           state['func_evals'] += closure_evals

           if f0<f1:
             return a
           else:
             return b

        # update state
        state['func_evals'] += closure_evals

        # fallback value
        return (a+b)*0.5
     



    #FF bracket [a,b]
    # xk: copy of parameters, use it to refresh self._param 
    def _linesearch_zoom(self,closure,xk,pk,a,b,phi_0,gphi_0,sigma,rho,t1,t2,t3,step):
        """Zoom step in line search

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
            xk: copy of parameter values 
            pk: gradient vector 
            a/b:  bracket interval for line search, 
            phi_0: phi(0)
            gphi_0: grad(phi(0))
            sigma,rho,t1,t2,t3: line search parameters (from Fletcher) 
            step: step size for differencing 
        """

        # state parameter 
        state = self.state[self._params[0]]
        # count function evals
        closure_evals=0

        aj=a
        bj=b
        ci=0
        found_step=False
        while ci<4: # FIXME original 10
           # choose alphaj from [a+t2(b-a),b-t3(b-a)]
           p01=aj+t2*(bj-aj)
           p02=bj-t3*(bj-aj)
           alphaj=self._cubic_interpolate(closure,xk,pk,p01,p02,step)

           # evaluate phi(alphaj)
           self._copy_params_in(xk)
           # xp <- xk+alphaj. pk
           self._add_grad(alphaj, pk) #FF param = param + t * grad 
           phi_j=float(closure())
          
           # evaluate phi(aj)
           # xp <- xk+aj. pk
           self._add_grad(-alphaj+aj, pk) #FF param = param + t * grad 
           phi_aj=float(closure())

           closure_evals +=2

           if (phi_j>phi_0+rho*alphaj*gphi_0) or phi_j>=phi_aj :
              bj=alphaj # aj is unchanged
           else:
              # evaluate grad(alphaj)
              # xp <- xk+(alphaj+step). pk
              self._add_grad(-aj+alphaj+step, pk) #FF param = param + t * grad 
              p01=float(closure())
              # xp <- xk+(alphaj-step). pk
              self._add_grad(-2.0*step, pk) #FF param = param + t * grad 
              p02=float(closure())
              gphi_j=(p01-p02)/(2.0*step)
        

              closure_evals +=2

              # termination due to roundoff/other errors pp. 38, Fletcher
              if (aj-alphaj)*gphi_j <= step:
                 alphak=alphaj
                 found_step=True
                 break
             
              if abs(gphi_j)<=-sigma*gphi_0 :
                 alphak=alphaj
                 found_step=True
                 break

              if gphi_j*(bj-aj)>=0.0:
                 bj=aj
              # else bj is unchanged
              aj=alphaj


           ci=ci+1
        
        if not found_step:
          alphak=alphaj

        # update state
        state['func_evals'] += closure_evals

        return alphak


    def step(self, closure):
        """Performs a single optimization step.

        Arguments:
            closure (callable): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        group = self.param_groups[0]
        lr = group['lr']
        max_iter = group['max_iter']
        max_eval = group['max_eval']
        tolerance_grad = group['tolerance_grad']
        tolerance_change = group['tolerance_change']
        line_search_fn = group['line_search_fn']
        history_size = group['history_size']

        batch_mode = group['batch_mode']


        # NOTE: LBFGS has only global state, but we register it as state for
        # the first param, because this helps with casting in load_state_dict
        state = self.state[self._params[0]]
        state.setdefault('func_evals', 0)
        state.setdefault('n_iter', 0)


        # evaluate initial f(x) and df/dx
        orig_loss = closure()
        loss = float(orig_loss)
        current_evals = 1
        state['func_evals'] += 1

        flat_grad = self._gather_flat_grad()
        abs_grad_sum = flat_grad.abs().sum()

        if abs_grad_sum <= tolerance_grad:
            return orig_loss

        # tensors cached in state (for tracing)
        d = state.get('d')
        t = state.get('t')
        old_dirs = state.get('old_dirs')
        old_stps = state.get('old_stps')
        H_diag = state.get('H_diag')
        prev_flat_grad = state.get('prev_flat_grad')
        prev_loss = state.get('prev_loss')

        n_iter = 0

        if batch_mode:
          alphabar=lr
          lm0=1e-6

        # optimize for a max of max_iter iterations
        grad_nrm=flat_grad.norm().item()
        while n_iter < max_iter and not math.isnan(grad_nrm):
            # keep track of nb of iterations
            n_iter += 1
            state['n_iter'] += 1

            ############################################################
            # compute gradient descent direction
            ############################################################
            if state['n_iter'] == 1:
                d = flat_grad.neg()
                old_dirs = []
                old_stps = []
                H_diag = 1
                if batch_mode:
                 running_avg=torch.zeros_like(flat_grad.data)
                 running_avg_sq=torch.zeros_like(flat_grad.data)
            else:
                if batch_mode:
                 running_avg=state.get('running_avg')
                 running_avg_sq=state.get('running_avg_sq')
                 if running_avg is None:
                  running_avg=torch.zeros_like(flat_grad.data)
                  running_avg_sq=torch.zeros_like(flat_grad.data)

                # do lbfgs update (update memory) 
                # what happens if current and prev grad are equal, ||y||->0 ??
                y = flat_grad.sub(prev_flat_grad)

                s = d.mul(t)

                if batch_mode: # y = y+ lm0 * s, to have a trust region
                  y.add_(s,alpha=lm0)

                ys = y.dot(s)  # y^T*s
                sn = s.norm().item()  # ||s||
                # FIXME batch_changed does not work for full batch mode (data might be the same)
                batch_changed= batch_mode and (n_iter==1 and state['n_iter']>1)
                if batch_changed: # batch has changed
                   # online estimate of mean,variance of gradient (inter-batch, not intra-batch)
                   # newmean <- oldmean + (grad - oldmean)/niter
                   # moment <- oldmoment + (grad-oldmean)(grad-newmean)
                   # variance = moment/(niter-1)

                   g_old=flat_grad.clone()
                   g_old.add_(running_avg,alpha=-1.0) # grad-oldmean
                   running_avg.add_(g_old,alpha=1.0/state['n_iter']) # newmean
                   g_new=flat_grad.clone()
                   g_new.add_(running_avg,alpha=-1.0) # grad-newmean
                   running_avg_sq.addcmul_(g_new,g_old,value=1) # +(grad-newmean)(grad-oldmean)
                   alphabar=1/(1+running_avg_sq.sum()/((state['n_iter']-1)*(grad_nrm)))
                   if be_verbose:
                     print('iter %d |mean| %f |var| %f ||grad|| %f step %f y^Ts %f alphabar=%f'%(state['n_iter'],running_avg.sum(),running_avg_sq.sum()/(state['n_iter']-1),grad_nrm,t,ys,alphabar))


                if ys > 1e-10*sn*sn and not batch_changed :
                    # updating memory (only when we have y within a single batch)
                    if len(old_dirs) == history_size:
                        # shift history by one (limited-memory)
                        old_dirs.pop(0)
                        old_stps.pop(0)

                    # store new direction/step
                    old_dirs.append(y)
                    old_stps.append(s)

                    # update scale of initial Hessian approximation
                    H_diag = ys / y.dot(y)  # (y*y)

                if math.isnan(H_diag):
                  print('Warning H_diag nan')

                # compute the approximate (L-BFGS) inverse Hessian
                # multiplied by the gradient
                num_old = len(old_dirs)

                if 'ro' not in state:
                    state['ro'] = [None] * history_size
                    state['al'] = [None] * history_size
                ro = state['ro']
                al = state['al']

                for i in range(num_old):
                    ro[i] = 1. / old_dirs[i].dot(old_stps[i])

                # iteration in L-BFGS loop collapsed to use just one buffer
                q = flat_grad.neg()
                for i in range(num_old - 1, -1, -1):
                    al[i] = old_stps[i].dot(q) * ro[i]
                    q.add_(old_dirs[i],alpha=-al[i])

                # multiply by initial Hessian
                # r/d is the final direction
                d = r = torch.mul(q, H_diag)
                for i in range(num_old):
                    be_i = old_dirs[i].dot(r) * ro[i]
                    r.add_(old_stps[i],alpha=al[i] - be_i)

            if prev_flat_grad is None:
                prev_flat_grad = flat_grad.clone()

            else:
                prev_flat_grad.copy_(flat_grad)

            prev_loss = loss

            ############################################################
            # compute step length
            ############################################################
            # reset initial guess for step size
            if state['n_iter'] == 1:
                t = min(1., 1. / abs_grad_sum) * lr
            else:
                t = lr

            # directional derivative
            gtd = flat_grad.dot(d)  # g * d

            if math.isnan(gtd.item()):
              print('Warning grad norm infinite')
              print('iter %d'%state['n_iter'])
              print('||grad||=%f'%grad_nrm)
              print('||d||=%f'%d.norm().item())
            # optional line search: user function
            ls_func_evals = 0
            if line_search_fn:
                # perform line search, using user function
                ##raise RuntimeError("line search function is not supported yet")
                #FF#################################
                # Note: we disable gradient calculation during line search
                # because it is not needed
                torch.set_grad_enabled(False)
                if not batch_mode:
                 t=self._linesearch_cubic(closure,d,1e-6) 
                else:
                 t=self._linesearch_backtrack(closure,d,flat_grad,alphabar)
                torch.set_grad_enabled(True)

                if math.isnan(t):
                  print('Warning: stepsize nan')
                  t=lr
                self._add_grad(t, d) #FF param = param + t * d 
                if be_verbose:
                 print('step size=%f'%(t))
                #FF#################################
            else:
                #FF Here, t = stepsize,  d = -grad, in cache
                # no line search, simply move with fixed-step
                self._add_grad(t, d) #FF param = param + t * d 
            if n_iter != max_iter:
                    # re-evaluate function only if not in last iteration
                    # the reason we do this: in a stochastic setting,
                    # no use to re-evaluate that function here
                    loss = float(closure())
                    flat_grad = self._gather_flat_grad()
                    abs_grad_sum = flat_grad.abs().sum()
                    if math.isnan(abs_grad_sum):
                       print('Warning: gradient nan')
                       break
                    ls_func_evals = 1

            # update func eval
            current_evals += ls_func_evals
            state['func_evals'] += ls_func_evals

            ############################################################
            # check conditions
            ############################################################
            if n_iter == max_iter:
                break

            if current_evals >= max_eval:
                break

            if abs_grad_sum <= tolerance_grad:
                break

            if gtd > -tolerance_change:
                break

            if d.mul(t).abs_().sum() <= tolerance_change:
                break

            if abs(loss - prev_loss) < tolerance_change:
                break

        state['d'] = d
        state['t'] = t
        state['old_dirs'] = old_dirs
        state['old_stps'] = old_stps
        state['H_diag'] = H_diag
        state['prev_flat_grad'] = prev_flat_grad
        state['prev_loss'] = prev_loss

        if batch_mode:
         if 'running_avg' not in locals() or running_avg is None:
           running_avg=torch.zeros_like(flat_grad.data)
           running_avg_sq=torch.zeros_like(flat_grad.data)
         state['running_avg']=running_avg
         state['running_avg_sq']=running_avg_sq
   

        return orig_loss

In [None]:
sdsds

In [None]:
#!/home/youli/miniconda3/bin/python3
# coding=utf8
"""
# Author: youli
# Created Time : 2021-12-27 15:38:05

# File Name: model_construct.py
# Description:
    test for Pytorch model

"""
print(f"pytorch test")

import numpy as np
import pandas as pd
import time
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt



input_size = 20000
train_size = int(input_size*0.9)
test_size  = input_size-train_size
batch_size = 1000

x_total = np.linspace(-1.0, 1.0, input_size, dtype=np.float32)
x_total = np.random.choice(x_total,size=input_size,replace=False) #random sampling
x_train = x_total[0:train_size]
x_train= x_train.reshape((train_size,1))
x_test  = x_total[train_size:input_size]
x_test= x_test.reshape((test_size,1))

x_train=torch.from_numpy(x_train)
x_test=torch.from_numpy(x_test)

y_train = torch.from_numpy(np.sinc(10.0 * x_train))
y_test  = torch.from_numpy(np.sinc(10.0 * x_test))

training_data = TensorDataset(x_train,y_train)
test_data = TensorDataset(x_test,y_test)


# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
for X, y in train_dataloader:
    print("Shape of X: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break




# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.tanh_linear= nn.Sequential(
                nn.Linear(1,20),
                nn.Tanh(),
               # nn.Linear(20,20),
               # nn.Tanh(),
                nn.Linear(20,1),
                )
        return

    def forward(self, x):
        out = self.tanh_linear(x)
        return out

model = NeuralNetwork().to(device)
print(model)

loss_fn = nn.MSELoss()
optimizer_adam = torch.optim.Adam(model.parameters(), lr=1e-2)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)


        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % train_size == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return loss_train

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    print(f"Test Error: \n  Avg loss: {test_loss:>8f} \n")
    return test_loss

# training

opt_label=f'adam_t20'
epochs = 1000
print(f"test for {opt_label}")
optimizer=optimizer_adam
loss_train=[]
loss_test=[]

t1= time.perf_counter()
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss_train+=[
            test(train_dataloader, model, loss_fn)
            ]
    loss_test+=[
            test(test_dataloader, model, loss_fn)
            ]
    print("Done!")

t2= time.perf_counter()
print("Elapsed time: ", t2- t1)
record=pd.DataFrame({
    'epochs':np.arange(epochs)
    ,'loss_train':np.array(loss_train)
    ,'loss_test':np.array(loss_test)
    })
record.to_csv(f"{opt_label}",sep=' ')




torch.save(model.state_dict(), f"model{opt_label}.pth")
print(f"Saved PyTorch Model State to model{opt_label}.pth")

In [None]:
sdfs

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import math
import time

# How many models (==slaves)
K=10
# train K models by Federated learning
# each iteration over a subset of parameters: 1) average 2) pass back average to slaves 3) SGD step
# initialize with pre-trained models (better to use common initialization)
# loop order: loop 0: parameters/layers   {
#               loop 1 : {  averaging (part of the model)
#                loop 2: { epochs/databatches  { train; } } } }
# repeat this Nloop times


torch.manual_seed(69)
# minibatch size
default_batch=128 # no. of batches per model is (50000/K)/default_batch
Nloop=12 # how many loops over the whole network
Nepoch=1 # how many epochs?
Nadmm=5 # how many ADMM iterations

# regularization
lambda1=0.0001 # L1 sweet spot 0.00031
lambda2=0.0001 # L2 sweet spot ?
admm_rho0=0.1 # ADMM penalty, default value
# note that per each slave, and per each layer, there will be a unique rho value

load_model=False
init_model=True
save_model=True
check_results=True
# if input is biased, each 1/K training data will have
# (slightly) different normalization. Otherwise, same normalization
biased_input=True
be_verbose=False

bb_update=False # if true, use adaptive ADMM (Barzilai-Borwein) update
if bb_update:
 #periodicity for the rho update, normally > 1
 bb_period_T=2
 bb_alphacorrmin=0.2 # minimum correlation required before an update is done
 bb_epsilon=1e-3 # threshold to stop updating
 bb_rhomax=0.1 # keep regularization below a safe upper limit


# Set this to true for using ResNet instead of simpler models
# In that case, instead of one layer, one block will be trained
use_resnet=False

# (try to) use a GPU for computation?
use_cuda=True
if use_cuda and torch.cuda.is_available():
  mydevice=torch.device('cuda')
else:
  mydevice=torch.device('cpu')


# split 50000 training data into K subsets (last one will be smaller if K is not a divisor)
K_perslave=math.floor((50000+K-1)/K)
subsets_dict={}
for ck in range(K):
 if K_perslave*(ck+1)-1 <= 50000:
  subsets_dict[ck]=range(K_perslave*ck,K_perslave*(ck+1)-1)
 else:
  subsets_dict[ck]=range(K_perslave*ck,50000)

transforms_dict={}
for ck in range(K):
 if biased_input:
  # slightly different normalization for each subset
  transforms_dict[ck]=transforms.Compose(
   [transforms.ToTensor(),
     transforms.Normalize((0.5+ck/100,0.5-ck/100,0.5),(0.5+ck/100,0.5-ck/100,0.5))])
 else:
  # same normalization for all training data
  transforms_dict[ck]=transforms.Compose(
   [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])


trainset_dict={}
testset_dict={}
trainloader_dict={}
testloader_dict={}
for ck in range(K):
 trainset_dict[ck]=torchvision.datasets.CIFAR10(root='./torchdata', train=True,
    download=True, transform=transforms_dict[ck])
 testset_dict[ck]=torchvision.datasets.CIFAR10(root='./torchdata', train=False,
    download=True, transform=transforms_dict[ck])
 trainloader_dict[ck] = torch.utils.data.DataLoader(trainset_dict[ck], batch_size=default_batch, shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(subsets_dict[ck]),num_workers=1)
 testloader_dict[ck]=torch.utils.data.DataLoader(testset_dict[ck], batch_size=default_batch,
    shuffle=False, num_workers=0)

import numpy as np

# define a cnn
from simple_models import *

net_dict={}

for ck in range(K):
 if not use_resnet:
  net_dict[ck]=Net().to(mydevice)
 else:
  net_dict[ck]=ResNet18().to(mydevice)
 # update from saved models
 if load_model:
   checkpoint=torch.load('./s'+str(ck)+'.model',map_location=mydevice)
   net_dict[ck].load_state_dict(checkpoint['model_state_dict'])
   net_dict[ck].train()

########################################################################### helper functions
from simple_utils import *

def verification_error_check(net_dict):
  for ck in range(K):
   correct=0
   total=0
   net=net_dict[ck]
   for data in testloader_dict[ck]:
     images,labels=data
     outputs=net(Variable(images).to(mydevice))
     _,predicted=torch.max(outputs.data,1)
     correct += (predicted==labels.to(mydevice)).sum()
     total += labels.size(0)

   print('Accuracy of the network %d on the %d test images:%%%f'%
     (ck,total,100*correct//total))
##############################################################################################

if init_model:
  for ck in range(K):
   # note: use same seed for random number generation
   torch.manual_seed(0)
   net_dict[ck].apply(init_weights)

criteria_dict={}
for ck in range(K):
 criteria_dict[ck]=nn.CrossEntropyLoss()

# get layer ids in given order 0..L-1 for selective training
np.random.seed(0)# get same list
Li=net_dict[0].train_order_block_ids()
L=len(Li)

# regularization (per layer, per slave)
# Note: need to scale rho down when starting from scratch  
rho=torch.ones(L,3).to(mydevice)*admm_rho0
# this will be updated when using adaptive ADMM (bb_update=True)


# from lbfgsnew import LBFGSNew # custom optimizer
import torch.optim as optim
############### loop 00 (over the full net)
for nloop in range(Nloop):
  ############ loop 0 (over layers of the network)
  for ci in range(L):
   for ck in range(K):
      unfreeze_one_block(net_dict[ck],ci)
   trainable=filter(lambda p: p.requires_grad, net_dict[0].parameters())
   params_vec1=torch.cat([x.view(-1) for x in list(trainable)])
  
   # number of parameters trained
   N=params_vec1.numel()
   z=torch.empty(N,dtype=torch.float,requires_grad=False).to(mydevice)
   z.fill_(0.0)
   y_dict={}
   for ck in range(K):
      y_dict[ck]=torch.empty(N,dtype=torch.float,requires_grad=False).to(mydevice)
      y_dict[ck].fill_(0.0)

   if bb_update: # extra storage for adaptive ADMM
      yhat_dict={}
      yhat0_dict={}
      x0_dict={}
      for ck in range(K):
         yhat_dict[ck]=torch.empty(N,dtype=torch.float,requires_grad=False).to(mydevice)
         yhat_dict[ck].fill_(0.0)
         x0_dict[ck]=torch.empty(N,dtype=torch.float,requires_grad=False).to(mydevice)
         yhat0_dict[ck]=get_trainable_values(net_dict[ck],mydevice)
      
  
   opt_dict={}
   for ck in range(K):
    opt_dict[ck]=LBFGSNew(filter(lambda p: p.requires_grad, net_dict[ck].parameters()), history_size=10, max_iter=4, line_search_fn=True,batch_mode=True)
    # opt_dict[ck]=optim.Adam(filter(lambda p: p.requires_grad, net_dict[ck].parameters()),lr=0.001)
  
   ############# loop 1 (ADMM for subset of model)
   for nadmm in range(Nadmm):
     ##### loop 2 (data) (all network updates are done per epoch, because K is large
     ##### and data per host is assumed to be small)
     for epoch in range(Nepoch):

        #### loop 3 (models)
        for ck in range(K):
          running_loss=0.0
  
          for i,data1 in enumerate(trainloader_dict[ck],0):
            # get the inputs
            inputs1,labels1=data1
            # wrap them in variable
            inputs1,labels1=Variable(inputs1).to(mydevice),Variable(labels1).to(mydevice)
    
 
            def closure1():
                 if torch.is_grad_enabled():
                    opt_dict[ck].zero_grad()
                 outputs=net_dict[ck](inputs1)
                 # augmented lagrangian terms y^T (x-z) + rho/2 ||x-z||^2
                 trainable=filter(lambda p: p.requires_grad, net_dict[ck].parameters())
                 params_vec1=torch.cat([x.view(-1) for x in list(trainable)])
                 xdelta=params_vec1-z
                 augmented_terms=(torch.dot(y_dict[ck],xdelta))+0.5*rho[ci,0]*(torch.norm(xdelta,2)**2)
                 loss=criteria_dict[ck](outputs,labels1)+augmented_terms
                 if ci in net_dict[ck].linear_layer_ids():
                    loss+=lambda1*torch.norm(params_vec1,1)+lambda2*(torch.norm(params_vec1,2)**2)
                 if loss.requires_grad:
                    loss.backward()
                 return loss
  
            # ADMM step 1
            opt_dict[ck].step(closure1)
  
            # only for diagnostics
            outputs1=net_dict[ck](inputs1)
            loss1=criteria_dict[ck](outputs1,labels1).data.item()
            running_loss +=loss1
           
            if be_verbose:
              print('model=%d block=[%d,%d] %d(%d) minibatch=%d epoch=%d loss %e'%(ck,Li[ci][0],Li[ci][1],nloop,N,i,epoch,loss1))
         
        # ADMM step 2 update global z
        x_dict={}
        for ck in range(K):
          x_dict[ck]=get_trainable_values(net_dict[ck],mydevice)

        # decide and update rho for this ADMM iteration (not the first iteration)
        if bb_update:
          if nadmm==0:
            # store for next use
            for ck in range(K):
              x0_dict[ck]=x_dict[ck]
          elif (nadmm%bb_period_T)==0:
            for ck in range(K):
              yhat_1=y_dict[ck]+rho[ci,0]*(x_dict[ck]-z)
              deltay1=yhat_1-yhat0_dict[ck]
              deltax1=x_dict[ck]-x0_dict[ck]
              # inner products
              d11=torch.dot(deltay1,deltay1)
              d12=torch.dot(deltay1,deltax1) # note: can be negative
              d22=torch.dot(deltax1,deltax1)

              print('admm %d deltas=(%e,%e,%e)'%(nadmm,d11,d12,d22))
              rhonew=rho[ci,0]
              # catch situation where denominator is very small
              if torch.abs(d12).item()>bb_epsilon and d11.item()>bb_epsilon and d22.item()>bb_epsilon:
                 alpha=d12/torch.sqrt(d11*d22)
                 alphaSD=d11/d22
                 alphaMG=d12/d22

                 if 2.0*alphaMG>alphaSD:
                   alphahat=alphaMG
                 else:
                   alphahat=alphaSD-0.5*alphaMG
                 if alpha>=bb_alphacorrmin and alphahat<bb_rhomax: # catches d12 being negative
                   rhonew=alphahat
                 print('admm %d alphas=(%e,%e,%e)'%(nadmm,alpha,alphaSD,alphaMG))

              rho[ci,0]=rhonew
              ###############

              # carry forward current values for the next update
              yhat0_dict[ck]=yhat_1
              x0_dict[ck]=x_dict[ck]


        znew=torch.zeros(x_dict[0].shape).to(mydevice)
        for ck in range(K):
         # sum (y+rho x)
         znew=znew+y_dict[ck]+rho[ci,0]*x_dict[ck]
        znew=znew/(K*rho[ci,0])

        dual_residual=torch.norm(z-znew).item()/N # per parameter
        z=znew

        # -> master will send z to all slaves
        # ADMM step 3 update Lagrange multiplier 
        primal_residual=0.0
        for ck in range(K):
          ydelta=rho[ci,0]*(x_dict[ck]-z)
          primal_residual=primal_residual+torch.norm(ydelta)
          y_dict[ck].add_(ydelta)
        primal_residual=primal_residual/N # per parameter

        print('block=[%d,%d](%d,%f) ADMM=%d/%d primal=%e dual=%e'%(Li[ci][0],Li[ci][1],N,torch.mean(rho).item(),nadmm,nloop,primal_residual,dual_residual))

        if check_results:
          verification_error_check(net_dict)
  

print('Finished Training')


if save_model:
 for ck in range(K):
   torch.save({
     'model_state_dict':net_dict[ck].state_dict(),
     'epoch':epoch,
     'optimizer_state_dict':opt_dict[ck].state_dict(),
     'running_loss':running_loss,
     },'./s'+str(ck)+'.model')

In [None]:
#!/home/youli/miniconda3/bin/python3
# coding=utf8
"""
# Author: youli
# Created Time : 2021-12-27 15:38:05

# File Name: model_construct.py
# Description:
    test for Pytorch model

"""
print(f"pytorch test")

import numpy as np
import pandas as pd
import time
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt



input_size = 20000
train_size = int(input_size*0.9)
test_size  = input_size-train_size
batch_size = 1000

x_total = np.linspace(-1.0, 1.0, input_size, dtype=np.float32)
x_total = np.random.choice(x_total,size=input_size,replace=False) #random sampling
x_train = x_total[0:train_size]
x_train= x_train.reshape((train_size,1))
x_test  = x_total[train_size:input_size]
x_test= x_test.reshape((test_size,1))

x_train=torch.from_numpy(x_train)
x_test=torch.from_numpy(x_test)

y_train = torch.from_numpy(np.sinc(10.0 * x_train))
y_test  = torch.from_numpy(np.sinc(10.0 * x_test))

training_data = TensorDataset(x_train,y_train)
test_data = TensorDataset(x_test,y_test)


# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
for X, y in train_dataloader:
    print("Shape of X: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break




# Get cpu or gpu device for training.
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.tanh_linear= nn.Sequential(
                nn.Linear(1,20),
                nn.Tanh(),
                nn.Linear(20,20),
                nn.Tanh(),
                nn.Linear(20,1),
                )
        return

    def forward(self, x):
        out = self.tanh_linear(x)
        return out

model = NeuralNetwork().to(device)
print(model)

loss_fn = nn.MSELoss()
# optimizer_lbfgs= torch.optim.LBFGS(model.parameters(), lr=1, 
#         history_size=100, max_iter=20,
#         line_search_fn="strong_wolfe"
#         )
optimizer_lbfgs= LBFGSNew(model.parameters(),  
        history_size=100, max_iter=20,
        line_search_fn=True,batch_mode=True
        )

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    lm_lbfgs=model.to(device)
    #spacial function for LBFGS
    for batch, (X, y) in enumerate(dataloader):
        x_ = Variable(X, requires_grad=True)
        y_ = Variable(y)
        def closure():
            # # Zero gradients
            # optimizer.zero_grad()
            # # Forward pass
            # y_pred = lm_lbfgs(x_)
            # # Compute loss
            # loss = loss_fn(y_pred, y_)
            # # Backward pass
            # loss.backward()
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            y_pred = lm_lbfgs(x_)
            loss = loss_fn(y_pred, y_)
            if loss.requires_grad:
                loss.backward()
            return loss

        optimizer.step(closure)
        loss=closure()

        if batch % train_size == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return loss_train

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    print(f"Test Error: \n  Avg loss: {test_loss:>8f} \n")
    return test_loss

# training

opt_label='lbfgsnew-t20-t20'
epochs = 50
print(f"test for {opt_label}")
optimizer=optimizer_lbfgs
loss_train=[]
loss_test=[]

t1= time.perf_counter()
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss_train+=[
            test(train_dataloader, model, loss_fn)
            ]
    loss_test+=[
            test(test_dataloader, model, loss_fn)
            ]
    print("Done!")

t2= time.perf_counter()
print("Elapsed time: ", t2- t1)
record=pd.DataFrame({
    'epochs':np.arange(epochs)
    ,'loss_train':np.array(loss_train)
    ,'loss_test':np.array(loss_test)
    })
record.to_csv(f"{opt_label}",sep=' ')




torch.save(model.state_dict(), f"model{opt_label}.pth")
print(f"Saved PyTorch Model State to model{opt_label}.pth")

In [None]:
#!/home/youli/miniconda3/bin/python
# coding=utf8
"""
# Author: youli
# Created Time : 2021-12-28 20:58:28

# File Name: summary.py
# Description:

"""
print(f"plot test")
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

adam1=pd.read_csv("lbfgs_original-t20",sep=' ')
lbfgs1=pd.read_csv("lbfgsnew-t20",sep=' ')
adam2=pd.read_csv("lbfgs_original-t20-t20",sep=' ')
lbfgs2=pd.read_csv("lbfgsnew-t20-t20",sep=' ')

adam1.loss_train=np.log(adam1.loss_train)
adam2.loss_train=np.log(adam2.loss_train)
lbfgs1.loss_train=np.log(lbfgs1.loss_train)
lbfgs2.loss_train=np.log(lbfgs2.loss_train)


plt.figure(figsize=(6,4),dpi=200)

plt.title(" trainning error in log() ")

lwidth=3.0

plt.plot(adam1.epochs, adam1.loss_train, 'g--', label="lbfgs_original 1-20-1"       ,linewidth=lwidth)
plt.plot(adam2.epochs, adam2.loss_train, 'g', label="lbfgs_original 1-20-20-1"      ,linewidth=lwidth)
plt.plot(lbfgs1.epochs*10, lbfgs1.loss_train, 'r--', label="lbfgsnew 1-20-1" ,linewidth=lwidth)
plt.plot(lbfgs2.epochs*10, lbfgs2.loss_train, 'r', label="lbfgsnew 1-20-20-1",linewidth=lwidth )
#plt.plot(lbfgs1.epochs, lbfgs1.loss_train, 'r--', label="lbfgs 1-20-1" ,linewidth=lwidth)
#plt.plot(lbfgs2.epochs, lbfgs2.loss_train, 'r', label="lbfgs 1-20-20-1",linewidth=lwidth )

plt.text(0.0,-12,"L-BFGSnew epochs*10",color='r',size='large')
plt.xlabel("epochs")
plt.ylabel("log(MAE)")
plt.legend()
plt.show()

In [None]:
saasdas

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import math
import time

# (try to) use a GPU for computation?
use_cuda=True
if use_cuda and torch.cuda.is_available():
  mydevice=torch.device('cuda')
else:
  mydevice=torch.device('cpu')


# try replacing relu with elu
torch.manual_seed(69)
default_batch=128 # no. of batches per epoch 50000/default_batch
batches_for_report=10#

transform=transforms.Compose(
   [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])


trainset=torchvision.datasets.CIFAR10(root='./torchdata', train=True,
    download=True, transform=transform)

trainloader=torch.utils.data.DataLoader(trainset, batch_size=default_batch,
    shuffle=True, num_workers=2)

testset=torchvision.datasets.CIFAR10(root='./torchdata', train=False,
    download=True, transform=transform)

testloader=torch.utils.data.DataLoader(testset, batch_size=default_batch,
    shuffle=False, num_workers=0)


classes=('plane', 'car', 'bird', 'cat', 
  'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



import matplotlib.pyplot as plt
import numpy as np

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


'''ResNet in PyTorch.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
 
From: https://github.com/kuangliu/pytorch-cifar
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.elu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.elu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.elu(self.bn1(self.conv1(x)))
        out = F.elu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.elu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.elu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet9():
    return ResNet(BasicBlock, [1,1,1,1])

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])


# enable this to use wide ResNet
wide_resnet=False
if not wide_resnet:
  net=ResNet18().to(mydevice)
else:
  # use wide residual net https://arxiv.org/abs/1605.07146
  net=torchvision.models.resnet.wide_resnet50_2().to(mydevice)


#####################################################
def verification_error_check(net):
   correct=0
   total=0
   for data in testloader:
     images,labels=data
     outputs=net(Variable(images).to(mydevice))
     _,predicted=torch.max(outputs.data,1)
     correct += (predicted==labels.to(mydevice)).sum()
     total += labels.size(0)

   return 100*correct//total
#####################################################

lambda1=0.000001
lambda2=0.001

# loss function and optimizer
import torch.optim as optim
from lbfgsnew import LBFGSNew # custom optimizer
criterion=nn.CrossEntropyLoss()
#optimizer=optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#optimizer=optim.Adam(net.parameters(), lr=0.001)
optimizer = LBFGSNew(net.parameters(), history_size=7, max_iter=2, line_search_fn=True,batch_mode=True)


load_model=False
# update from a saved model 
if load_model:
  checkpoint=torch.load('./res18.model',map_location=mydevice)
  net.load_state_dict(checkpoint['model_state_dict'])
  net.train() # initialize for training (BN,dropout)

start_time=time.time()
use_lbfgs=True
# train network
for epoch in range(20):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    # get the inputs
    inputs,labels=data
    # wrap them in variable
    inputs,labels=Variable(inputs).to(mydevice),Variable(labels).to(mydevice)

    if not use_lbfgs:
     # zero gradients
     optimizer.zero_grad()
     # forward+backward optimize
     outputs=net(inputs)
     loss=criterion(outputs,labels)
     loss.backward()
     optimizer.step()
    else:
      if not wide_resnet:
        layer1=torch.cat([x.view(-1) for x in net.layer1.parameters()])
        layer2=torch.cat([x.view(-1) for x in net.layer2.parameters()])
        layer3=torch.cat([x.view(-1) for x in net.layer3.parameters()])
        layer4=torch.cat([x.view(-1) for x in net.layer4.parameters()])

      def closure():
        if torch.is_grad_enabled():
         optimizer.zero_grad()
        outputs=net(inputs)
        if not wide_resnet:
          l1_penalty=lambda1*(torch.norm(layer1,1)+torch.norm(layer2,1)+torch.norm(layer3,1)+torch.norm(layer4,1))
          l2_penalty=lambda2*(torch.norm(layer1,2)+torch.norm(layer2,2)+torch.norm(layer3,2)+torch.norm(layer4,2))
          loss=criterion(outputs,labels)+l1_penalty+l2_penalty
        else:
          l1_penalty=0
          l2_penalty=0
          loss=criterion(outputs,labels)
        if loss.requires_grad:
          loss.backward()
          #print('loss %f l1 %f l2 %f'%(loss,l1_penalty,l2_penalty))
        return loss
      optimizer.step(closure)
    # only for diagnostics
    outputs=net(inputs)
    loss=criterion(outputs,labels)
    running_loss +=loss.data.item()

    if math.isnan(loss.data.item()):
       print('loss became nan at %d'%i)
       break

    # print statistics
    if i%(batches_for_report) == (batches_for_report-1): # after every 'batches_for_report'
      print('%f: [%d, %5d] loss: %.5f accuracy: %.3f'%
         (time.time()-start_time,epoch+1,i+1,running_loss/batches_for_report,
         verification_error_check(net)))
      running_loss=0.0

print('Finished Training')


# save model (and other extra items)
torch.save({
            'model_state_dict':net.state_dict(),
            'epoch':epoch,
            'optimizer_state_dict':optimizer.state_dict(),
            'running_loss':running_loss,
           },'./res.model')


# whole dataset
correct=0
total=0
for data in trainloader:
   images,labels=data
   outputs=net(Variable(images).to(mydevice)).cpu()
   _,predicted=torch.max(outputs.data,1)
   total += labels.size(0)
   correct += (predicted==labels).sum()
   
print('Accuracy of the network on the %d train images: %d %%'%
    (total,100*correct//total))

correct=0
total=0
for data in testloader:
   images,labels=data
   outputs=net(Variable(images).to(mydevice)).cpu()
   _,predicted=torch.max(outputs.data,1)
   total += labels.size(0)
   correct += (predicted==labels).sum()
   
print('Accuracy of the network on the %d test images: %d %%'%
    (total,100*correct//total))


class_correct=list(0. for i in range(10))
class_total=list(0. for i in range(10))
for data in testloader:
  images,labels=data
  outputs=net(Variable(images).to(mydevice)).cpu()
  _,predicted=torch.max(outputs.data,1)
  c=(predicted==labels).squeeze()
  for i in range(4):
    label=labels[i]
    class_correct[label] += c[i]
    class_total[label] += 1

for i in range(10):
  print('Accuracy of %5s : %2d %%' %
    (classes[i],100*float(class_correct[i])/float(class_total[i])))

In [None]:
!ls

In [None]:
a =[1,2,3,4,5,6,7,7,7,7,7,7,7]
a = torch.from_numpy(np.array(a))
b = a.t()
new = torch.matmul(a, b)
print(new)