In [None]:
import torch
from torch.optim.optimizer import Optimizer, required

class LARS(Optimizer):
    def __init__(self, params, lr=required, momentum=0.9, weight_decay=0.0, eta=0.001, eps=1e-8):
        if lr is not required and lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0 or momentum >= 1.0:
            raise ValueError(f"Invalid momentum: {momentum}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay: {weight_decay}")
        if eta <= 0.0:
            raise ValueError(f"Invalid eta: {eta}")
        
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eta=eta, eps=eps)
        super(LARS, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']
            weight_decay = group['weight_decay']
            eta = group['eta']
            eps = group['eps']

            for param in group['params']:
                if param.grad is None:
                    continue

                grad = param.grad.data
                if weight_decay != 0:
                    grad.add_(param.data, alpha=weight_decay)

                # calculate learning rate layer-wise
                param_norm = param.data.norm(2)
                grad_norm = grad.norm(2)
                local_lr = eta * param_norm / (grad_norm + eps)

                # apply momentum updates
                if 'momentum_buffer' not in self.state[param]:
                    buf = self.state[param]['momentum_buffer'] = torch.clone(grad).detach()
                else:
                    buf = self.state[param]['momentum_buffer']
                    buf.mul_(momentum).add_(grad)

                param.data.add_(buf, alpha=-local_lr * lr)

        return loss


class AdamW(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamW, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for param in group['params']:
                if param.grad is None:
                    continue

                grad = param.grad.data
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')

                state = self.state[param]

                # initial state
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(param.data)
                    state['exp_avg_sq'] = torch.zeros_like(param.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # update moving averages
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = (exp_avg_sq.sqrt() / (1 - beta2 ** state['step'])).add_(group['eps'])
                step_size = group['lr'] / (1 - beta1 ** state['step'])

                # apply weight decay
                if group['weight_decay'] != 0:
                    param.data.add_(param.data, alpha=-group['weight_decay'] * group['lr'])

                # update params
                param.data.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

class LAMB(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(LAMB, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for param in group['params']:
                if param.grad is None:
                    continue

                grad = param.grad.data
                if grad.is_sparse:
                    raise RuntimeError('LAMB does not support sparse gradients')

                state = self.state[param]

                # Initial state
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(param.data)
                    state['exp_avg_sq'] = torch.zeros_like(param.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Update moving averages
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = (exp_avg_sq.sqrt() / (1 - beta2 ** state['step'])).add_(group['eps'])
                r1 = param.data.norm(2)
                r2 = exp_avg.norm(2)

                if r1 == 0 or r2 == 0:
                    trust_ratio = 1.0
                else:
                    trust_ratio = r1 / r2

                step_size = group['lr'] * trust_ratio

                # apply weight decay
                if group['weight_decay'] != 0:
                    param.data.add_(param.data, alpha=-group['weight_decay'])

                # upadates params
                param.data.addcdiv_(exp_avg, denom, value=-step_size)

        return loss