In [None]:
from dataclasses import dataclass
import math


@dataclass
class TrainConfig:
    max_iters: int = 10
    learning_rate: float = 10
    min_lr: float = 2
    min_lr_decay_ratio: float = 0.1
    min_lr_iter_ratio: float = 0.1
    warmup_iter_ratio: float = 0.1
class Trainer:
    def __init__(self, model, train_dataset, val_dataset=None, config=TrainConfig):
        self.config = config
    def get_lr(self, cur_iter):
        """
        LR ^
        |
        |      /‾‾\          learning_rate
        |     /    ‾‾\
        |    /        ‾\
        |   /           ‾\
        |  /              \  min_lr_decayed
        | /
        min_lr
        +------+----------+--> iter
        0     warmup_    max_
            end_iter   iters
        """
        #first  LR warmup  from min_lr to learning_rate at (warmup_iter_ratio % of iters)
        config = self.config
        warmup_end_iter = config.warmup_iter_ratio * config.max_iters
        if cur_iter < warmup_end_iter:
            return config.min_lr \
                + (config.learning_rate-config.min_lr) \
                * (cur_iter/warmup_end_iter)
        # cosine LR decay from learning_rate at (warmup_iter_ratio % of iters) down to min_lr_ratio at (min_lr_iter_ratio % of iters)
        min_lr_decayed = config.min_lr_decay_ratio * config.learning_rate
        if cur_iter < config.max_iters:
            return min_lr_decayed \
                    + math.cos((cur_iter - warmup_end_iter) / (config.max_iters - warmup_end_iter) * math.pi/2) \
                    * (config.learning_rate - min_lr_decayed)
        return min_lr_decayed

  """


In [7]:
trainer = Trainer(None,None,)
[trainer.get_lr(iter) for iter in range(TrainConfig.max_iters)]

[2.0,
 10.0,
 9.863269777109872,
 9.457233587073176,
 8.794228634059948,
 7.894399988070802,
 6.785088487178855,
 5.500000000000001,
 4.0781812899310195,
 2.562833599002374]