Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensitivity wrt LR restarts #8

Open
depthwise opened this issue Aug 17, 2019 · 12 comments
Open

Sensitivity wrt LR restarts #8

depthwise opened this issue Aug 17, 2019 · 12 comments
Assignees
Labels
question Further information is requested

Comments

@depthwise
Copy link

I'm observing sensitivity wrt LR restarts in a typical SGDR schedule with cosine annealing as in Loschilov & Hutter. RAdam still seems to be doing better than AdamW so far, but the jumps imply possible numerical instability at LR discontinuities.

Here's the training loss compared to AdamW (PyTorch 1.2.0 version):
radam_jumps

Here's the validation loss:
radam_val

What's the recommendation here? Should I use warmup in every cycle rather than just in the beginning? I thought RAdam was supposed to obviate the need for warmup. Is this a bug?

@depthwise
Copy link
Author

depthwise commented Aug 17, 2019

The model is ShuffleNet V2, the dataset is Imagenet 1K. This isn't necessarily a bug, just wanted to bring this to author's attention. Feel free to close.

@LiyuanLucasLiu
Copy link
Owner

Thanks for bringing this up. In our analysis & experiments, we haven't try any learning rate restarts. I agree this issue may due to numerical instability or algorithm design. Will look into it later.

@LiyuanLucasLiu LiyuanLucasLiu self-assigned this Aug 17, 2019
@LiyuanLucasLiu LiyuanLucasLiu added the question Further information is requested label Aug 17, 2019
@LiyuanLucasLiu
Copy link
Owner

Also, RAdam didn't obviate all needs for warmup : ( we found in some cases, adding additional warmup gets a better performance (some discussions are put at: https://github.com/LiyuanLucasLiu/RAdam#questions-and-discussions).

@e-sha
Copy link

e-sha commented Oct 3, 2019

I found that the problem is in the initial 5 steps.
The problem is in the code:

if N_sma >= 5:
    step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
    step_size = 1.0 / (1 - beta1 ** state['step'])

The default setting for betas is (0.9, 0.999). Thus the internal variables are changed as following:

state['step']| step_size
------------------------------
        1    |     10
        2    |5.26315789
        3    |3.6900369
        4    |2.90782204
        5    |2.44194281
        6    |0.00426327
        7    |0.00524248
        8    |0.00607304
        9    |0.00681674
       10    |0.00750596

Note, that step_size doesn't depend on gradient value and it scales learning_rate.
Thus RAdam aggressively moves weights from their initial values, even if they have a good initialization.

Is it better to set step_size equal to 0 if N_sma < 5?

@LiyuanLucasLiu
Copy link
Owner

@e-sha the step_size here is not the learning rate, but more like the step size ratio. When N_sma < 5, the adaptive learning rate will be turned off, and step_size here is set to the bias adjustment for the first momentum.

For example, in the first update, the first momentum is set to 0.1 times of the gradient; thus the step_size is set to 10 to re-adjust this issue.

You can refer the implementation of the Adam algorithm for more details.

@e-sha
Copy link

e-sha commented Oct 4, 2019

I got it. You are right.
So at initialization stage the RAdam works exactly like SGD with momentum.
Nevertheless, It seems that optimal learning rates for SGD with momentum and Adam (or RAdam) differ considerably. Let's consider my toy example for training. The model weights are initialized to be near local optimum:

#optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5, amsgrad=False)
optimizer = RAdam(model.parameters(), lr=0.001, weight_decay=1e-5)

n = len(train_loader)
for i, (data, start, stop, image1, image2) in enumerate(train_loader):
    data, image1, image2 = map(lambda x: x.to(device), (data, image1, image2))
    optimizer.zero_grad()
    shape = image1.size()[-2:]
    prediction = model(data, start, stop, shape, raw=False)
    loss, _ = combined_loss(prediction, image1, image2)
    loss.backward()
    optimizer.step()
    print(f'{loss}')
    if i == 10: break

For RAdam it outputs

18.729469299316406
111.54562377929688
114.46490478515625
119.26341247558594
94.38755798339844
122.42045593261719
89.53224182128906
105.24909973144531
100.99745178222656
81.69304656982422
84.50157165527344

The output for Adam is:

18.50310707092285
32.51015853881836
19.18273162841797
21.213653564453125
14.611763000488281
13.85060977935791
21.966794967651367
17.23381233215332
19.077804565429688
19.467504501342773
19.318479537963867

The last is much better.
Should we apply Adam instead of SGD with momentum when N_sma < 5?

@LiyuanLucasLiu
Copy link
Owner

LiyuanLucasLiu commented Oct 4, 2019

Thanks for letting us know @e-sha, can you provide a full script to reproduce the result? I'm not sure why RAdam behaves in this way. Intuitively, SGDM should be more stable than adam (since it updates less).

@e-sha
Copy link

e-sha commented Oct 7, 2019

I found that the problem is in gradient values.
The gradients w.r.t. some of parameters are bigger than 1. Thus on the first iteration of training SDGW just multiplies them to 10*learning_rate, but the Adam makes the smaller step (equal to learning_rate) due to normalization.

@LiyuanLucasLiu
Copy link
Owner

@e-sha Thanks for letting us know : -)

I guess you mean the problem is in parameter values? or gradient values? I think in the first iteration SGDM, although with the bias adjustment, should have smaller updates comparing to Adam. For example, even the gradient is larger than one, it will be first multiplied with 0.1, then multiplied with 10 * learning_rate.

@e-sha
Copy link

e-sha commented Oct 9, 2019

Yes, your are right. In the first iteration SGDM makes step equal to learning_rate * gradient.
In my particular case some values of gradient are >> 1. Thus SGDM makes a bigger step compared to Adam, as Adam makes step exactly equal to learning_rate.

@LiyuanLucasLiu
Copy link
Owner

I see, I understand it know, thanks for sharing. BTW, people find that using the gradient clip also helps to stabilize the model training.

@LiyuanLucasLiu
Copy link
Owner

@e-sha I added an option to decide whether to use sgdm
373b3e4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants