Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WarmupLR Scheduler 单测问题 #37

Closed
rentainhe opened this issue Dec 31, 2021 · 15 comments
Closed

WarmupLR Scheduler 单测问题 #37

rentainhe opened this issue Dec 31, 2021 · 15 comments
Labels
bug Something isn't working

Comments

@rentainhe
Copy link
Contributor

rentainhe commented Dec 31, 2021

待讨论的问题以及可能的解决方案

问题描述

WarmupLR Scheduler在和其他Scheduler组合的时候,好像在Warmup结束的阶段学习率无法到预设的值,不知道是我代码的问题还是oneflow里的WarmUpLR的问题,可以一起帮忙我看看

import oneflow as flow
import oneflow.nn as nn

p = nn.Parameter(flow.zeros(0))
opt = flow.optim.SGD([p], lr=5.0)

multi_step = flow.optim.lr_scheduler.MultiStepLR(opt, [10], gamma=0.1)
sched = flow.optim.lr_scheduler.WarmUpLR(multi_step, 
                                         warmup_factor=0.001, 
                                         warmup_iters=5, 
                                         warmup_method="linear")

p.sum().backward()
opt.step()

lrs = [0.005]
for _ in range(20):
    sched.step()
    lrs.append(opt.param_groups[0]["lr"])

print(lrs)
>>> [0.005, 1.004, 2.003, 3.002, 4.001, 4.001, 4.001, 4.001, 4.001, 4.001, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007, 0.40010000000000007]

按理来说在warmup结束的时候学习率应该是5.0才对,但是这里直接是4.001

观察到的现象

  • warmup的steps更新有一个step的差距: comment
  • 如果warmup_iters设置为0的时候会报错:comment

可能的解决方案

目前libai里实现的Scheduler都用WarmUpLR进行了封装,举例如下:

@SCHEDULER_REGISTRY.register()
def WarmupMultiStepLR(optimizer: flow.optim.Optimizer,
                      warmup_factor: float,
                      warmup_iters: int,
                      milestones: list,
                      gamma: float = 0.1,
                      warmup_method: str = "linear",
                      **kwargs):
    multistep_lr = flow.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=gamma
    )
    warmup_multistep_lr = flow.optim.lr_scheduler.WarmUpLR(
        multistep_lr, warmup_factor=warmup_factor, warmup_iters=warmup_iters, warmup_method=warmup_method, **kwargs
    )
    return warmup_multistep_lr
  • 关于warmup steps更新相差一个steps的问题需要看代码确认一下
  • 关于warmup iters是否可以设置为0,有以下两个解决方案:
    • warmup iters设置为0,目的就是为了不使用任何warmup操作,但是我看了一下oneflow里是把WarmUpLR单独作为了一个Scheduler,可以传入optimizer或者一个别的scheduler,如果传入的是别的scheduler的话,我认为可以考虑一下warmup_iters设置为0的情况,这样表示我虽然用WarmUpLR进行了封装,但是我并不去调用这个WarmUp操作,而是保留了原来的scheduler
    • warmup iters设置为0,其实也可以通过我这里改一下判断,如果warmup_iters = 0的话,直接return原来的scheduler,不知道哪种更加合适
@SCHEDULER_REGISTRY.register()
def WarmupMultiStepLR(optimizer: flow.optim.Optimizer,
                      warmup_factor: float,
                      warmup_iters: int,
                      milestones: list,
                      gamma: float = 0.1,
                      warmup_method: str = "linear",
                      **kwargs):
    multistep_lr = flow.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=gamma
    )
    # 在这里做个判断直接return
    if warmup_iters == 0:
        return multistep_lr
    warmup_multistep_lr = flow.optim.lr_scheduler.WarmUpLR(
        multistep_lr, warmup_factor=warmup_factor, warmup_iters=warmup_iters, warmup_method=warmup_method, **kwargs
    )
    return warmup_multistep_lr
@rentainhe rentainhe added the bug Something isn't working label Dec 31, 2021
@CPFLAME
Copy link
Contributor

CPFLAME commented Dec 31, 2021

把warmup_factor设置成0试试看呢?

@rentainhe
Copy link
Contributor Author

把warmup_factor设置成0试试看呢?

设置为0好像会报错,不能设置为0貌似,这个地方我感觉可以修一下。。应该是可以设置为0才对,如果传入的是一个LR Scheduler而不是optimizer的话

@CPFLAME
Copy link
Contributor

CPFLAME commented Dec 31, 2021

我本地跑了一下, 可以设置为0, 但是lr还是不对,达不到5.0, 可以让框架组的人来看看

@rentainhe
Copy link
Contributor Author

我本地跑了一下, 可以设置为0, 但是lr还是不对,达不到5.0, 可以让框架组的人来看看

  • 我好像有个报错
AssertionError: warmup_iters must greater than zero, but got 0

@CPFLAME
Copy link
Contributor

CPFLAME commented Dec 31, 2021

估计我们oneflow版本不一样

@leaves-zwx
Copy link
Collaborator

AssertionError: warmup_iters must greater than zero, but got 0

这里设置的是 warmup_iters 为 0,但其实你是希望 warmup_factor 设为 0,这里是笔误吗?

@strint
Copy link
Collaborator

strint commented Dec 31, 2021

class WarmUpLR(WarmUpLrScheduler):
    def __init__(
        self,
        lrsch_or_optimizer,
        warmup_factor: float = 1.0 / 3,
        warmup_iters: int = 5,
        warmup_method="linear",
        last_step=-1,
        verbose=False,  # 打开这个看一下呢
    ):

@rentainhe
Copy link
Contributor Author

AssertionError: warmup_iters must greater than zero, but got 0

这里设置的是 warmup_iters 为 0,但其实你是希望 warmup_factor 设为 0,这里是笔误吗?

我这里是想设置warmup_iters为0,因为我的传参是另一个scheduler,我如果不希望进行任何的Warmup,这里应该需要设置为0对吧

@rentainhe
Copy link
Contributor Author

rentainhe commented Dec 31, 2021

class WarmUpLR(WarmUpLrScheduler):
    def __init__(
        self,
        lrsch_or_optimizer,
        warmup_factor: float = 1.0 / 3,
        warmup_iters: int = 5,
        warmup_method="linear",
        last_step=-1,
        verbose=False,  # 打开这个看一下呢
    ):

我尝试一下看看

Last step 0 adjusting learning rate of param_groups[0] to 0.005
Last step 1 adjusting learning rate of param_groups[0] to 1.004
Last step 2 adjusting learning rate of param_groups[0] to 2.003
Last step 3 adjusting learning rate of param_groups[0] to 3.002
Last step 4 adjusting learning rate of param_groups[0] to 4.001

这里应该是把第一步也算进去了吧,就是 5.0 * warmup_factor,这里面也算一个steps了,但是其实不应该这么算感觉

@L1aoXingyu
Copy link
Collaborator

增加一下 sched.get_lr()[0] 的调用这个函数呢,看看是不是一样的 @rentainhe

@leaves-zwx
Copy link
Collaborator

leaves-zwx commented Dec 31, 2021

PyTorch 框架本身没提供 lr warmup 的实现,但一般模型库都会自己实现一个,tianhe 发给我他们参考的 fvcore (detectron2) 的 lr warmup 的实现,链接如下:

通过比较 fvcore 和 oneflow 的 lr warmup 的实现,发现虽然计算方式不同,但结果是一致的,测试的脚本如下:

def of_warmup(factor, base_lr, steps):
    step_lrs = []
    for step in range(steps):
        multiplier = factor + (1.0 - factor) * (float(step) / steps)
        step_lrs.append(base_lr * multiplier)
    return step_lrs

def fvcore_warmup(begin_lr, end_lr, steps):
    step_lrs = []
    for step in range(steps):
        where = step / steps
        lr = end_lr * where + begin_lr * (1 - where)
        step_lrs.append(lr)
    return step_lrs

of_lrs = of_warmup(0.5, 0.01, 5)
fv_lrs = fvcore_warmup(0.005, 0.01, 5)

print(of_lrs)  // [0.005, 0.006, 0.006999999999999999, 0.008, 0.009000000000000001]
print(fv_lrs)  // [0.005, 0.006, 0.007, 0.008, 0.009]

上面都是从 0.005 到 0.01 的 5 轮的 linear warmup 计算,可以看到 5 轮都是一致的结果,说明计算公式本身并没问题。

从 issue 给出的 warmup 的测试结果也可以发现,从第 6 轮开始 lr 就与第 5 轮 lr 相等,一直到第 11 轮到达 multi-step 的 milestone 后才变化。我自己把参数调整了一下,与上面的测试一致,结果如下:

[0.005, 0.006, 0.006999999999999999, 0.008, 0.009000000000000001, 
 0.009000000000000001, 0.009000000000000001, 0.009000000000000001, 0.009000000000000001, 0.009000000000000001, 
 0.0045000000000000005, 0.0045000000000000005, 0.0045000000000000005, 0.0045000000000000005, 0.0045000000000000005, 
 0.0045000000000000005, 0.0045000000000000005, 0.0045000000000000005, 0.0045000000000000005, 0.0045000000000000005, 
 0.0045000000000000005]

发现结果与前面一样,就是第6轮开始结果不正确,且影响到了第11轮开始后 multi-step 的结果。

同样的测试移植到 graph 中后再测试了一遍,发现结果正确。

train_step, lr
0, 0.005000
1, 0.006000
2, 0.007000
3, 0.008000
4, 0.009000
5, 0.010000
6, 0.010000
7, 0.010000
8, 0.010000
9, 0.010000
10, 0.005000
11, 0.005000
12, 0.005000
13, 0.005000
14, 0.005000
15, 0.005000
16, 0.005000
17, 0.005000
18, 0.005000
19, 0.005000

已经说明是 eager 下 warmup 独有的问题。

通过查看代码发了问题所在。现在的实现是 warmup lr scheduler 作为 multi-step lr scheduler 的 wrapper,然后劫持前 N 轮(warmup iters)的 get_lr 后,来返回 warmup 的 lr 结果。warmup iters 结束后再走 multi-step 的 get_lr 函数。warmup lr scheduler 计算完 lr 后会存在 param_group["lr"] 中,以供 optimizer 接下来使用。但实际 warmup 在前 5 轮走完后,它的职责就完成了,应该是 multi-step lr scheduler 从第 6 轮开始时接管 lr 的计算,但这里明显 multi-step lr scheduler 什么都没做。问题出在 multi-step lr scheduler 的 get_lr 上:

    def get_lr(self):
        if self.last_step not in self.milestones:
            return [group["lr"] for group in self._optimizer.param_groups]
        else:
            return [group["lr"] * self.gamma for group in self._optimizer.param_groups]

我们看到它这里是直接用 group["lr"] 来计算 lr,但这个 group["lr"] 是被 multi-step lr scheduler 修改过的,从第 6 轮开始到第 10 轮,由于没有到 multi-step 的 milestone,所以它什么都没做。造成了这个错误。借这个错误我们可以阐述一些“经验”来加强程序的鲁棒性:

  • lr scheduler 会嵌套,它们尽量都要独立工作,只需要根据 last_step 就能计算出当前的 lr
  • 尽量不要依赖 group["lr"] 这个”状态“来进行计算,这里 multi-step 可以通过遍历 milestone 来让 base_lrs 累乘 gamma 获得正确的 lr,milestone 一般不会超过 5,常见的就是 2 和 3,所以这个遍历完全不用 care 性能影响。这就是我们常说使用纯函数(无状态,无论调用多少次,只要参数一样结果就一样)总比有状态函数好(调用多次结果不一样)。

之前没发现的原因是我们的测试只覆盖了:

  • warmup + cosine_decay (这个是基于 base_lrs 直接计算的)
  • multi-step (这个单独工作没问题,不会有其他 scheduler 来干扰 lr 状态)

@rentainhe rentainhe mentioned this issue Jan 2, 2022
11 tasks
@rentainhe
Copy link
Contributor Author

@chengtbf 这个问题可能需要找一下对应的人来解决一下,一直没有看到相关PR,这个bug已经挂了2周了

@strint
Copy link
Collaborator

strint commented Jan 14, 2022

已经和文骁对齐了解决方式,他下周会优先改个版本。

你急用的话,可以找文骁沟通下改的方式,自己开发下也ok。

@leaves-zwx
Copy link
Collaborator

leaves-zwx commented Jan 19, 2022

关于 WarmUpLR 有以下几个问题可能需要讨论一下。

之前 pytorch 官方的 WarmUpLR 已经删除,我们之前参照 pytorch 非正式版本中的 WarmUpLR 的接口和参数现在也失去了参照物。不过我们可以从权威第三方实现中去找到一些参照物,我觉得主要参照的实现有 megatron-lm 中 warmup + annealing LR schduler 的实现detectron2 中依赖的 fvcore 的实现,其他遇到的觉得有比较好的实现的参照欢迎补充。

关于 warmup 的实现我发现有几处细节需要注意,一个 warmup 结束 step 处的 lr 到底是以什么为参照?这里有2种可能,比如我们选择了 linear warmup,那么 warmup 就是构建一个从起始 lr (warmup_factor) 到结束 lr (base_lr/inital_lr) 的线性公式来计算 warmup 每个 step 的实际 lr。但这里结束 lr 不一定是 base_lr/inital_lr,有可能是用除 warmup 之外的另外一个 lr_scheduler 在该 step 处计算出的 lr,示例见这里

这里看起来好像采用第二种做法更好,因为如果采取第一种做法(warmup end lr 是 initial_lr),那么除 warmup lr scheduler 外的另外一个 lr_scheduler 在 warmup end 的那个 step 可能计算出的 lr 并不等于 initial_lr,这样会导致 lr 对应 step 的函数在这个 step 处有一个“断裂点”。但不一定所有情况都是如此,第二种做法有其他适用场景,取决于 warmup lr scheduler + 另外一个 lr scheduler 对 step 的协同的方式不同。

我们通常认为的 warmup steps 方式是这样计算的,给定一个 lr decay steps,前 warmup steps 都是用 warmup lr scheduler 来计算 lr,从 warmup steps 一直到 lr decay steps,都用另外一个 lr scheduler 计算 lr。但实际上也有别的方式,即 warmup steps + lr decay steps = total steps。前面 warmup steps 范围内用 warmup 公式计算,warmup 结束后,另外一个 lr scheduler 认为是从 0 step 开始按照其计算公式计算 lr,两个 lr scheduler 互不相关。见示例

感觉上应该提供一个配置来选择这2种工作方式,类似这里的实现

@leaves-zwx
Copy link
Collaborator

leaves-zwx commented Jan 19, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants