<img src="../rsag_convex.png" alt="algoconvex" />
<img src="../x_update.png" alt="x_update" />
<img src="../mean.png" alt="mean" />
<img src="../rsag_composite.png" alt="algo" />

__Parameters :__
- $\alpha$: (1-$\alpha$) weight of aggregated x on current state, i.e. momentum
- $\lambda$: learning rate
- $\beta$: change for aggregated x
- $p_k$ termination probability



In [None]:
from torch.optim.optimizer import Optimizer, required
import copy

class RSAG(Optimizer):
    r"""
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate (lambda) (required)
        kappa (float): lambda  (default: 1000)
        xi (float, optional): statistical advantage parameter (default: 10)
        smallConst (float, optional): any value <=1 (default: 0.7)
    Example:
        >>> from RSAG import *
        >>> optimizer = RSAG(model.parameters(), lr=0.1, kappa = 1000.0, xi = 10.0)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    """

    def __init__(self, params, lr=0.01, alpha = 0.1, beta = 0.1): #, smallConst = 0.7, weight_decay=0):
        #defaults = dict(lr=lr, kappa=kappa, xi, smallConst=smallConst,
                        # weight_decay=weight_decay)
        defaults = dict(lr=lr, alpha=alpha, beta=beta)
        super(RSAG, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RSAG, self).__setstate__(state)

    def step(self, closure=None):
        """ Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            lr = group['lr']
            alpha, beta = group['alpha'], group['beta']

            
            # Alpha = 1.0 - ((group['smallConst']*group['smallConst']*group['xi'])/group['kappa'])
            # Beta = 1.0 - Alpha
            # zeta = group['smallConst']/(group['smallConst']+Beta)
            for p in group['params']:
                if p.grad is None:
                    continue
                grad_d = p.grad.data

                # if weight_decay != 0:
                #     grad_d.add_(weight_decay, p.data)
                param_state = self.state[p]
                if 'momentum_aggr' not in param_state:
                    param_state['momentum_aggr'] = copy.deepcopy(p.data)
                aggr = param_state['momentum_aggr']
                aggr.mul_((1.0/Beta)-1.0)
                aggr.add_(-large_lr,d_p)
                aggr.add_(p.data)
                aggr.mul_(Beta)

                p.data.add_(-group['lr'],d_p)
                p.data.mul_(zeta)
                p.data.add_(1.0-zeta,buf)

        return loss

## Initialize Optimizer and test

optim = Adan(
    model.parameters(),
    lr = 1e-3,                  # learning rate (can be much higher than Adam, up to 5-10x)
    betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
    weight_decay = 0.02         # weight decay 0.02 is optimal per author
)