In [1]:
class One_Cycle_Policy_LR(lr_scheduler._LRScheduler):
    """Starting with the optimizer's default learning rate, during a single cycle with the given cycle length, linearly
    increases the learning rate to the maximum learning rate specified, then decreases it back to the initial value.
    After the cycle is finished, the learning rate will continue decreasing linearly, decay exponentially, or remain
    constant depending on the mode specified. (The original implementation decreases the learning rate linearly to
    the annihilation)
    
    Optionally, also controls the momentum value; decreasing and increasing it during the cycle.
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_lr (int or float): Maximum learning rate.
        cycle_length (int or float): Number of epochs for the cycle. The value
            can be a float as well.
        total_epochs (int or float): Total number of epochs during the training.
            Used only if `lr_mode_after_cycle` is set to `linear`. Can be a float.
            Otherwise, the value is ignored. Should be greater than `cycle_length`,
            though not strictly prohibited. Default: 0.
        gamma (int or float): Multiplicative factor of learning rate decay. Used only if
            `lr_mode_after_cycle` is set to `exponential`. Ignored for other modes.
            Default: 0.98.
        const_lr (int or float): Constant learning rate value. Used only if `lr_mode_after_cycle`
            is set to 'constant'. Ignored for other modes. Default: 1e-03.
        min_lr (int or float): Minimum learning rate after the cycle. Default: 0.
        lr_mode_after_cycle (str): One of `linear`, `exponential`, `constant`.
            In `linear` mode, the learning rate will continue to decrease linearly
            towards 0 after the cycle is finished until `total_epochs` specified, and it
            will be 0 after it reaches `total_epochs`. In `exponential` mode, the learning
            rate will start decreasing exponentially every epoch after the cycle with decay
            factor `gamma`. In `constant` mode, the learning rate will be set to `const_lr`
            after the cycle and will remain constant. Default: ``linear``.
        min_momentum (int or float): Minimum momentum value within the range [0, 1]. Used only if
            `cyclic_momentum` is ``True``. Default: 0.85.
        cyclic_momentum (bool): Whether to vary the momentum value during the cycle. If
            ``True``, the optimizer must have ``momentum`` parameter (or ``betas`` for Adam)
            and the momentum value will linearly decrease to `min_momentum` and increase back
            to the original value during the cycle. Default: ``False``.
        last_epoch (int or float): The index of last epoch. Can be a float depending on your
            purpose. Default: -1.
    """
    def __init__(self, optimizer, max_lr, cycle_length, total_epochs=0, gamma=0.8, const_lr=3e-3, min_lr=0.,
                 lr_mode_after_cycle="linear", min_momentum=0.85, cyclic_momentum=False, last_epoch=-1):
        if not isinstance(optimizer, optim.Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        
        if not isinstance(max_lr, (int, float)):
            raise TypeError('expected an int or float for `max_lr`, but {} was given'.format(type(max_lr)))
        if max_lr < 0:
            raise ValueError('expected a non-negative value for `max_lr`, but {} was given'.format(max_lr))
        self.max_lr = float(max_lr)
        
        if not isinstance(min_lr, (int, float)):
            raise TypeError('expected an int or float for `min_lr`, but {} was given'.format(type(min_lr)))
        if min_lr < 0:
            raise ValueError('expected a non-negative value for `min_lr`, but {} was given'.format(min_lr))
        self.min_lr = float(min_lr)
        
        if not isinstance(cycle_length, (int, float)):
            raise TypeError('expected an int or float for `cycle_length`, but {} was given'.format(type(cycle_length)))
        if cycle_length < 0:
            raise ValueError('`cycle_length` must be non-negative')
        self.cycle_length = cycle_length
        
        if cyclic_momentum:
            if not isinstance(min_momentum, (int, float)):
                raise TypeError('expected an int or float for `min_momentum`, but {} was given'.format(type(min_momentum)))
            if not 0 <= min_momentum <= 1:
                raise ValueError('expected a value within the range [0, 1] for `min_momentum`, but {} was given'.format(min_momentum))
            self.cyclic_momentum = True
            self.min_momentum = float(min_momentum)
        else:
            self.cyclic_momentum = False
            self.min_momentum = None
        
        if not isinstance(lr_mode_after_cycle, str):
            raise TypeError('expected a str for `lr_mode_after_cycle`, but {} was given'.format(type(lr_mode_after_cycle)))
        self.lr_mode_after_cycle = lr_mode_after_cycle.lower()
        if self.lr_mode_after_cycle == 'linear':
            if not isinstance(total_epochs, (int, float)):
                raise TypeError('expected an int or float for `total_epochs`, but {} was given'.format(type(total_epochs)))
            self.total_epochs = total_epochs
            self.gamma = gamma
            self.const_lr = const_lr
        elif self.lr_mode_after_cycle == 'exponential':
            if not isinstance(gamma, (int, float)):
                raise TypeError('expected a float for `gamma`, but {} was given'.format(type(gamma)))
            if not 0 <= gamma <= 1:
                raise ValueError('expected a value within the range [0, 1] for `gamma`, but {} was given'.format(gamma))
            self.total_epochs = total_epochs
            self.gamma = float(gamma)
            self.const_lr = const_lr
        elif self.lr_mode_after_cycle == 'constant':
            if not isinstance(const_lr, (int, float)):
                raise TypeError('expected a float for `const_lr`, but {} was given'.format(type(const_lr)))
            if const_lr < 0:
                raise ValueError('expected a non-negative value for `const_lr`, but {} was given'.format(const_lr))
            self.total_epochs = total_epochs
            self.gamma = gamma
            self.const_lr = float(const_lr)
        else:
            raise ValueError('expected one of (`linear`, `exponential`, `constant`), but `{}` was given'.format(self.lr_mode_after_cycle))
        
        if last_epoch == -1:
            for i, group in enumerate(optimizer.param_groups):
                if cyclic_momentum:
                    if 'momentum' not in group:
                        if 'betas' not in group:
                            raise KeyError("param 'momentum' or 'betas' is not present "
                                           "in param_groups[{}] of the given optimizer {}".format(i, type(optimizer).__name__))
                        elif self.min_momentum > group['betas'][0]:
                            raise ValueError("first beta value in `betas` of param_groups[{}] of the given optimizer {} "
                                             "is below `min_momentum` given".format(i, type(optimizer).__name__))
                        else:
                            group.setdefault('initial_momentum', group['betas'][0])
                    elif self.min_momentum > group['momentum']:
                        raise ValueError("`momentum` value in param_groups[{}] of the given optimizer {} "
                                         "is below `min_momentum` given".format(i, type(optimizer).__name__))
                    else:
                        group.setdefault('initial_momentum', group['momentum'])
                if self.max_lr < group['lr']:
                    raise ValueError("`lr` value in param_groups[{}] of the given optimizer {} "
                                     "exceeds `max_lr` given".format(i, type(optimizer).__name__))
                else:
                    group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
                if self.max_lr < group['initial_lr']:
                    raise ValueError("`initial_lr` value in param_groups[{}] of the given optimizer {} "
                                     "exceeds `max_lr` given".format(i, type(optimizer).__name__))
                if cyclic_momentum:
                    if 'initial_momentum' not in group:
                        raise KeyError("param 'initial_momentum' is not specified "
                                       "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        if cyclic_momentum:
            self.base_momentums = list(map(lambda group: group['initial_momentum'], optimizer.param_groups))
        else:
            self.base_momentums = None
        self.step(last_epoch + 1)
        self.last_epoch = last_epoch
        
    def get_lr(self):
        if self.last_epoch <= 0:
            return self.base_lrs
        elif self.last_epoch <= self.cycle_length:
            return [self.max_lr + (base_lr - self.max_lr)
                    * abs(-2. * (self.last_epoch / self.cycle_length) + 1)
                    for base_lr in self.base_lrs]
        else:
            if self.lr_mode_after_cycle == 'linear':
                if self.last_epoch < self.total_epochs:
                    return [max(base_lr *
                                (1 - (self.last_epoch - self.cycle_length)
                                 / (self.total_epochs - self.cycle_length)), self.min_lr)
                            for base_lr in self.base_lrs]
                else:
                    return [self.min_lr for base_lr in self.base_lrs]
            elif self.lr_mode_after_cycle == 'exponential':
                return [max(base_lr * self.gamma ** (self.last_epoch - self.cycle_length), self.min_lr)
                        for base_lr in self.base_lrs]
            else:
                return [max(self.const_lr, self.min_lr) for base_lr in self.base_lrs]
        
    def get_momentum(self):
        if self.cyclic_momentum:
            if 0 <= self.last_epoch < self.cycle_length:
                return [self.min_momentum + (base_momentum - self.min_momentum)
                        * abs(-2. * (self.last_epoch / self.cycle_length) + 1)
                        for base_momentum in self.base_momentums]
            else:
                return self.base_momentums
        else:
            return None
        

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
            
        if self.cyclic_momentum:
            for param_group, momentum in zip(self.optimizer.param_groups, self.get_momentum()):
                if 'momentum' in param_group:
                    param_group['momentum'] = momentum
                else:
                    param_group['betas'] = (momentum, param_group['betas'][1])


NameError: name 'lr_scheduler' is not defined