Skip to content

GNN-RL Compression: Topology-Aware Network Pruning using Multi-stage Graph Embedding and Reinforcement Learning

Notifications You must be signed in to change notification settings

yusx-swapp/GNN-RL-Model-Compression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

92 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GNN-RL-Model-Compression

GNN-RL Compression: Topology-Aware Network Pruning using Multi-stage Graph Embedding and Reinforcement Learning

Website

Dependencies

Current code base is tested under following environment:

  1. Python 3.8
  2. PyTorch 1.8.0 (cuda 11.1)
  3. torchvision 0.7.0
  4. torch-geometric 1.6.1

GNN-RL Channel Pruning

In this work, we compress DNNs by performing FLOPs-constrained channel pruning (prunes Conv filters) on Conv layers. The GNN-RL has been tested on over-parameterized and mobile-friendly DNNs with different datasets (CIFAR-10 and ImageNet).

CIFAR-10

In this subsection, DNNs are trained on CIFAR-10 since we observe a positive correlation between the pre-fine-tune accuracy and the post-fine-tuning accuracy. Pruning policies that obtain higher validation accuracy correspondingly have higher fine-tuned accuracy. It enables us to predict final model accuracy without fine-tuning, which results in an efficient and faster policy exploration.

To search the strategy on ResNet-110 with channel pruning (filter pruning) on Conv layers, and prunes 20% FLOPs reduction, by running:

python -W ignore gnnrl_network_pruning.py --dataset cifar10 --model resnet110 --compression_ratio 0.2 --log_dir ./logs

To search the strategy on ResNet-56 with channel pruning (filter pruning) on Conv layers, and prunes 30% FLOPs reduction, by running:

python -W ignore gnnrl_network_pruning.py --dataset cifar10 --model resnet56 --compression_ratio 0.3 --log_dir ./logs

ImageNet (ILSVRC-2012)

To evaluate the GNN-RL on the ImageNet (ILSVRC-2012), you need to first download the dataset from ImageNet and export the data.

Since the validation accuracy on the ImageNet dataset is sensitive to the compression ratio, with high compression ratios, the accuracy drops considerably without fine-tuning (in some cases, the pruned model without fine-tuning has less than 1% validation accuracy). We highly recommend that you decompose the pruning into several stages. For instance, to obtain a 49% FLOPs model, prune the target DNN two times, each with 70% FLOPs constraint (i.e., 70% FLOPs * 70% FLOPs = 49% FLOPs).

If you have enough GPU resources, we also recommend you enable fine-tuning process on each RL search episode to ensure that the GNN-RL gets a valuable reward.

To search the strategy on VGG-16 with channel pruning on convolutional layers and fine-grained pruning on dense layers, and prunes 80% FLOPs reduction on convolutional layers, by running:

python -W ignore gnnrl_network_pruning.py --dataset imagenet --model vgg16 --compression_ratio 0.8 --log_dir ./logs --data_root [your dataset dir] 

To search the strategy on MobileNet-V1 with channel pruning on convolutional layers and fine-grained pruning on dense layers, and prunes 25% FLOPs reduction on convolutional layers, by running:

python -W ignore gnnrl_network_pruning.py --dataset imagenet --model mobilenet --compression_ratio 0.25 --val_size 5000  --log_dir ./logs --data_root [your imagenet dataset dir]

Pruning tools

We apply the PyTorch built-in pruning tools torch.nn.utils.prune to prune a given DNN. This package prunes a neural network by mask those pruned weights, and it does not accelerate the inference. If you want to accelerate the inference or save memory, please discard those weights with zero-masks.

We also provide functions for you to extract these weights. For example, the VGG-16, you can extract weights and evaluate it by run:

python gnnrl_real_pruning.py --dataset imagenet --model vgg16 --data_root [your imagenet data dir] --ckpt_path data/pretrained_models

before you run the above command, please download the pre-trained weights from google drive, and move them to data/pretrained_models.

Fine-tuning

To fine-tune the pruned 50%FLOPs ResNet-110, by running:

python -W ignore gnnrl_fine_tune.py \
    --model=resnet110 \
    --dataset=cifar10 \
    --lr=0.005 \
    --n_gpu=4 \
    --batch_size=256 \
    --n_worker=32 \
    --lr_type=cos \
    --n_epoch=200 \
    --wd=4e-5 \
    --seed=2018 \
    --data_root=data/datasets \
    --ckpt_path=[pruned model checkpoint path] \
    --finetuning

Evaluate the compressed Model

After searching, we can evaluate the compressed Model, which is saved on the default directory ./logs. We also provide the pruned and fine-tuned model in the google drive for you to evaluate them. You can download and move them to data/pretrained_models. If we want to evaluate the performance of compressed Models py running:

python -W ignore gnnrl_fine_tune.py \
    --model=[model name] \
    --dataset=cifar10 \
    --n_gpu=4 \
    --batch_size=256 \
    --n_worker=32 \
    --data_root=data/datasets \
    --ckpt_path=[pruned model checkpoint path] \
    --eval
     

Results on CIFAR-10

Models FLOPs ratio Top1 Acc. (%) Dataset
ResNet-110 50% FLOPs 94.31 CIFAR-10
ResNet-56 50% FLOPs 93.49 CIFAR-10
ResNet-44 50% FLOPs 93.23 CIFAR-10

Results on ImageNet

Models FLOPs ratio Top1 Acc. (%) \delta Acc. Dataset
MobileNet-v1 40% FLOPs 69.50 -1.40 ImageNet
MobileNet-v1 70% FLOPs 70.70 -0.20 ImageNet
MobileNet-v2 58% FLOPs 70.04 -1.83 ImageNet
VGG-16 20% FLOPs 70.992 +0.49 ImageNet
ResNet-50 47% FLOPs 74.28 -1.82 ImageNet
ResNet-18 50% FLOPs 68.66 -1.10 ImageNet

Note: We will continuously update results on ImageNet and add support for other popular networks.

About

GNN-RL Compression: Topology-Aware Network Pruning using Multi-stage Graph Embedding and Reinforcement Learning

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages