In [1]:
import torch

In [2]:
optimizer = torch.optim.SGD([torch.randn(2, 2, requires_grad=True)], lr=0.1)


In [33]:
import warnings
from collections import Counter
from torch.optim.lr_scheduler import LRScheduler
class MultiStepwithDoubleLinearWarmup(LRScheduler):
    def __init__(self, optimizer, milestones=[], gamma=1e-1,eta_max=None, eta_medium=0.0, eta_min=0.0, warmup_iters2=0, inter_warmups_iters=0, warmup_iters1=0, last_epoch=-1,
                 verbose=False):
        assert eta_max >= eta_medium >= eta_min >= 0.0, 'sa'
        self.milestones = Counter(milestones)
        self.gamma = gamma
        self.eta_max = eta_max
        self.eta_medium = eta_medium
        self.eta_min = eta_min
        self.warmup_iters2 = warmup_iters2
        self.inter_warmups_iters = inter_warmups_iters
        self.warmup_iters1 = warmup_iters1
        if eta_min > 0.0:
            for groups in optimizer.param_groups:
                groups['lr'] = eta_min
        elif eta_medium > 0.0:
            for groups in optimizer.param_groups:
                groups['lr'] = eta_medium
        elif eta_max == 0.0:
            raise ValueError('eta_max must be greater than 0.0')
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0 or self.warmup_iters1 < self.last_epoch <= self.warmup_iters1 + self.inter_warmups_iters:
            return [group['lr'] for group in self.optimizer.param_groups]

        if self.last_epoch <= self.warmup_iters1:
            return [self.eta_min + (self.eta_medium - self.eta_min) * self.last_epoch / self.warmup_iters1
                    for _ in self.optimizer.param_groups]
        
        if self.last_epoch <= self.warmup_iters1 + self.inter_warmups_iters + self.warmup_iters2:
            return [self.eta_medium + (self.eta_max - self.eta_medium) * (self.last_epoch-(self.warmup_iters1 + self.inter_warmups_iters)) / self.warmup_iters2
                    for _ in self.optimizer.param_groups]
        if self.last_epoch not in self.milestones:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
                for group in self.optimizer.param_groups]

In [34]:
optimizer = torch.optim.SGD([torch.randn(2, 2, requires_grad=True)], lr=0.1)
scheduler = MultiStepwithDoubleLinearWarmup(optimizer, [], 1e-1, 0.5, 0.3, 0.1, 10, 20, 5)

In [35]:
for _ in range(150):
    print(scheduler.get_last_lr())
    scheduler.step()

[0.1]
[0.14]
[0.18]
[0.22]
[0.26]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.3]
[0.32]
[0.33999999999999997]
[0.36]
[0.38]
[0.4]
[0.42000000000000004]
[0.44]
[0.45999999999999996]
[0.48]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
[0.5]
