In [0]:
import torch
from torch.optim.optimizer import Optimizer
from math import ceil

In [0]:
class GradSliding(Optimizer):
    """Lan's Gradient Sliding algorithm."""
    def __init__(self, params, L, M, D_tilde):
        defaults = dict(L=L, M=M, D_tilde=D_tilde)
        super().__init__(params, defaults)
    
    def compute_gamma(self, k):
        """Compute gamma according to formula (8.1.42) in Lan's book."""
        return 3 / (k + 2)
    
    def compute_T(self, k):
        """Compute T according to formula (8.1.42) in Lan's book."""
        L = self.defaults['L']
        M = self.defaults['M']
        D_tilde = self.defaults['D_tilde']
        P = ceil(M**2 * (k + 1)**3 / (D_tilde * L**2))
        return int(P)
    
    def compute_P(self, t):
        """Compute P according to formula (8.1.44) in Lan's book."""
        return 2 / ((t + 1) * (t + 2))
    
    def compute_theta(self, t):
        """Compute theta according to formula (8.1.39) in Lan's book."""
        return 2 * (t + 1) / (t * (t + 3))
    
    def compute_p(self, t):
        """Compute p according to formula (8.1.39) in Lan's book."""
        return t / 2
    
    def compute_beta(self, P, k):
        """Compute beta according to formula (8.1.42) in Lan's book."""
        L = self.defaults['L']
        return 9 * L * (1 - P) / (2 * (k + 1))

    @torch.no_grad()
    def step(self, closure=None):
        """Perform Gradient Sliding step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            for par in group['params']:
                if par.grad is None:
                    continue
                
                state = self.state[par]

                # State initialization.
                if len(state) == 0:
                    state['k'] = 0
                    state['t'] = 0
                    state['mode'] = 'main'
                    state['x'] = par.clone()
                    state['x_bar'] = par.clone()
                
                # Part of main loop before PS (prox-sliding) procedure.
                # In this branch, par is x_underbar in notation
                # of Lan's book.
                if state['mode'] == 'main':
                    state['k'] += 1
                    state['gamma'] = self.compute_gamma(state['k'])
                    gamma = state['gamma']
                    par = (1 - gamma) * state['x_bar'] + gamma * state['x']
                    state['df_x'] = par.grad
                    state['mode'] = 'PS'
                    # At the beginning of PS procedure, gradient of h
                    # will be calculated at u0 = x.
                    par = state['x']
                    
                    state['T'] = self.compute_T(state['k'])
                    P = self.compute_P(state['T'])
                    state['beta'] = self.compute_beta(P, state['k'])
                
                # PS procedure.
                # In this branch, par is u in notation of Lan's book.
                elif state['mode'] == 'PS':
                    if state['t'] == 0:
                        state['u_tilde'] = par.clone()
                    state['t'] += 1
                    
                    dh_u = par.grad

                    p = self.compute_p(state['t'])
                    beta = state['beta']

                    # Formula (1) from our report.
                    par = (beta*(state['x'] + p*par) - state['df_x'] - dh_u) \
                        / (beta*(1 + p))
                    theta = self.compute_theta(state['t'])
                    state['u_tilde'] = (1 - theta) * state['u_tilde'] \
                                     + theta * par
                    
                    if state['t'] % state['T'] == 0:
                        # Finish PS procedure.
                        state['t'] = 0
                        state['x'] = par
                        state['x_tilde'] = state['u_tilde']
                
                        # Part of main loop after PS procedure.
                        gamma = state['gamma']
                        state['x_bar'] = (1 - gamma) * state['x_bar'] \
                                       + gamma * state['x_tilde']
                        state['mode'] = 'main'
                        # At the beginning of main loop, gradient of f
                        # will be calculated at x_underbar.
                        par = state['x_underbar']
                
        return loss

In [0]:
# common training loop
for batch_n, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)
    y_pred = model(x_batch)
    
    loss = criterion(y_pred, y_batch)
    opt.zero_grad()
    loss.backward()
    opt.step()

In [0]:
# what if our loss is a sum of two terms?