In [10]:
import functools
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

import numpy as np
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.optim import SGD, Optimizer

if TYPE_CHECKING:  # pragma: no cover
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any


class AdaScale(Optimizer):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        world_size: Optional[int] = None,
        scale: Optional[float] = None,
        smoothing: float = None,
        num_gradients_to_accumulate: int = 1,
        debias_ewma: bool = True,
    ):
        self._optimizer = optimizer
        self._local_grad_sqr: Optional[torch.Tensor] = None
        self._world_size: int = (
            world_size if world_size is not None else 1 if not dist.is_available() else dist.get_world_size() if dist.is_initialized() else 1
        )
        self._num_backward_calls = 0
        self._last_final_backward_call = 0
        self._num_grads_to_accum = num_gradients_to_accumulate
        self._debias_ewma = debias_ewma

        # Proxy the param_groups so that `torch.optim.lr_scheduler` can work.
        self.param_groups = self._optimizer.param_groups

        self.set_num_gradients_to_accumulate(num_gradients_to_accumulate, update_smoothing=True)

        # The previous function call sets smoothing to its default value.
        # Override that here if smoothing was passed as an argument.
        if smoothing is not None:
            self._smoothing = smoothing

        if self._world_size * self._num_grads_to_accum <= 1:
            # gain will be NaN since we will be dividing by zero in paper's B.3 where (S-1) == 0.
            raise RuntimeError("AdaScale does not support a single worker without grad accumulation.")

        # Per-param-group sqr & var states (sigma^2 & mu^2 in the paper).
        self._optimizer.state.setdefault(
            "adascale",
            {
                "grad_sqr_avg": np.ones(len(optimizer.param_groups)),
                "grad_var_avg": np.zeros(len(optimizer.param_groups)),
            },
        )

        self._scale = 1.0  # Assign to inform mypy about the typing of this variable.
        self.set_scale(self._world_size * self._num_grads_to_accum if scale is None else scale)

        self._hook_handles: List[Any] = []
        self._hook()

    def _hook(self) -> None:
        assert self._hook_handles == [], "Must run unhook first"
        for idx, param_group in enumerate(self._optimizer.param_groups):
            for param in param_group["params"]:
                h = param.register_hook(functools.partial(self._backward_hook, idx))
                self._hook_handles.append(h)

    def __del__(self) -> None:
        self.unhook()

    def unhook(self) -> None:
        for h in self._hook_handles:
            h.remove()
        self._hook_handles = []

    @property
    def _state(self) -> Dict[str, np.ndarray]:
        return self._optimizer.state["adascale"]

    @property
    def scale(self) -> float:
        return self._scale

    @property
    def smoothing(self) -> float:
        return self._smoothing

    def set_scale(self, scale: float, update_estimate: bool = True) -> None:
        assert self._local_grad_sqr is None, "Don't change scale in backward phase"
        assert scale >= 1, "Scale must be at least 1"
        if update_estimate and hasattr(self, "_scale"):
            assert self._scale >= 1, "bug: old scale isn't valid"
            # Rescale grad_var_avg to account for the change in scale
            if self._debias_ewma and "grad_var_avg_biased" in self._state:
                self._state["grad_var_avg_biased"] *= self._scale / scale
            elif "grad_var_avg_total" in self._state:  # _debias_ewma==False
                self._state["grad_var_avg_total"] *= self._scale / scale
            self._state["grad_var_avg"] *= self._scale / scale
        self._scale = scale

    def _grad_sqr_avg(self, pg_idx: Optional[int] = None) -> float:
        if pg_idx is not None:
            return self._state["grad_sqr_avg"][pg_idx]
        else:
            return float(np.sum(self._state["grad_sqr_avg"]))

    def _grad_var_avg(self, pg_idx: Optional[int] = None) -> float:
        if pg_idx is not None:
            return self._state["grad_var_avg"][pg_idx]
        else:
            return float(np.sum(self._state["grad_var_avg"]))

    def gain(self, pg_idx: Optional[int] = None) -> float:
        var = self._grad_var_avg(pg_idx)
        sqr = self._grad_sqr_avg(pg_idx)
        gain = (var + sqr) / (var / self.scale + sqr)
        return gain

    def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None:
        if self._debias_ewma:
            # This function computes and stores the moving average of a vector
            # using a smoothing factor.
            biased = self._state.get(name + "_biased", np.zeros(value.shape[0]))
            unbias = self._state.get(name + "_unbias", np.zeros(value.shape[0]))
            biased = factor * biased + (1.0 - factor) * value
            unbias = factor * unbias + (1.0 - factor)
            self._state[name + "_biased"] = biased
            self._state[name + "_unbias"] = unbias
            self._state[name] = biased / unbias
        else:
            count = self._state.get(name + "_count", np.zeros(1))
            count[0] += 1
            self._state[name + "_count"] = count
            if count < 1 / (1 - self._smoothing):
                total = self._state.get(name + "_total", None)
                if total is None:
                    total = value
                else:
                    total += value
                self._state[name + "_total"] = total
                self._state[name] = total / count
            else:
                self._state[name] = factor * self._state[name] + (1.0 - factor) * value

    def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
        if self._local_grad_sqr is None:
            self._local_grad_sqr = torch.zeros(
                len(self._optimizer.param_groups),
                device=grad.device,
                requires_grad=False,
            )
        self._local_grad_sqr[pg_idx] += grad.pow(2).sum()
        self._final_callback_queued = False
        Variable._execution_engine.queue_callback(self._queue_callback)

    def _queue_callback(self) -> None:
        if self._final_callback_queued:
            return
        self._final_callback_queued = True
        Variable._execution_engine.queue_callback(self._final_callback)

    def _final_callback(self) -> None:
        self._final_callback_queued = False
        assert isinstance(self._local_grad_sqr, torch.Tensor)
        self._num_backward_calls += 1
        assert (
            self._num_backward_calls - self._last_final_backward_call
        ) <= self._num_grads_to_accum, (
            f"bug: {self._num_backward_calls} - {self._last_final_backward_call} should <= {self._num_grads_to_accum}"
        )
        if (self._num_backward_calls - self._last_final_backward_call) % self._num_grads_to_accum != 0:
            assert self._local_grad_sqr is not None, "We should still be in backward phase"
            return

        work = None
        if self._world_size > 1:
            work = dist.all_reduce(self._local_grad_sqr, async_op=True)  # SUM

        total_grad_sqr = np.array(
            [sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups]
        )

        if self._num_grads_to_accum > 1:
            # np array doesn't support /=.
            total_grad_sqr = total_grad_sqr / (self._num_grads_to_accum ** 2)

        # Wait for all_reduce to be done and move it to cpu & np.
        if work:
            work.wait()
        local_grad_sqr = self._local_grad_sqr.cpu().numpy()

        S = self._scale
        cN = self._world_size * self._num_grads_to_accum
        grad_var = local_grad_sqr * (S / cN) / (cN - 1) - total_grad_sqr * S / (cN - 1)
        grad_sqr = total_grad_sqr - grad_var / S
        grad_var = np.maximum(grad_var, 1e-6)
        grad_sqr = np.maximum(grad_sqr, 0.0)
        self._update_avg("grad_sqr_avg", grad_sqr, self.smoothing)
        self._update_avg("grad_var_avg", grad_var, self.smoothing)
        self._last_final_backward_call = self._num_backward_calls
        # Indicating backward is done.
        self._local_grad_sqr = None

    def step(self, *args: Any, **kwargs: Any) -> Optional[float]:
        assert self._local_grad_sqr is None, "Don't step without finishing backward phase"
        # Set original LR and set new LR.
        original_lr = []
        for idx, param_group in enumerate(self._optimizer.param_groups):
            original_lr.append(param_group["lr"])
            param_group["lr"] = self.gain(pg_idx=idx) * param_group["lr"]

        # Step it.
        res = self._optimizer.step(*args, **kwargs)

        # Restore the original LR.
        for lr, param_group in zip(original_lr, self._optimizer.param_groups):
            param_group["lr"] = lr

        return res

    def add_param_group(self, pg: Dict) -> None:
        assert self._local_grad_sqr is None, "Can't add parameter group during backward"
        self._optimizer.add_param_group(pg)
        # Update the hooks.
        self.unhook()
        self._hook()
        # Extend the states.
        for name in self._state.keys():
            assert name.startswith("grad_sqr_avg") or name.startswith("grad_var_avg"), name
            if name.endswith("_count"):
                # This is the "_count" variable, should be a 1D int.
                assert self._state[name].shape == (1,), self._state[name].shape
                continue
            # must be a np array, extend it with the right value and check the shape.
            val = 1 if name == "grad_sqr_avg" else 0
            self._state[name] = np.append(self._state[name], val)  # type: ignore
            assert self._state[name].shape == (len(self._optimizer.param_groups),)

    def zero_grad(self) -> None:
        assert self._local_grad_sqr is None, "Don't zero_grad in backward"
        return self._optimizer.zero_grad()

    def state_dict(self) -> Dict:
        assert self._local_grad_sqr is None, "Don't checkpoint in backward"
        return self._optimizer.state_dict()

    def load_state_dict(self, data: Dict) -> None:
        assert self._local_grad_sqr is None, "Don't load checkpoint in backward"
        return self._optimizer.load_state_dict(data)

    def set_num_gradients_to_accumulate(
        self,
        num_gradients_to_accumulate: int,
        update_smoothing: bool = True,
    ) -> None:
        assert self._local_grad_sqr is None, "Don't change num_grad_to_accum in backward"
        assert num_gradients_to_accumulate >= 1, f"Invalid value {num_gradients_to_accumulate}"
        self._num_grads_to_accum = num_gradients_to_accumulate
        if update_smoothing:
            self._smoothing = max(1 - self._world_size * self._num_grads_to_accum / 1000, 0)

    def __getattr__(self, name: str) -> Any:
        """Forward missing attributes to wrapped optimizer."""
        try:
            return super().__getattr__(name)  # defer to Optimizer logic
        except AttributeError:
            return getattr(self._optimizer, name)  # fallback to wrapped optim