In [0]:
import torch
from torch.optim.optimizer import Optimizer
from math import ceil

In [0]:
class GradSliding(Optimizer):
    """Lan's Gradient Sliding algorithm."""
    def __init__(self, params, L, M, D_tilde):
        defaults = dict(L=L, M=M, D_tilde=D_tilde)
        super().__init__(params, defaults)
        self.k = 0
        self.t = 0
        self.mode = 'main'
    
    def compute_gamma(self):
        """Compute gamma according to formula (8.1.42) in Lan's book."""
        self.gamma = 3 / (self.k + 2)
        self.gamma_next = 3 / (self.k + 3)
    
    def compute_T(self):
        """Compute T according to formula (8.1.42) in Lan's book."""
        L = self.defaults['L']
        M = self.defaults['M']
        D_tilde = self.defaults['D_tilde']
        T = ceil(M**2 * (self.k + 1)**3 / (D_tilde * L**2))
        self.T = int(T)
        print("T=", T)
    
    def compute_P(self):
        """Compute P_Tk according to formula (8.1.44) in Lan's book."""
        self.P = 2 / ((self.T + 1) * (self.T + 2))
    
    def compute_theta(self):
        """Compute theta according to formula (8.1.39) in Lan's book."""
        self.theta = 2 * (self.t + 1) / (self.t * (self.t + 3))
    
    def compute_p(self):
        """Compute p according to formula (8.1.39) in Lan's book."""
        self.p = self.t / 2
    
    def compute_beta(self):
        """Compute beta according to formula (8.1.42) in Lan's book."""
        L = self.defaults['L']
        self.beta = 9 * L * (1 - self.P) / (2 * (self.k + 1))

    @torch.no_grad()
    def step(self, closure=None):
        """Perform Gradient Sliding step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        # Part of main loop before PS (prox-sliding) procedure.
        # In this branch, par is x_underbar in notation of Lan's book.
        if self.mode == 'main':
            self.k += 1
            self.compute_gamma()
            self.compute_T()
            self.compute_P()
            self.compute_beta()
            for group in self.param_groups:
                for par in group['params']:
                    if par.grad is None:
                        continue
                    
                    state = self.state[par]
                    # State initialization.
                    if len(state) == 0:
                        state['x'] = par.clone()
                        state['x_bar'] = par.clone()
                    
                    state['df_x'] = par.grad
                    # At the beginning of PS procedure, gradient of h
                    # will be calculated at u0 = x.
                    par_copy = par.clone()
                    par.add_(state['x'])
                    par.sub_(par_copy)
            self.mode = 'PS'
                
        # PS procedure.
        # In this branch, par is u in notation of Lan's book.
        elif self.mode == 'PS':
            self.t += 1
            self.compute_p()
            self.compute_theta()
            for group in self.param_groups:
                for par in group['params']:
                    if par.grad is None:
                        continue
                    
                    state = self.state[par]
                    if self.t == 1:
                        state['u_tilde'] = par.clone()
                    
                    dh_u = par.grad

                    # Formula (1) from our report.
                    numerator = self.beta * (state['x'] + self.p * par) \
                              - state['df_x'] - dh_u
                    u = numerator / (self.beta * (1 + self.p))
                    par_copy = par.clone()
                    par.add_(u)
                    par.sub_(par_copy)
                    
                    state['u_tilde'] = (1 - self.theta) * state['u_tilde'] \
                                     + self.theta * par
                    
                    if self.t % self.T == 0:
                        # Finish PS procedure.
                        state['x'] = par
                        state['x_tilde'] = state['u_tilde']
                
                        # Part of main loop after PS procedure.
                        state['x_bar'] = (1 - self.gamma) * state['x_bar'] \
                                       + self.gamma * state['x_tilde']
                        # Beginning of main loop of new iteration.
                        # Now par is again x_underbar.
                        par_copy = par.clone()
                        x_underbar = (1 - self.gamma_next) * state['x_bar'] \
                            + self.gamma_next * state['x']
                        par.add_(x_underbar)
                        par.sub_(par_copy)
            if self.t % self.T == 0:
                self.t = 0
                self.mode = 'main'
                
        return loss

In [109]:
from scipy.sparse.linalg import svds
import numpy as np

n_obj = 100
n_feat = 10
noise_std = 0.01
reg_coef = 1.

np.random.seed(0)
A = np.random.rand(n_obj, n_feat)
np.random.seed(0)
x_true = np.random.rand(n_feat)
np.random.seed(0)
b = A @ x_true + noise_std * np.random.rand(n_obj)

L = 2 * reg_coef / n_obj
max_norm = np.linalg.norm(b) / np.sqrt(reg_coef)
M = 2 * np.linalg.norm(A, ord=2)**2 * diam / n_obj
M /= 1000
D_tilde = max_norm**2 * 3 / 2
print(f"L={L:.3f}, M={M:.1f}, D_tilde={D_tilde:.1f}")

L=0.020, M=1.0, D_tilde=1470.5


In [0]:
X_train = torch.tensor(A, dtype=torch.float)
y_train = torch.tensor(b.reshape(-1, 1), dtype=torch.float)

In [0]:
import torch.nn as nn

model = nn.Linear(10, 1)
opt = GradSliding(model.parameters(), L, M, D_tilde)

loss1 = nn.MSELoss()
loss2 = nn.MSELoss()

In [112]:
for i in range(1000):

    y_pred = model(X_train)
    if opt.mode == 'main':
        # f
        loss = reg_coef * loss1(y_pred, torch.zeros_like(y_pred))
        with torch.no_grad():
            sum_loss = loss2(y_pred, y_train) \
                     + reg_coef * loss1(y_pred, torch.zeros_like(y_pred))
        print(f"k = {opt.k}, sum loss = {sum_loss:.2f}")
    else:
        # h
        loss = loss2(y_pred, y_train)
    opt.zero_grad()
    loss.backward()
    opt.step()

k = 0, sum loss = 8.11
T= 15
k = 1, sum loss = inf
T= 49
k = 2, sum loss = nan
T= 114
k = 3, sum loss = nan
T= 223
k = 4, sum loss = nan
T= 385
k = 5, sum loss = nan
T= 610
