This repository is for our NeurIPS 2022 paper:
What Makes a "Good" Data Augmentation in Knowledge Distillation -- A Statistical Perspective
Huan Wang1,2,†, Suhas Lohit2,∗, Mike Jones2, Yun Fu1
1Northeastern University, Boston, MA 2MERL, Cambridge, MA
†Work done when Huan was an intern at MERL
∗Corresponding author: slohit@merl.com
[TL;DR]
- We present a proven proposition to precisely answer "What makes a good data augmentation (DA) in knowledge distillation (KD)?": A good DA should reduce the covariance of the teacher-student cross-entropy.
- We present a practical metric that only needs the teacher to measure the "goodness" of a DA in KD: the stddev of teacher’s mean probability (shorted as T. stddev).
- Interestingly, T. stddev works very well in practice (on CIFAR100 and Tiny ImageNet), posing a strong correlation with student’s test loss, despite knowing nothing about the student -- see the right figure below.
- Based on the theory, we further propose an entropy-based data picking algorithm that can further boost prior SOTA DA scheme (CutMix) in KD, resulting in a new strong DA method, CutMixPick.
- Finally, we show how the theory can be utilized in practice to harvest considerable performance gains simply by using a stronger DA with prolonged training epochs.
Supported 9 Data Augmentation Methods
- Identity
- Flip
- Flip+Crop
- Cutout
- AutoAugment (CVPR'19)
- Mixup (ICLR'18)
- CutMix (ICCV'19)
- CutMixPick (S. ent.) (ours)
- CutMixPick (T. ent.) (ours)
git clone git@github.com:MingSun-Tse/Good-DA-in-KD.git
cd Good-DA-in-KD
sh scripts/set_up_pretrained_models.sh
We use python 3.9.6 and pytorch 1.9.0. All the dependecies are summarized in requirements.txt
. Create a conda env named Good-DA-in-KD_Py3.9.6
and enter it:
sh scripts/set_up_env.sh Good-DA-in-KD 3.9.6 requirements.txt
conda activate --no-stack Good-DA-in-KD_Py3.9.6
Below we give an example with vgg13/vgg8 pair on CIFAR100. Please refer to
scripts/S_test_loss_different_DA_cifar100.sh
scripts/S_test_loss_different_DA_tinyimagenet.sh
scripts/T_stddev_different_DA_cifar100.sh
scripts/T_stddev_different_DA_tinyimagenet.sh
for the complete scripts for all pairs on CIFAR100 and Tiny ImageNet.
# (1) Get the S. test loss with KD + different DA's
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode identity --project kd__vgg13vgg8__cifar100__identity
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode flip --project kd__vgg13vgg8__cifar100__flip
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode crop+flip --project kd__vgg13vgg8__cifar100__cropflip
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode cutout --project kd__vgg13vgg8__cifar100__cutout
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode autoaugment --project kd__vgg13vgg8__cifar100__autoaugment
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode mixup --project kd__vgg13vgg8__cifar100__mixup
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode cutmix --project kd__vgg13vgg8__cifar100__cutmix
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode cutmix_pick --mix_n_run 2 --cutmix_pick_criterion student_entropy --project kd__vgg13vgg8__cifar100__cutmix_pick_Sentropy
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --mix_mode cutmix_pick --mix_n_run 2 --cutmix_pick_criterion teacher_entropy --project kd__vgg13vgg8__cifar100__cutmix_pick_Tentropy
# (2) Get the T. stddev with KD + different DA's
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode identity --project kd__vgg13vgg8__cifar100__CheckTProbStd_identity
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode flip --project kd__vgg13vgg8__cifar100__CheckTProbStd_flip
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode crop+flip --project kd__vgg13vgg8__cifar100__CheckTProbStd_cropflip
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode cutout --project kd__vgg13vgg8__cifar100__CheckTProbStd_cutout
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode autoaugment --project kd__vgg13vgg8__cifar100__CheckTProbStd_autoaugment
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode mixup --project kd__vgg13vgg8__cifar100__CheckTProbStd_mixup
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode cutmix --project kd__vgg13vgg8__cifar100__CheckTProbStd_cutmix
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --finetune_student Experiments/*-141344/weights/ckpt.pth --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode cutmix_pick --mix_n_run 2 --cutmix_pick_criterion student_entropy --project kd__vgg13vgg8__cifar100__CheckTProbStd_cutmix_pick_Sentropy
python train_student.py --path_t ./save/models/vgg13_vanilla/ckpt_epoch_240.pth --distill kd --model_s vgg8 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode cutmix_pick --mix_n_run 2 --cutmix_pick_criterion teacher_entropy --project kd__vgg13vgg8__cifar100__CheckTProbStd_cutmix_pick_Tentropy
Meanwhile, note that the major results in our paper are Tabs. 3~8, where we document the T. stddev and S. test loss on 9 teacher-student pairs and 9 DA schemes. Each experiment is averaged at least 3 times. All the logs of these experiments have been released (only the log txts are released; checkpoints are omitted due to the large size. But if you want any of these checkpoints, feel free to reach out to Huan Wang at wang.huan@northeastern.edu).
We use smilelogging for logging. Each experiment is binded with a unique experiment ID and folder. The easist way to reproduce any experiment is to check the log.txt
in that experiment folder. Its path is <experiment_folder>/log/log.txt
. At the head of the log.txt
, we document the script of that experiment, e.g.,
cd /home3/wanghuan/Projects/KD-DA
CUDA_VISIBLE_DEVICES=1 python train_student.py --dataset tinyimagenet --path_t ./save/models_tinyimagenet_v2/wrn_40_2_vanilla/ckpt_epoch_240.pth --distill kd --model_s wrn_16_2 -r 0.1 -a 0.9 -b 0 --t_output_as_target_for_input_mix --lw_mix [1,0,1] --learning_rate 0 --fix_student --utils.ON --utils.check_ce_var --epochs 10 --mix_mode mixup --project kd__wrn_40_2wrn_16_2__tinyimagenet__CheckTProbStd_mixup
('alpha': 0.9) ('amp': False) ('batch_size': 64) ('bbox': rand_bbox) ('beta': 0.0) ('branch_dropout_rate': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) ('branch_layer_S': []) ('branch_layer_T': []) ('branch_width_S': 256) ('branch_width_T': 256) ('cache_ignore': ) ('ceiling_ratio_schedule': ) ('check_cutmix_label': False) ('CodeID': ['f4e606d']) ('crd_multiheads': False) ('cut_size': 16) ('cutmix_pick_criterion': kld) ('cutmix_pick_scheme': sort) ('d_z': 1000) ('DA_pick_base': None) ('dataset': tinyimagenet) ('debug': False) ('distill': kd) ('embed': original) ('entropy_log': None) ('epoch_factor': 0) ('epoch_stop_head_kd_loss': 10000000) ('epochs': 10) ('experiments_dir': Experiments) ('feat_dim': 128) ('finetune_student': ) ('fix_embed': False) ('fix_student': True) ('fix_T_heads': False) ('floor_ratio_schedule': ) ('gamma': 0.1) ('head_init': default) ('hint_layer': 2) ('init_epochs': 30) ('input_mix_no_kld_epoch': 1000000) ('kd_S': 4) ('kd_T': 4) ('learning_rate': 0.0) ('lr_DA': 0) ('lr_decay_epochs': [150, 180, 210]) ('lr_decay_rate': 0.1) ('lw_branch_ce': 0.1) ('lw_branch_kld': 0.9) ('lw_dcs': 0.5) ('lw_mix': [1.0, 0.0, 1.0]) ('mask_zero_ratio': 0.5) ('max_min_ratio': 5) ('mix_mode': mixup) ('mix_n_run': 1) ('mode': exact) ('model_path': ./save/student_model) ('model_s': wrn_16_2) ('model_s_pretrained': None) ('model_t': wrn_40_2) ('model_t_pretrained': None) ('modify_student_input': ) ('momentum': 0.9) ('n_branch_fc_S': 1) ('n_branch_fc_T': 1) ('n_patch': 4) ('n_pick': 64) ('nce_k': 16384) ('nce_m': 0.5) ('nce_t': 0.07) ('no_DA': False) ('num_workers': 8) ('online_augment': False) ('only_test': ) ('path_t': ./save/models_tinyimagenet_v2/wrn_40_2_vanilla/ckpt_epoch_240.pth) ('pretrained_embed': ) ('print_freq': 100) ('print_interval': 100) ('project_name': kd__wrn_40_2wrn_16_2__tinyimagenet__CheckTProbStd_mixup) ('ratio_CE_loss': 1) ('reinit_student': False) ('resume_ExpID': ) ('resume_student': ) ('s_branch_target': t_branch) ('save_crd_loss': False) ('save_entropy_log_step': 0) ('save_freq': -1) ('save_img_interval': 1000000000) ('SERVER': ) ('stack_input': False) ('t_output_as_target_for_input_mix': True) ('tb_freq': 500) ('tb_path': ./save/student_tensorboards) ('test_loader_in_train': False) ('test_teacher': False) ('train_linear_classifier': False) ('trial': 1) ('two_loader': False) ('update_data_interval': 2) ('update_data_start': 100) ('use_DA': 11) ('userip': ) ('utils': <utils.EmptyClass object at 0x7f70a2b187f0>) ('weight_decay': 0.0005) ('weight_decay_schedule': None)
By running that script, you should be able to reproduce our results (results may vary from what we reported within a reasonable range due to randomness).
We plot the scatter points of student test loss (S. test loss) vs. our proposed metric (T. stddev). Per our proposition, a lower T. stddev should lead to lower S. test loss, i.e., they should pose a positive correlation. This is verified in all of these plots -- three kinds of positive correlation coeffcients (Pearson, Spearman, Kendall) are reported with their p-values below. The p-values are far below 5%, suggesting the correlation is rather strong. Please refer to our paper for more results.
- [10/21/2022] Initial code release. Should be all set.
- Remove useless and deprecated code
In this code we heavily rely on the wonderful code of CRD. Great thanks to them! We also greatly thank the anounymous NeurIPS'22 reviewers for the constructive comments to help us improve the paper.
If our work or code helps you, please consider citing our paper (and Tian's CRD paper since our code builds upon theirs). Thank you!
@inproceedings{wang2022what,
author = {Huan Wang and Suhas Lohit and Michael Jones and Yun Fu},
title = {What Makes a "Good" Data Augmentation in Knowledge Distillation -- A Statistical Perspective},
booktitle = {NeurIPS},
year = {2022}
}
@inproceedings{tian2019crd,
title = {Contrastive Representation Distillation},
author = {Yonglong Tian and Dilip Krishnan and Phillip Isola},
booktitle = {ICLR},
year = {2020}
}