## AMSGrad

Ref: <a href='https://arxiv.org/abs/1904.09237'>On the Convergence of Adam and Beyond</a>

AMSGrad는 Adam 옵티마이저의 변형으로, Adam이 일부 경우에 수렴 문제를 겪는 것을 해결하기 위해 개발되었다. AMSGrad는 이전 Gradient의 최댓값을 이용하여 보다 안정적인 학습률을 제공함으로써, 학습 과정의 안정성을 높이고 더 나은 수렴 성능을 달성하는 것을 목표로 한다.

여기선 이전에 서술했던 Adam 옵티마이저를 상속 받아 사용하겠다.

In [2]:
from typing import Dict, Any, Tuple, Optional

import torch
from torch import nn
from torch._tensor import Tensor

from optimizers import WeightDecay, Adam

In [14]:
class AMSGrad(Adam):
    def __init__(self, params, lr: float = 1e-3, 
                 betas: Tuple[float, float] = (0.9, 0.999), 
                 eps: float = 1e-16, weight_decay: WeightDecay = WeightDecay(), 
                 optimized_update: bool = True, amsgrad: bool = True,
                 defaults: Optional[Dict[str, Any]] = None):
        '''
        Initialize the Optimizer
            - params: the list of parameters
            - lr: learning rate alpha
            - betas: tuple of (beta_1, beta_2)
            - eps: epsilon
            - weight_decay: instance of class WeightDecay
            - optimized_update: a flag whether to optimize the bias correction 
                                of the second moment by doing it after adding epsilon
            - amsgrad: a flag indicating whether to use AMSGrad or fallback to plain Adam
            - defaults: a dict of default for group values
        '''
        defaults = {} if defaults is None else defaults
        defaults.update(dict(amsgrad=amsgrad))
        
        super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
        
    def init_state(self, state: Dict[str, Any], group: Dict[str, Any], param: nn.Parameter):
        '''
        Initialize a parameter state
            - state: the optimizer state of the parameter (tensor)
            - group: stores optimizer attributes of the parameter group
            - param: the parameter tensor theta at t-1
        '''
        # Enteding Adam opt
        super().init_state(state, group, param)
        
        # if amsgrad = True, maintain the maximum of exponential moving average of squared gradient
        if group['amsgrad']:
            state['max_exp_avg_sqrd'] = torch.zeros_like(param, memory_format=torch.preserve_format)
    
    def calc_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: Tensor):
        '''
        Calculate m_t and v_t or max(v1, v2, ..., vt)
            - state: the optimizer state of the parameter (tensor)
            - group: stores optimizer attributes of the parameter group
            - grad: current gradient tensor g_t for theta at t-1
        '''
        m, v = super().calc_mv(state, group, grad)
        
        # if amsgrad, get max(v1, v2, ..., vt)
        if group['amsgrad']:
            v_max = state['max_exp_avg_sqrd']
            torch.maximum(v_max, v, out=v_max)
            return m, v_max
        else:
            return m, v

### Synthetic Experiment

다음은 AMSGrad 논문에서 제시된 가상의 시나리오로, Adam이 실패하는 상황을 보여준다.

$$
f_t(x) = 
\begin{cases} 
1010x, & \text{for } t \mod 101 = 1 \\
-10x, & \text{otherwise}
\end{cases}
~~where~-1\leq x\leq +1
$$

여기서 optimal solution은 $x=-1$이며, 옵티마이저의 performance는 다음(`regret`)을 이용해 측정한다.

$$
R(T) = \sum^T_{t=1} [f_t(\theta_t)-f_t(\theta^*)]
$$

In [11]:
def _adam_exp(is_adam: bool):
    x = nn.Parameter(torch.tensor([.0]))
    
    # optimal: x^* = -1
    x_prime = nn.Parameter(torch.tensor([-1]), requires_grad=False)
    
    # f_t(x)
    def func(t: int, x_: nn.Parameter):
        if t % 101 == 1:
            return (1010 * x_).sum()
        else:
            return (-10 * x_).sum()
    
    # Initialize optimizer
    if is_adam:
        optimizer = Adam([x], lr=1e-2, betas=(0.9,0.99))
    else:
        optimizer = AMSGrad([x], lr=1e-2, betas=(0.9,0.99))
        
    # R(T)
    total_regret = 0
    from labml import monit, tracker, experiment
    
    with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):
        for step in monit.loop(10_000_000):
            # f_t(theta_t) - f_t(theta^prime)
            regret = func(step, x) - func(step, x_prime)
            total_regret += regret.item()
            
            if (step+1) % 1000 == 0:
                tracker.save(loss=regret, x=x, regret=total_regret / (step+1))
            
            # calculate gradients
            regret.backward()
            
            # optimize
            optimizer.step()
            optimizer.zero_grad()
            
            # -1 <= x <= +1
            x.data.clamp_(-1., +1.)

In [8]:
# Adam
_adam_exp(True)

In [15]:
# AMSGrad
_adam_exp(False)