In [0]:
import torch
from torch.optim.optimizer import Optimizer

In [0]:
class GradSliding(Optimizer):
    """Lan's Gradient Sliding algorithm."""
    def __init__(self, params, beta, gamma, T):
        defaults = dict(beta=beta, gamma=gamma, T=T)
        super().__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)

    @torch.no_grad()
    def step(self, closure=None):
        """Perform Gradient Sliding step."""
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            for par in group['params']:
                if par.grad is None:
                    continue
                
                state = self.state[par]

                # State initialization
                if len(state) == 0:
                    state['step'] = 1
                    state['mode'] = 'main'
                    state['x'] = par.clone()
                    state['x_bar'] = par.clone()
                    state['x_underbar'] = par.clone()
                    state['x_tilde'] = par.clone()
                
                # Part of main loop before PS (prox-sliding) procedure
                if state['mode'] == 'main':
                    gamma = group['gamma']
                    state['x_underbar'] = (1 - gamma) * state['x_bar'] \
                                        + gamma * state['x']
                    state['df_x'] = par.grad
                    state['mode'] = 'PS'
                    # At the beginning of PS procedure, gradient will be
                    # calculated at u0 = x
                    par = state['x']
                
                # PS procedure
                elif state['mode'] == 'PS':
                    if state['step'] % T == 1:
                        state['u'] = par.clone()
                        state['u_tilde'] = par.clone()
                    dh_u = par.grad
                    
                    beta = group['beta']
                    p = group['p']
                    theta = group['theta']

                    par = (beta*(state['x'] + p*par) - state['df_x'] - dh_u) \
                        / (beta*(p + 1))
                    state['u_tilde'] = (1 - theta) * state['u_tilde'] \
                                     + theta * par
                    
                    # finish PS procedure
                    if state['step'] % T == 0:
                        state['x'] = par
                        state['x_tilde'] = state['u_tilde']
                
                        # Part of main loop after PS procedure
                        gamma = group['gamma']
                        state['x_bar'] = (1 - gamma) * state['x_bar'] \
                                       + gamma * state['x_tilde']
                        state['mode'] = 'main'
                        # At the beginning of main loop, gradient will be
                        # calculated at x_underbar
                        par = state['x_underbar']
                    state['step'] += 1
                
        return loss

In [0]:
# common training loop
for batch_n, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)
    y_pred = model(x_batch)
    
    loss = criterion(y_pred, y_batch)
    opt.zero_grad()
    loss.backward()
    opt.step()

In [0]:
# what if our loss is a sum of two terms?