<a href="https://www.kaggle.com/code/aisuko/implement-stochastic-optimisation-with-pytorch?scriptVersionId=164626041" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

Adam stands for A method for stochastic optimization. According to it's paper, it is an algorithm for first-order gradient-based optimization of stochastis objective functions, and it is based on adaptive estimates of lower-order moments. Furthermore, it is straightforward to implement, is computationally efficient, has little memory requirements, is invariant to diagonal rescaling of the greadients, and is well suited for problems that are large in terms of data and/or parameters. And it is also appropriate for non-stationary objectices and problems with very noisy and/or sparse gradients. So, let's look at how this is implemented in PyTorch.


# Algorithm Description

<div style="text-align: center"><img src="https://hostux.social/system/media_attachments/files/111/533/123/322/407/873/original/c0b7ea2c4824b18d.png" width="80%" heigh="50%" alt="multiplication_attention_scores"><figure>Source: Paper:A method for Stochastic Optimization</figure></div>

Adam update is,

\begin{align}
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\
v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\
\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\
\hat{v}_t &\leftarrow \frac{v_t}{1-\beta_2^t} \\
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
\end{align}

where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters.
$m_t$ and $v_t$ are first and second order moments.
$\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments.
$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a gyper-parameter that acts against variance in gradients.


Effective step taken assuming $\epsilon=0$ is,
$$\Delta t= \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$

This is bounded by,
$$\vert \Delta t \vert \le \alpha \cdot \frac{1-\beta_1}{\sqrt{1-\beta_2}}$$
when $1 -\beta_1 \gt \sqrt{1-\beta_2}$
and
$$ \vert \Delta t\vert \le \alpha$$
otherwise.
And in most common scenarios, $$\vert \Delta t \vert \approx \alpha$$



In [1]:
from typing import List, Optional
from torch import Tensor
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

# Implement

Here are the description of the input parameters:

* `params` is the list of parameters to optimize or dicts defining parameter groups
* `lr` is the learning rate $\alpha$
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
* `eps` is $\hat{\epsilon}$ or $\epsilon$, term added to the denominator to improve numerical stability
* `weight_decay` weight decay(L2 penalty)
* `amsgrad` whether to use the AMSGrad variant of this algorithm form the paper On the Convergence of Adam and Beyond
* `foreach` whether foreach implementation of optimizer is used.
* `maximize` maximize the params based on the objective, instead of minimizing
* `capturable` whether this instance is safe to capture in a CUDA graph.
* `differentiable` whether autograd should occur through the optimizer step in training.
* `fused` whether the fused implementation is used.

In [2]:
from torch.optim.optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling, _dispatch_sqrt, _default_to_fused_or_foreach)

class Adam(Optimizer):

    def __init__(
        self, params, lr=1e-3, betas=(0.9,0.999), 
        eps=1e-8, weight_decay=0, amsgrad=False, *,
        foreach:Optional[bool]=None, maximize: bool=False, 
        capturable:bool=False, 
        differentiable:bool=False, fused:Optional[bool]=None):
        
        
        if not 0.0<=lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <=eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        
        defaults=dict(
            lr=lr, betas=betas, eps=eps, 
            weight_decat=weight_decay, amsgrad=amdgrad,
            maximize=maximize, foreach=foreach, capturable=capturable,
            differentiable=differentiable, fused=fused)

        super().__init__(params, defaults)
        
        if fused:
            if differentiable:
                raise RuntimeError("`fused` does not support `differentiable`")
            self._step_supports_amp_scaling=True
            if not all(p.is_cuda and torch.is_floating_point(p) 
                       for pg in self.param_groups 
                       for p in pg['params']):
                raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
            if foreach:
                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
    
    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)
            group.setdefault('maximize', False)
            group.setdefault('foreach', None)
            group.setdefault('capturable', False)
            group.setdefault('differentiable', False)
            group.setdefault('fused', None)
        state_values=list(self.state.values())
        step_is_tensor=(len(state_values)!=0 and torch.is_tensor(state_values[0]['step']))
        if not step_is_tensor:
            for s in state_values:
                s['step']=torch.tensor(float(s['step']))
    
    
    def _init_group(
        self, group, params_with_grad, 
        grads, exp_avgs, exp_avg_sqs, 
        max_exp_avg_sqs, state_steps):
        

        for p in group['params']:
            if p.grad is not None:
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients, please consider SparseAdam instead')
                grads.append(p.grad)
                
                state=self.state[p]
                
                # Lazy state initialization
                if len(state)==0:
                    state['step']=(
                        torch.zeros((1,), dtype=torch.float, device=p.device)
                        if group['capturable'] or group['fused'] else torch.tensor(0.)
                    )
                    # Exponential moving average of gradient values
                    state['exp_avg'] =torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq']=torch.zeros_like(p, memory_format=torch.preserve_format)
                    if group['amsgrad']:
                        # Maintains max of all exp, moving avg of sq. grad. values
                        state['max_exp_avg_sq']=torch.zeros_like(
                            p, 
                            memory_format=torch.preserve_format)
                
                exp_avgs.append(state['exp_avg'])
                exp_avg_sqs.append(state['exp_avg_sq'])
                
                if group['amsgrad']:
                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])
                if group['differentiable'] and state['step'].requires_grad:
                    raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode')
                state_steps.append(state['step'])
    
    @_use_grad_for_differentiable
    def step(self, closure=None):
        """
        Performs a single optimization step.
        
        Args:
            closure (Callable, optional): A closure that revaluates the model and return the loss
        """
        
        self._cuda_graph_capture_health_check()
        
        loss=None
        if closure is not None:
            with torch.enable_grad():
                loss=closure()
                
        for group in self.param_groups:
            params_with_grad=[]
            grads=[]
            exp_avgs=[]
            exp_avg_sqs=[]
            max_exp_avg_sqs=[]
            state_steps=[]
            beta1, beta2=group['betas']
            
            self._init_group(
                group,
                params_with_grad,
                grads,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps
            )
            
            adam(
                params_with_grad,
                grads,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
                amsgrad=group['amsgrad'],
                beta1=beta1,
                beta2=beta2,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                eps=group['eps'],
                maximize=group['maximize'],
                foreach=group['foreach'],
                capturable=group['capturable'],
                differentiable=group['differentiable'],
                fused=group['fused'],
                grad_scale=getattr(self, "grad_scale", None),
                found_inf=getattr(self, "found_inf", None),
            )
            
        return loss


In [3]:
def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[Tensor],
         foreach: Optional[bool] = None,
         capturable: bool = False,
         differentiable: bool = False,
         fused: Optional[bool] = None,
         grad_scale: Optional[Tensor] = None,
         found_inf: Optional[Tensor] = None,
         *,
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         maximize: bool
        ):
    
    if fused is None and foreach is None:
        _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
    if fused is None:
        fused = False
    if foreach is None:
        foreach = False

    if not all(isinstance(t, torch.Tensor) for t in state_steps):
        raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")

    if foreach and torch.jit.is_scripting():
        raise RuntimeError('torch.jit.script not supported with foreach optimizers')

    if fused and not torch.jit.is_scripting():
        func = _fused_adam
    elif foreach and not torch.jit.is_scripting():
        func = _multi_tensor_adam
    else:
        func = _single_tensor_adam

    func(params,
         grads,
         exp_avgs,
         exp_avg_sqs,
         max_exp_avg_sqs,
         state_steps,
         amsgrad=amsgrad,
         beta1=beta1,
         beta2=beta2,
         lr=lr,
         weight_decay=weight_decay,
         eps=eps,
         maximize=maximize,
         capturable=capturable,
         differentiable=differentiable,
         grad_scale=grad_scale,
         found_inf=found_inf)

In [4]:
def _single_tensor_adam(params: List[Tensor],
                        grads: List[Tensor],
                        exp_avgs: List[Tensor],
                        exp_avg_sqs: List[Tensor],
                        max_exp_avg_sqs: List[Tensor],
                        state_steps: List[Tensor],
                        grad_scale: Optional[Tensor],
                        found_inf: Optional[Tensor],
                        *,
                        amsgrad: bool,
                        beta1: float,
                        beta2: float,
                        lr: float,
                        weight_decay: float,
                        eps: float,
                        maximize: bool,
                        capturable: bool,
                        differentiable: bool):

    assert grad_scale is None and found_inf is None

    for i, param in enumerate(params):

        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

        if capturable:
            assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."

        # update step
        step_t += 1

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        if torch.is_complex(param):
            grad = torch.view_as_real(grad)
            exp_avg = torch.view_as_real(exp_avg)
            exp_avg_sq = torch.view_as_real(exp_avg_sq)
            param = torch.view_as_real(param)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)

        if capturable or differentiable:
            step = step_t

            # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor
            # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing")
            bias_correction1 = 1 - torch.pow(beta1, step)
            bias_correction2 = 1 - torch.pow(beta2, step)

            step_size = lr / bias_correction1
            step_size_neg = step_size.neg()

            bias_correction2_sqrt = bias_correction2.sqrt()

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                if differentiable:
                    max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone()
                else:
                    max_exp_avg_sqs_i = max_exp_avg_sqs[i]
                max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq))
                # Uses the max. for normalizing running avg. of gradient
                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
                denom = (max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
            else:
                denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)

            param.addcdiv_(exp_avg, denom)
        else:
            step = _get_value(step_t)

            bias_correction1 = 1 - beta1 ** step
            bias_correction2 = 1 - beta2 ** step

            step_size = lr / bias_correction1

            bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
                # Use the max. for normalizing running avg. of gradient
                denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
            else:
                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

            param.addcdiv_(exp_avg, denom, value=-step_size)

In [5]:
def _multi_tensor_adam(params: List[Tensor],
                       grads: List[Tensor],
                       exp_avgs: List[Tensor],
                       exp_avg_sqs: List[Tensor],
                       max_exp_avg_sqs: List[Tensor],
                       state_steps: List[Tensor],
                       grad_scale: Optional[Tensor],
                       found_inf: Optional[Tensor],
                       *,
                       amsgrad: bool,
                       beta1: float,
                       beta2: float,
                       lr: float,
                       weight_decay: float,
                       eps: float,
                       maximize: bool,
                       capturable: bool,
                       differentiable: bool):
    if len(params) == 0:
        return

    if capturable:
        assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
            "If capturable=True, params and state_steps must be CUDA tensors."

    assert grad_scale is None and found_inf is None

    assert not differentiable, "_foreach ops don't support autograd"

    grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
    for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
         device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():

        if maximize:
            device_grads = torch._foreach_neg(tuple(device_grads))  # type: ignore[assignment]

        # Handle complex parameters
        device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
        device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs]
        device_exp_avg_sqs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs]
        params_ = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]

        # update steps
        torch._foreach_add_(device_state_steps, 1)

        if weight_decay != 0:
            device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        torch._foreach_mul_(device_exp_avgs, beta1)
        torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)

        torch._foreach_mul_(device_exp_avg_sqs, beta2)
        torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)

        if capturable:
            # TODO: use foreach_pow if/when foreach_pow is added
            bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps]
            bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps]
            # foreach_sub doesn't allow a scalar as the first arg
            torch._foreach_sub_(bias_correction1, 1)
            torch._foreach_sub_(bias_correction2, 1)
            torch._foreach_neg_(bias_correction1)
            torch._foreach_neg_(bias_correction2)

            # foreach_div doesn't allow a scalar as the first arg
            step_size = torch._foreach_div(bias_correction1, lr)
            torch._foreach_reciprocal_(step_size)
            torch._foreach_neg_(step_size)

            bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)  # type: ignore[assignment]

                # Use the max. for normalizing running avg. of gradient
                max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
                torch._foreach_div_(max_exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size))
                eps_over_step_size = torch._foreach_div(step_size, eps)
                torch._foreach_reciprocal_(eps_over_step_size)
                denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
            else:
                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
                torch._foreach_div_(exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size))
                eps_over_step_size = torch._foreach_div(step_size, eps)
                torch._foreach_reciprocal_(eps_over_step_size)
                denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)

            torch._foreach_addcdiv_(params_, device_exp_avgs, denom)
        else:
            bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
            bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]

            step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])

            bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)

                # Use the max. for normalizing running avg. of gradient
                max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
                torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
                denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
            else:
                exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
                torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
                denom = torch._foreach_add(exp_avg_sq_sqrt, eps)

            torch._foreach_addcdiv_(params_, device_exp_avgs, denom, step_size)

In [6]:
def _fused_adam(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_sqs: List[Tensor],
    max_exp_avg_sqs: List[Tensor],
    state_steps: List[Tensor],
    grad_scale: Optional[Tensor],
    found_inf: Optional[Tensor],
    *,
    amsgrad: bool,
    beta1: float,
    beta2: float,
    lr: float,
    weight_decay: float,
    eps: float,
    maximize: bool,
    capturable: bool,  # Needed for consistency.
    differentiable: bool,
) -> None:
    grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
    grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
    found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
    grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
    for (device, dtype) in grouped_tensors:
        (
            device_params,
            device_grads,
            device_exp_avgs,
            device_exp_avg_sqs,
            device_max_exp_avg_sqs,
            device_state_steps,
        ) = grouped_tensors[(device, dtype)]
        if grad_scale is not None and found_inf is not None:
            if device not in grad_scale_dict:
                grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
            if found_inf not in found_inf_dict:
                found_inf_dict[device] = found_inf.to(device, non_blocking=True)
            device_grad_scale = grad_scale_dict[device]
            device_found_inf = found_inf_dict[device]
        else:
            device_grad_scale = None
            device_found_inf = None
        torch._foreach_add_(device_state_steps, 1)
        torch._fused_adam_(
            device_params,
            device_grads,
            device_exp_avgs,
            device_exp_avg_sqs,
            device_max_exp_avg_sqs,
            device_state_steps,
            amsgrad=amsgrad,
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            weight_decay=weight_decay,
            eps=eps,
            maximize=maximize,
            grad_scale=device_grad_scale,
            found_inf=device_found_inf,
        )
        if device_found_inf is not None:
            torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))

# Citation List

* https://huggingface.co/docs/transformers/v4.35.2/en/main_classes/optimizer_schedules#transformers.AdamW
* https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW
* https://nn.labml.ai/optimizers/adam.html
* https://arxiv.org/pdf/1412.6980.pdf
* https://arxiv.org/abs/1711.05101