Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supernet training #4

Closed
pawopawo opened this issue Sep 21, 2019 · 15 comments
Closed

supernet training #4

pawopawo opened this issue Sep 21, 2019 · 15 comments
Labels
good first issue Good for newcomers

Comments

@pawopawo
Copy link

Training an updated version of the supernet, resulting in the following error:

File "train_imagenet.py", line 512, in <module>
    main()
  File "train_imagenet.py", line 508, in main
    train(context)
  File "train_imagenet.py", line 393, in train
    trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
  File "/usr/local/lib/python3.6/site-packages/mxnet-1.5.0-py3.6.egg/mxnet/gluon/trainer.py", line 100, in __init__
    self._contexts = self._check_contexts()
  File "/usr/local/lib/python3.6/site-packages/mxnet-1.5.0-py3.6.egg/mxnet/gluon/trainer.py", line 113, in _check_contexts
    ctx = param.list_ctx()
  File "/usr/local/lib/python3.6/site-packages/mxnet-1.5.0-py3.6.egg/mxnet/gluon/parameter.py", line 539, in list_ctx
    raise RuntimeError("Parameter '%s' has not been initialized"%self.name)
RuntimeError: Parameter 'shufflenasoneshot0_features_fc_weight' has not been initialized
@CanyonWind
Copy link
Owner

Please pull this commit.

@pawopawo
Copy link
Author

When supernet_search.py is executed, the batch size is set to 200. I found that the same model has a change in test accuracy. If update_bn is removed, the accuracy does not change.

I think the reason may be that there is a problem with the processing of 20,000 images, and the train_data should not be shuffled.

@CanyonWind
Copy link
Owner

You are right, very good observation! Indeed, if training data is shuffled, same architecture may have different performance in multiple runs because BN will be updated with different images.

I thought shuffle or not will be both fine because of the question 3 of this issue.

@pawopawo
Copy link
Author

pawopawo commented Sep 24, 2019

I found that if the train_data is not shuffled, the accuracy of the test obtained each time is still different. The update_bn function may have a little problem?

In addition, I re-read the paper. The reason for updating bn seems to be because the supernet training is unstable, so the update_bn should be used in supernet training instead of one shot model evaluation time?

@CanyonWind
Copy link
Owner

CanyonWind commented Sep 24, 2019

I found that if the train_data is not shuffled, the accuracy of the test obtained each time is still different. The update_bn function may have a little problem?

Did you use the search_supernet.py and set shuffle in train_data to false for evaluating same structure performance? If so, the problem might be that, after the first run of BN updating, the following runs of BN update are processed upon the previous runs' results. It can cause moving statistics inconsistency.

To eliminate this concern, I would suggest, in addition to changing train_data shuffle to false, also loading the pre-trained supernet parameters every time before update_bn. Please see the below pseudo code.

supernet = ShuffleNas()

for _ in range(repeat_times):
    # TODO: try to add this in the for loop
    supernet.load_parameters(supernet_params, ctx=context)
    # End of edition
    update_bn(supernet)
    val_acc = get_accuracy(supernet, fixed_block_choice, fixed_channel_choice)
    print(val_acc)

If it is still not working as you expected, then the NasBN part can be a problem. Any reproducible unit test case for the NasBN, like this, would be appreciated. I will help to debug on my best

@CanyonWind
Copy link
Owner

CanyonWind commented Sep 24, 2019

In addition, I re-read the paper. The reason for updating bn seems to be because the supernet training is unstable, so the update_bn should be used in supernet training instead of one shot model evaluation time?

The BN statistics is only found mentioned in the section 3.4. Evolutionary Architecture Search and I assumed it might be supposed to work like this. Sorry it's hard for me to answer on behalf of the paper...

@pawopawo
Copy link
Author

I solved the problem of the accuracy change after update_bn in the same model.

The 20,000 images should not be processed according to the preprocessing method of train_data. Because there is random_resized_crop, I follow the preprocessing method of test_data, and the accuracy is the same every time.

@pawopawo pawopawo reopened this Sep 27, 2019
@pawopawo
Copy link
Author

pawopawo commented Sep 27, 2019

shufflenas_supernet.log

python3 train_imagenet.py \
    --rec-train /ILSVRC2012_img_train_supernet_rec/_train.rec --rec-train-idx /ILSVRC2012_img_train_supernet_rec/_train.idx \
    --rec-val /supernet_val_rec/_val.rec --rec-val-idx /supernet_val_rec/_val.idx \
    --model ShuffleNas --mode imperative \
    --lr 0.5 --wd 0.00004 --lr-mode cosine --dtype float16\
    --num-epochs 120 --batch-size 256 --num-gpus 4 -j 30 \
    --label-smoothing --no-wd --warmup-epochs 10 --use-rec \
    --save-dir params_shufflenas_supernet --logging-file shufflenas_supernet.log --epoch-start-cs 60 \
    --use-se --cs-warm-up

After training supernet according to the above script, Val acc dropped sharply after 60 epochs, but train acc did not. Your shufflenas_supernet.log file results: [Epoch 118] validation: err-top1 = 0.371019 ,did you use channel search?

@CanyonWind
Copy link
Owner

CanyonWind commented Sep 27, 2019

Yes I did use channel selection. Please check whether you were using all channels in the test function. If so, pulling the latest code should help.

@pawopawo
Copy link
Author

Thank you!
After use all channels, the accuracy has become normal, but I don't understand why should use all channels when testing. Shouldn't do random search like block choice?

In addition, it should be better to change select_all_channels=opt.use_all_channels to select_all_channels=True in there.

@CanyonWind
Copy link
Owner

You are right, for the validation evaluation, we are not supposed to use all channels. And I think you might misunderstand the code. Before epoch_start_cs, use_all_channels is set to Ture and after epoch_start_cs the use_all_channels is set to False. So that the val accuracy dop in your experiment was caused by using all channels and was actually fixed by not using it.

@pawopawo
Copy link
Author

pawopawo commented Sep 28, 2019

python3 train_imagenet.py \
    --rec-train /ILSVRC2012_img_train_supernet_rec/_train.rec --rec-train-idx /ILSVRC2012_img_train_supernet_rec/_train.idx \
    --rec-val /supernet_val_rec/_val.rec --rec-val-idx /supernet_val_rec/_val.idx \
    --model ShuffleNas --mode imperative \
    --lr 0.5 --wd 0.00004 --lr-mode cosine --dtype float16\
    --num-epochs 120 --batch-size 256 --num-gpus 4 -j 30 \
    --label-smoothing --no-wd --warmup-epochs 10 --use-rec \
    --save-dir params_shufflenas_supernet_929_1048 --logging-file shufflenas_supernet_929_1048.log \
    --epoch-start-cs 0

Thank you, but if I follow the script above to run the program and start channel search at the beginning, the accuracy will not be trained, it is always 0.001.
shufflenas_supernet.log

@CanyonWind
Copy link
Owner

I've only been trying --epoch-start-cs 0 --use-se and It works okay. Indeed, after a quick experiment, --epoch-start-cs 0 alone somehow doesn't converge. I will take a look later on this. Thanks for the information

@CanyonWind
Copy link
Owner

CanyonWind commented Sep 30, 2019

Update: I did various experiments during the weekend and here is a short summary:

  1. With SE, Channel Selection with the full choices (0.2 ~ 2.0) can be used at the first epoch and it converges.
    • However, doing Channel Selection from the first epoch seems like harming the accuracy. Compared to the same se-supernet with first n epoch block selection alone and --cs-warm-up, the channel selection from-scratch se-supernet only reached ~33% training accuracy and the warmed up se-supernet reaches ~44%, both at ~70th epoch.
    • Another thing is that validation accuracy in the channel selection from-scratch se-supernet is always under 1%, while the warmed up se-supernet's validation accuracy looks reasonably increasing from 0.1% to 63%.
  2. Without SE, doing block selection alone in the first n epochs and using --cs-warm-up are necessary to make the model converge.

@pawopawo
Copy link
Author

pawopawo commented Oct 1, 2019

Thank you very much, I will try it, and the results will be promptly feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants