Skip to content

LeapLabTHU/L2W-DEN

main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 

Learning to Weight Samples for Dynamic Early-exiting Networks (ECCV 2022)

Yizeng Han* , Yifan Pu*, Zihang Lai, Chaofei Wang, Shiji Song, Junfeng Cao, Wenhui Huang, Chao Deng, Gao Huang.

*: Equal contribution.

Introduction

This repository contains the implementation of the paper, Learning to Weight Samples for Dynamic Early-exiting Networks (ECCV 2022). The proposed method adopts a weight prediction network to weight the training loss of different samples for dynamic early-exiting networks, such as MSDNet and RANet, and improves their performance in the dynamic early exiting scenario.

Overall idea

fig1

Training pipeline

fig2

Gradient flow of the meta-learning algorithm

fig3

Usage

Dependencies

  • Python: 3.8
  • Pytorch: 1.10.0
  • Torchvision: 0.11.0

Scripts

  • Train a MSDNet (5 exits, step=4) on ImageNet:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tools/main_imagenet_DDP.py \
--train_url YOUR_SAVE_PATH \
--data_url YOUR_DATA_PATH --data ImageNet --workers 64 --seed 0 \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--meta_net_hidden_size 500 --meta_net_num_layers 1 --meta_interval 100 --meta_lr 1e-4 --meta_weight_decay 1e-4 \
--epsilon 0.3 --target_p_index 15 --meta_net_input_type loss --constraint_dimension mat \
--epochs 100 --batch-size 4096 --lr 0.8 --lr-type cosine --print-freq 10
hfai python tools/main_imagenet_DDP_HF.py \
--train_url YOUR_SAVE_PATH \
--data_url YOUR_DATA_PATH --data ImageNet --workers 64 --seed 0 \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--meta_net_hidden_size 500 --meta_net_num_layers 1 --meta_interval 100 --meta_lr 1e-4 --meta_weight_decay 1e-4 \
--epsilon 0.3 --target_p_index 15 --meta_net_input_type loss --constraint_dimension mat \
--epochs 100 --batch-size 4096 --lr 0.8 --lr-type cosine --print-freq 10 \
-- --nodes=1 --name=YOUR_EXPERIMENT_NAME
  • Evaluate (anytime):
CUDA_VISIBLE_DEVICES=0 python tools/eval_imagenet.py \
--data ImageNet --batch-size 512 --workers 8 --seed 0 --print-freq 10 --evalmode anytime \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--data_url YOUR_DATA_PATH \
--train_url YOUR_SAVE_PATH \
--evaluate_from YOUR_CKPT_PATH
  • Evaluate (dynamic):
CUDA_VISIBLE_DEVICES=0 python tools/eval_imagenet.py \
--data ImageNet --batch-size 512 --workers 2 --seed 0 --print-freq 10 --evalmode dynamic \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--data_url YOUR_DATA_PATH 
--train_url YOUR_SAVE_PATH  \
--evaluate_from YOUR_CKPT_PATH

Results

  • CIFAR-10 and CIFAR-100

result_cifar

  • ImageNet

result_IN

Pre-trained Models on ImageNet

model config epochs labelsmooth acc_exit1 acc_exit2 acc_exit3 acc_exit4 acc_exit5 Checkpoint Link
step=4 100 N/A 59.54 67.22 71.03 72.33 73.93 Tsinghua Cloud / Google Drive
step=6 100 N/A 60.05 69.13 73.33 75.19 76.30 Tsinghua Cloud / Google Drive
step=7 100 N/A 59.24 69.65 73.94 75.66 76.72 Tsinghua Cloud / Google Drive
step=4 300 0.1 61.64 67.89 71.61 73.82 75.03 Tsinghua Cloud / Google Drive
step=6 300 0.1 61.41 70.70 74.38 75.80 76.66 Tsinghua Cloud / Google Drive
step=7 300 0.1 60.94 71.88 75.13 76.03 76.82 Tsinghua Cloud / Google Drive

Contact

If you have any questions, please feel free to contact the authors.

Yizeng Han: hanyz18@mails.tsinghua.edu.cn, yizeng38@gmail.com.

Yifan Pu: pyf20@mails.tsinghua.edu.cn, yifanpu98@126.com.

Ackowledgements

We use the pytorch implementation of MSDNet-PyTorch, RANet-PyTorch and IMTA in our experiments.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages