Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented model iteration averaging to reduce model variance #901

Merged
merged 16 commits into from Jul 17, 2020

Conversation

xcgoner
Copy link
Contributor

@xcgoner xcgoner commented Jun 29, 2020

Issue #, if available:

Description of changes:

  1. In model_iteration_averaging.py, implemented model averaging across iterations during training instead of epochs after training
  2. Implemented 3 different averaging triggers: NTA (NTA_V1 is the ICLR version: https://openreview.net/pdf?id=SyyGPP0TZ, NTA_V2 is the arxiv version: https://arxiv.org/pdf/1708.02182.pdf), and Alpha Suffix (https://arxiv.org/pdf/1109.5647.pdf)
  3. Integrated both epoch averaging and iteration averaging in Trainer (mx/trainer/_base.py)
  4. Wrote test in test/trainer/test_model_iteration_averaging.py

The overall goal is to reduce the model variance.
We test iteration averaging on DeepAR anomaly detection (examples\anomaly_detection.py, electricity data)
We train the model with 20 different random seeds, and report the variance on the same batch of target sequences (take variance on each timestamp, and then take the average over the entire sequence and all samples)
The results are as follows:

n or alpha var var/mean std std/mean RMSE
SelectNBestMean 1 9552.24 0.508395 22.5279 0.0318269 414.924
SelectNBestMean 5 8236.13 0.41966 19.9947 0.0253164 411.92
NTA_V1 5 5888.36 0.387781 16.7624 0.0253107 412.792
NTA_V2 5 6422.11 0.394004 17.7947 0.0237186 416.328
Alpha_Suffix 0.2 5877.92 0.384664 16.6868 0.030484 408.711
Alpha_Suffix 0.4 5814.86 0.378298 16.6081 0.0290987 409.952

Although we haven't tuned the hyperparameters, we've already obtained smaller variance and better RMSE.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/model_iteration_averaging.py Outdated Show resolved Hide resolved
src/gluonts/mx/trainer/_base.py Outdated Show resolved Hide resolved
@benidis
Copy link
Contributor

benidis commented Jun 29, 2020

One basic question that we need to answer is what is the computational overhead by using the iteration-based averaging. It seems to me quite heavy. Could we have some basic experiments with running times without model averaging and with the various techniques?

@benidis
Copy link
Contributor

benidis commented Jun 29, 2020

Also, I am not sure I fully understand what are the numbers in the table. What does "take variance on each timestamp" mean exactly? The output of the model is a distribution. Which value did you use?

@xcgoner
Copy link
Contributor Author

xcgoner commented Jun 30, 2020

@benidis For each run, I use the mean as the prediction value of the target sequence. Then, for the same sequence, I run the training for multiple times with different random seeds, and measure the variance for each timestamp in this sequence across multiple runs, and then take the average over the entire sequence.
Say we use mean as the prediction, and I train 20 times which produces 20 predictions of the same sequence stored in a list prediction_list.
Then I report np.var(prediction_list, axis=0).mean()
I use this to measure the model variance, and hope iteration averaging can mitigate the influence of the random seeds in training.

@xcgoner
Copy link
Contributor Author

xcgoner commented Jun 30, 2020

@benidis I haven't recorded the training time yet, but with iteration averaging, the throughput looks similar to without it.
Basically, the overhead of iteration averaging should be no greater than adding 1 more sample in the training mini-batch. If the batch size is large, then the extra overhead could be ignored.

@benidis
Copy link
Contributor

benidis commented Jun 30, 2020

@xcgoner just a verification of how things work based on the code. Please correct me if I am wrong:

  1. The variations of NTA basically check the loss after each epoch and when we have seen at least n epochs and the loss of the current epoch is larger compared to the best (minimum) loss we have seen in a selected previous window (the window depends on the version) then we trigger the averaging. Now, the averaging is over all iterations inside an epoch and for all subsequent epochs. For example in the case where we have 100 batches per epoch, 100 epochs, if we set n=5 (as in your experiments) then we can trigger the averaging from the 6th epoch. Assume it is triggered at the 20th epoch (this can be actually way earlier with the arxiv version). Then the final averaged model will be an average of the remaining 80 x 100 (remaining epochs x batches per epoch) iterations that seems a bit too much.

  2. With the Alpha_suffix method we average all iterations of the percentage of epochs we select. For example, with the same setting as above and alpha=0.2 we average all the models of the last 20 epochs. This seems a bit more controlled, although if the best epoch is earlier then we will miss it which makes me think if we can combine somehow the best epoch information with model averaging, i.e., start averaging from the best epoch and on. This would probably require averaging from the beginning and when a new epoch is better overwrite the previous averaged model.

@benidis
Copy link
Contributor

benidis commented Jun 30, 2020

@xcgoner One comment about the variations of NTA: Consider the following simple example:

n=3
epochs=12
Losses = [5, 4, 3, 2, 2.01, 1, 1.01, 0.99, 1.03, 1.01, 1, 0.99]

The arxiv version (NTA_V2) at time t looks at the [t-n, t-1] interval of losses and if the current loss is larger than the minimum of that interval then the averaging is triggered. This is equivalent to the rule "after n epochs, the first time that an epoch will not improve then start averaging". In the above example this happens in the 5th epoch where the loss is 2.01.

On the other hand, the iclr version (NTA_V1) allows for a buffer of length n since it looks at the losses in [0:-t-n-1]. Therefore minimal variations will not trigger the averaging if there is still a trend in the loss and it has not reached a plateau. In the example above this method will trigger at the 10th epoch with loss 1.01.

Basically, the arxiv version can be triggered just by noise while the iclr version if the trend is not significant to overcome the noise (which is also an indication of a plateau). Overall the arxiv version seems too sensitive too me and the iclr way more robust.

As you said, GluonNLP/MXNet is using V2 but Salesforce/PyTorch is using V1 which is the implementation of the people that invented this method. Based on this and your experiments that already show some advantage of V1 I think we could discard entirely V2.

Any thoughts on this @lostella , @vafl ?

@benidis benidis requested review from lostella and vafl June 30, 2020 14:31
@xcgoner
Copy link
Contributor Author

xcgoner commented Jun 30, 2020

@benidis The intuition of iteration averaging is that, when the evaluation metrics stop to make much progress, the model will go back and forth randomly in a small region around the optimal point. Thus, taking the averaging over more models can cancel the random noise and reduce the random noise/variance. The more random samples (models) we use to take the average, the better for variance reduction. We just need to figure out when the training progress gets stuck and then the average will be triggered.
Alpha suffix is a relatively old strategy, but still widely used these days, which is also an important baseline. There are also some other variants based on alpha suffix, which I will implement later and test the performance. NTA is a recently developed strategy.

@xcgoner
Copy link
Contributor Author

xcgoner commented Jun 30, 2020

@benidis I also think NTA_V2 is the correct one, which matches the description from the paper. However, V1 is also showing some positive results in some cases.
You could refer to the issue I opened in GluonNLP: dmlc/gluon-nlp#1253, where I tested both strategies on AWD-LSTM model. That's why I could not decide which strategy to keep. So I simply include both in this PR.

@codecov-commenter
Copy link

codecov-commenter commented Jun 30, 2020

Codecov Report

Merging #901 into master will decrease coverage by 0.02%.
The diff coverage is 83.33%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #901      +/-   ##
==========================================
- Coverage   85.85%   85.82%   -0.03%     
==========================================
  Files         195      196       +1     
  Lines       11886    11989     +103     
==========================================
+ Hits        10205    10290      +85     
- Misses       1681     1699      +18     
Impacted Files Coverage Δ
src/gluonts/trainer.py 100.00% <ø> (ø)
src/gluonts/mx/trainer/_base.py 83.00% <50.00%> (-5.89%) ⬇️
...rc/gluonts/mx/trainer/model_iteration_averaging.py 91.66% <91.66%> (ø)
src/gluonts/mx/trainer/__init__.py 87.50% <100.00%> (+1.78%) ⬆️

@benidis
Copy link
Contributor

benidis commented Jul 1, 2020

@benidis The intuition of iteration averaging is that, when the evaluation metrics stop to make much progress, the model will go back and forth randomly in a small region around the optimal point. Thus, taking the averaging over more models can cancel the random noise and reduce the random noise/variance. The more random samples (models) we use to take the average, the better for variance reduction. We just need to figure out when the training progress gets stuck and then the average will be triggered.

I agree with the note that averaging more models is better if all of these models have converged to the same parameter-space valley. My point was that this needs some tuning because with n=5 that was set in the experiments, especially for NTA_V1 that triggers easily, that averaging would start too early. We should take into account that we have a learning rate scheduler as well and there could be plateaus before the learning rate adjustments. Starting averaging too early would result in averaging the wrong plateaus.

@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 1, 2020

@benidis The intuition of iteration averaging is that, when the evaluation metrics stop to make much progress, the model will go back and forth randomly in a small region around the optimal point. Thus, taking the averaging over more models can cancel the random noise and reduce the random noise/variance. The more random samples (models) we use to take the average, the better for variance reduction. We just need to figure out when the training progress gets stuck and then the average will be triggered.

I agree with the note that averaging more models is better if all of these models have converged to the same parameter-space valley. My point was that this needs some tuning because with n=5 that was set in the experiments, especially for NTA_V1 that triggers easily, that averaging would start too early. We should take into account that we have a learning rate scheduler as well and there could be plateaus before the learning rate adjustments. Starting averaging too early would result in averaging the wrong plateaus.

I guess you are talking about V2. And I agree that V2 is more sensitive to the randomness in training, which is opposite to the description in the paper, and that's why I think V1 is correct. But currently I have no choice but keep both of them.
I will combine V1 and V2 in a single class though.

@lostella
Copy link
Contributor

lostella commented Jul 1, 2020

@xcgoner what is RMSE in the table? I understand you’re running 20 trainings for each method, is that the best/worst/average RMSE? Or am I missing something?

@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 1, 2020

@xcgoner what is RMSE in the table? I understand your running 20 trainings for each method, is that the best/worst/average RMSE? Or am I missing something?

I use "mean" as prediction, compute RMSE with the target value, and then take the average over multiple runs.

@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 3, 2020

I recently added a new parameter "eta" in iteration averaging.
When eta=0, it's equivalent to model averaging with same weights.
I tuned some parameters and report the following results.
Note that this time, to save time I only run 10 times for each.
It seems that larger eta could improve RMSE, but in some cases makes variance larger

n,alpha eta var var/mean var/target std std/mean std/target RMSE
NTA_V1 5 0 5458.46 0.352691 1.19761 16.1872 0.0256935 0.0528544 413.327
NTA_V1 5 1 5708.1 0.339342 1.59306 16.1602 0.0214149 0.0545339 411.838
NTA_V1 5 2 5908.85 0.388605 1.73945 16.2166 0.0222913 0.0553765 409.101
NTA_V1 5 3 5888.47 0.398614 1.8202 16.2464 0.025246 0.0559798 408.909
NTA_V1 5 4 5900 0.421298 1.72878 16.2848 0.032133 0.0558771 409.077
NTA_V2 5 0 5744 0.360641 1.27146 17.0476 0.0223395 0.0540415 418.241
NTA_V2 5 1 5399.63 0.368877 1.26411 16.1605 0.0310405 0.0527432 411.307
NTA_V2 5 2 5740.23 0.344187 1.49426 16.1744 0.0222479 0.0541305 411.152
NTA_V2 5 3 5801.03 0.34815 1.54092 16.1734 0.0223275 0.0543931 409.525
NTA_V2 5 4 5964.81 0.392968 1.88651 16.269 0.0243954 0.0559012 409.076
Alpha_Suffix 0.2 0 6001.52 0.36584 1.51896 16.3647 0.0254731 0.0547222 409.258
Alpha_Suffix 0.2 1 5850.05 0.346833 1.90304 16.3297 0.0216271 0.055543 409.422
Alpha_Suffix 0.2 2 5823.71 0.349354 2.0767 16.3069 0.022622 0.0561412 409.361
Alpha_Suffix 0.2 3 5825.78 0.368147 2.33742 16.3022 0.0237156 0.0567312 408.816
Alpha_Suffix 0.2 4 5822.47 0.393744 2.29471 16.2869 0.0256854 0.0566127 408.65
Alpha_Suffix 0.4 0 5576.68 0.353762 1.53918 16.0972 0.0228856 0.054568 409.943
Alpha_Suffix 0.4 1 5768.16 0.3996 1.77385 16.1941 0.0245187 0.0559052 408.703
Alpha_Suffix 0.4 2 5982.12 0.402727 1.61943 16.3217 0.0293872 0.0549535 408.755
Alpha_Suffix 0.4 3 5853.89 0.351401 1.75941 16.3109 0.0231192 0.0552866 409.515
Alpha_Suffix 0.4 4 5811.82 0.342626 2.09126 16.2952 0.0213535 0.0561871 409.319
Alpha_Suffix 0.6 0 5336.78 0.350267 1.21213 16.0173 0.024014 0.0524983 412.528
Alpha_Suffix 0.6 1 5722.41 0.513281 1.65458 16.155 0.042226 0.0548066 411.144
Alpha_Suffix 0.6 2 5798.98 0.393895 1.82252 16.219 0.0232389 0.0554649 408.961
Alpha_Suffix 0.6 3 5886.16 0.446194 1.77374 16.2488 0.0332345 0.0557463 408.957
Alpha_Suffix 0.6 4 5976.12 0.55639 1.7828 16.3173 0.0524773 0.0557558 409.295
Alpha_Suffix 1 0 7815.14 0.438992 2.81825 19.9796 0.0272006 0.0619424 429.21
Alpha_Suffix 1 1 5392.7 0.409184 1.22162 16.468 0.0299096 0.0530986 415.042
Alpha_Suffix 1 2 5528.93 0.342512 1.30177 16.1637 0.024349 0.0529704 411.819
Alpha_Suffix 1 3 5876.35 0.39015 1.4833 16.2047 0.029879 0.054061 411.139
Alpha_Suffix 1 4 5837.22 0.383768 1.69184 16.188 0.0223831 0.0548478 409.833

@szhengac
Copy link
Contributor

szhengac commented Jul 3, 2020 via email

@benidis
Copy link
Contributor

benidis commented Jul 9, 2020

@xcgoner first of all thanks for all the experiments. One thing that is not really clear to me is what is the column to look at.

Obviously the ideal case is to have low variance and low rmse. But which one from the var, var/mean and var/target makes more sense to focus? Any thoughts @lostella? Do you think the geometric mean approach applies here? Just trying to make sense of these numbers in a fair way...

@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 9, 2020

@xcgoner first of all thanks for all the experiments. One thing that is not really clear to me is what is the column to look at.

Obviously the ideal case is to have low variance and low rmse. But which one from the var, var/mean and var/target makes more sense to focus? Any thoughts @lostella? Do you think the geometric mean approach applies here? Just trying to make sense of these numbers in a fair way...

Oh, sorry, I attach var/target and std/target by mistake. You don't need to compare that part.
I think var or var/mean is the most important criteria, and rmse has the second priority.
Basically, we want smaller var or var/mean, while controlling rmse in a reasonable range.

@xcgoner xcgoner requested a review from benidis July 15, 2020 21:12
@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 15, 2020

@benidis I've merged the 2 versions of NTA, and changed the definintion of "update_average_trigger", so that the trainer no longer needs to check whether the averaging strategy is NTA or alpha suffix.

@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 15, 2020

@benidis In unit test, it reports an error "conda: command not found", which I have no idea how to fix.
I've restarted the check workflow twice but the error is still there.
Could you help to check what's wrong?

@xcgoner
Copy link
Contributor Author

xcgoner commented Jul 16, 2020

@benidis I've resolved your recent comments. Please review them.
And somehow the conda error in ci test doesn't show up this time.

@benidis benidis merged commit aaad207 into awslabs:master Jul 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants