In [1]:
!pip install labml-nn # Install the labml-nn package, which contains the required modules.

import math
from typing import Dict, Any, Tuple, Optional
import torch
from labml import tracker
from torch import nn
from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay


Collecting labml-nn
  Downloading labml_nn-0.4.137-py3-none-any.whl.metadata (9.2 kB)
Collecting labml==0.4.168 (from labml-nn)
  Downloading labml-0.4.168-py3-none-any.whl.metadata (7.5 kB)
Collecting labml-helpers==0.4.89 (from labml-nn)
  Downloading labml_helpers-0.4.89-py3-none-any.whl.metadata (1.4 kB)
Collecting torchtext (from labml-nn)
  Downloading torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting fairscale (from labml-nn)
  Downloading fairscale-0.4.13.tar.gz (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Downloading labml_nn-0.4.137-py3-none-any.whl (443 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m443.9/443.9 kB

In [2]:
class Adam(GenericAdaptiveOptimizer):
    def __init__(self, params,
                 lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
                 weight_decay: WeightDecay = WeightDecay(),
                 optimized_update: bool = True,
                 defaults: Optional[Dict[str, Any]] = None):
        defaults = {} if defaults is None else defaults
        defaults.update(weight_decay.defaults())
        super().__init__(params, defaults, lr, betas, eps)

        self.weight_decay = weight_decay
        self.optimized_update = optimized_update
    def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
        """
        ### Initialize a parameter state

        * `state` is the optimizer state of the parameter (tensor)
        * `group` stores optimizer attributes of the parameter group
        * `param` is the parameter tensor $\theta_{t-1}$
        """

        # This is the number of optimizer steps taken on the parameter, $t$
        state['step'] = 0
        # Exponential moving average of gradients, $m_t$
        state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
        # Exponential moving average of squared gradient values, $v_t$
        state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)

In [3]:
    def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
        """
        ### Calculate $m_t$ and and $v_t$

        * `state` is the optimizer state of the parameter (tensor)
        * `group` stores optimizer attributes of the parameter group
        * `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
        """

        # Get $\beta_1$ and $\beta_2$
        beta1, beta2 = group['betas']

        # Get $m_{t-1}$ and $v_{t-1}$
        m, v = state['exp_avg'], state['exp_avg_sq']

        # In-place calculation of $m_t$
        # $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
        m.mul_(beta1).add_(grad, alpha=1 - beta1)
        # In-place calculation of $v_t$
        # $$v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2$$
        v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

        return m, v

In [4]:
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
        """
        ### Get learning-rate

        This returns the modified learning rate based on the state.
        For *Adam* this is just the specified learning rate for the parameter group,
        $\alpha$.
        """
        return group['lr']

In [6]:
def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
                    m: torch.Tensor, v: torch.Tensor):
        """
        ### Do the *Adam* parameter update

        * `state` is the optimizer state of the parameter (tensor)
        * `group` stores optimizer attributes of the parameter group
        * `param` is the parameter tensor $\theta_{t-1}$
        * `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$.

        This computes the following

        \begin{align}
        \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
        \end{align}

        Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors
        we modify this calculation to optimize the computation.

        \begin{align}
        \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\
        \theta_t &\leftarrow \theta_{t-1} - \alpha \cdot
                \frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\
        \theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
                \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\
        \end{align}

        where
        $$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$
        is what we should specify as the hyper-parameter.
        """

        # Get $\beta_1$ and $\beta_2$
        beta1, beta2 = group['betas']
        # Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
        bias_correction1 = 1 - beta1 ** state['step']
        # Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
        bias_correction2 = 1 - beta2 ** state['step']

        # Get learning rate
        lr = self.get_lr(state, group)

        # Whether to optimize the computation
        if self.optimized_update:
            # $\sqrt{v_t} + \hat{\epsilon}$
            denominator = v.sqrt().add_(group['eps'])
            # $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
            step_size = lr * math.sqrt(bias_correction2) / bias_correction1
            # $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
            #  \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
            param.data.addcdiv_(m, denominator, value=-step_size)
        # Computation without optimization
        else:
            # $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
            denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
            # $\frac{\alpha}{1-\beta_1^t}$
            step_size = lr / bias_correction1
            # $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot
            # \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
            param.data.addcdiv_(m, denominator, value=-step_size)


In [7]:
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
        """
        ### Take an update step for a given parameter tensor

        * `state` is the optimizer state of the parameter (tensor)
        * `group` stores optimizer attributes of the parameter group
        * `grad` is the current gradient tensor  $g_t$ for the parameter $\theta_{t-1}$
        * `param` is the parameter tensor $\theta_{t-1}$
        """

        # Calculate weight decay
        grad = self.weight_decay(param, grad, group)

        # Get $m_t$ and $v_t$
        m, v = self.get_mv(state, group, grad)

        # Increment $t$ the number of optimizer steps
        state['step'] += 1

        # Perform *Adam* update
        self.adam_update(state, group, param, m, v)