Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
ResNet50 Dynamic Surgery: new schedule with improved results
Browse files Browse the repository at this point in the history
Top1: 75.52% (-0.63% from TorchVision dense ResNet50)
Total sparsity: 82.6%
  • Loading branch information
nzmora committed Dec 3, 2018
1 parent a27aabe commit 37d5774
Showing 1 changed file with 203 additions and 0 deletions.
203 changes: 203 additions & 0 deletions examples/network_surgery/resnet50.network_surgery2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# This schedule follows the methodology proposed by Intel Labs China in the paper:
# Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
# NIPS 2016, https://arxiv.org/abs/1600.604493.
#
# Top1 is 75.518 (on Epoch: 99) vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
# Total sparsity: 82.6%
#
# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.001 --compress=resnet50.network_surgery2.yaml --validation-size=0 --masks-sparsity --num-best-scores=10
#
#
# Parameters:
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# | | Name | Shape | NNZ (dense) | NNZ (sparse) | Cols (%) | Rows (%) | Ch (%) | 2D (%) | 3D (%) | Fine (%) | Std | Mean | Abs-Mean |
# |----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
# | 0 | module.conv1.weight | (64, 3, 7, 7) | 9408 | 9408 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.10739 | -0.00040 | 0.06567 |
# | 1 | module.layer1.0.conv1.weight | (64, 64, 1, 1) | 4096 | 812 | 0.00000 | 0.00000 | 3.12500 | 80.17578 | 7.81250 | 80.17578 | 0.05457 | -0.00405 | 0.02019 |
# | 2 | module.layer1.0.conv2.weight | (64, 64, 3, 3) | 36864 | 5052 | 0.00000 | 0.00000 | 7.81250 | 51.00098 | 6.25000 | 86.29557 | 0.02182 | 0.00054 | 0.00697 |
# | 3 | module.layer1.0.conv3.weight | (256, 64, 1, 1) | 16384 | 2477 | 0.00000 | 0.00000 | 6.25000 | 84.88159 | 13.28125 | 84.88159 | 0.02654 | 0.00022 | 0.00923 |
# | 4 | module.layer1.0.downsample.0.weight | (256, 64, 1, 1) | 16384 | 2975 | 0.00000 | 0.00000 | 1.56250 | 81.84204 | 14.06250 | 81.84204 | 0.04410 | -0.00247 | 0.01580 |
# | 5 | module.layer1.1.conv1.weight | (64, 256, 1, 1) | 16384 | 2026 | 0.00000 | 0.00000 | 14.45312 | 87.63428 | 6.25000 | 87.63428 | 0.02121 | 0.00072 | 0.00704 |
# | 6 | module.layer1.1.conv2.weight | (64, 64, 3, 3) | 36864 | 4064 | 0.00000 | 0.00000 | 6.25000 | 52.31934 | 0.00000 | 88.97569 | 0.01952 | 0.00019 | 0.00595 |
# | 7 | module.layer1.1.conv3.weight | (256, 64, 1, 1) | 16384 | 1997 | 0.00000 | 0.00000 | 0.00000 | 87.81128 | 5.85938 | 87.81128 | 0.02324 | 0.00021 | 0.00751 |
# | 8 | module.layer1.2.conv1.weight | (64, 256, 1, 1) | 16384 | 2994 | 0.00000 | 0.00000 | 9.37500 | 81.72607 | 0.00000 | 81.72607 | 0.02169 | -0.00005 | 0.00874 |
# | 9 | module.layer1.2.conv2.weight | (64, 64, 3, 3) | 36864 | 4551 | 0.00000 | 0.00000 | 0.00000 | 45.41016 | 0.00000 | 87.65462 | 0.02076 | -0.00029 | 0.00698 |
# | 10 | module.layer1.2.conv3.weight | (256, 64, 1, 1) | 16384 | 1938 | 0.00000 | 0.00000 | 0.00000 | 88.17139 | 10.15625 | 88.17139 | 0.02266 | -0.00103 | 0.00724 |
# | 11 | module.layer2.0.conv1.weight | (128, 256, 1, 1) | 32768 | 5757 | 0.00000 | 0.00000 | 6.25000 | 82.43103 | 0.00000 | 82.43103 | 0.02551 | -0.00083 | 0.00988 |
# | 12 | module.layer2.0.conv2.weight | (128, 128, 3, 3) | 147456 | 23222 | 0.00000 | 0.00000 | 0.00000 | 43.48755 | 0.00000 | 84.25157 | 0.01525 | -0.00010 | 0.00572 |
# | 13 | module.layer2.0.conv3.weight | (512, 128, 1, 1) | 65536 | 6978 | 0.00000 | 0.00000 | 0.00000 | 89.35242 | 28.90625 | 89.35242 | 0.01970 | 0.00022 | 0.00584 |
# | 14 | module.layer2.0.downsample.0.weight | (512, 256, 1, 1) | 131072 | 13839 | 0.00000 | 0.00000 | 0.00000 | 89.44168 | 14.06250 | 89.44168 | 0.01643 | -0.00022 | 0.00459 |
# | 15 | module.layer2.1.conv1.weight | (128, 512, 1, 1) | 65536 | 6780 | 0.00000 | 0.00000 | 17.18750 | 89.65454 | 0.00000 | 89.65454 | 0.01183 | 0.00018 | 0.00345 |
# | 16 | module.layer2.1.conv2.weight | (128, 128, 3, 3) | 147456 | 15531 | 0.00000 | 0.00000 | 0.00000 | 60.41260 | 2.34375 | 89.46737 | 0.01378 | 0.00027 | 0.00402 |
# | 17 | module.layer2.1.conv3.weight | (512, 128, 1, 1) | 65536 | 6229 | 0.00000 | 0.00000 | 0.00000 | 90.49530 | 19.72656 | 90.49530 | 0.01613 | -0.00081 | 0.00447 |
# | 18 | module.layer2.2.conv1.weight | (128, 512, 1, 1) | 65536 | 9000 | 0.00000 | 0.00000 | 1.95312 | 86.26709 | 0.00000 | 86.26709 | 0.01634 | -0.00031 | 0.00554 |
# | 19 | module.layer2.2.conv2.weight | (128, 128, 3, 3) | 147456 | 16032 | 0.00000 | 0.00000 | 0.00000 | 52.87476 | 0.00000 | 89.12760 | 0.01431 | -0.00007 | 0.00440 |
# | 20 | module.layer2.2.conv3.weight | (512, 128, 1, 1) | 65536 | 6783 | 0.00000 | 0.00000 | 0.00000 | 89.64996 | 5.85938 | 89.64996 | 0.01736 | -0.00006 | 0.00516 |
# | 21 | module.layer2.3.conv1.weight | (128, 512, 1, 1) | 65536 | 8544 | 0.00000 | 0.00000 | 1.75781 | 86.96289 | 0.00000 | 86.96289 | 0.01625 | -0.00028 | 0.00555 |
# | 22 | module.layer2.3.conv2.weight | (128, 128, 3, 3) | 147456 | 23301 | 0.00000 | 0.00000 | 0.00000 | 33.40454 | 0.00000 | 84.19800 | 0.01532 | -0.00025 | 0.00578 |
# | 23 | module.layer2.3.conv3.weight | (512, 128, 1, 1) | 65536 | 7932 | 0.00000 | 0.00000 | 0.00000 | 87.89673 | 17.57812 | 87.89673 | 0.01693 | -0.00015 | 0.00545 |
# | 24 | module.layer3.0.conv1.weight | (256, 512, 1, 1) | 131072 | 19673 | 0.00000 | 0.00000 | 0.00000 | 84.99069 | 0.00000 | 84.99069 | 0.02137 | -0.00038 | 0.00765 |
# | 25 | module.layer3.0.conv2.weight | (256, 256, 3, 3) | 589824 | 101140 | 0.00000 | 0.00000 | 0.00000 | 47.00165 | 0.00000 | 82.85251 | 0.01212 | -0.00015 | 0.00467 |
# | 26 | module.layer3.0.conv3.weight | (1024, 256, 1, 1) | 262144 | 29483 | 0.00000 | 0.00000 | 0.00000 | 88.75313 | 6.05469 | 88.75313 | 0.01549 | 0.00009 | 0.00486 |
# | 27 | module.layer3.0.downsample.0.weight | (1024, 512, 1, 1) | 524288 | 52743 | 0.00000 | 0.00000 | 0.00000 | 89.94007 | 4.98047 | 89.94007 | 0.01071 | 0.00006 | 0.00310 |
# | 28 | module.layer3.1.conv1.weight | (256, 1024, 1, 1) | 262144 | 28594 | 0.00000 | 0.00000 | 8.10547 | 89.09225 | 0.00000 | 89.09225 | 0.01041 | -0.00005 | 0.00319 |
# | 29 | module.layer3.1.conv2.weight | (256, 256, 3, 3) | 589824 | 65069 | 0.00000 | 0.00000 | 0.00000 | 54.10919 | 0.00000 | 88.96807 | 0.00993 | -0.00002 | 0.00310 |
# | 30 | module.layer3.1.conv3.weight | (1024, 256, 1, 1) | 262144 | 27368 | 0.00000 | 0.00000 | 0.00000 | 89.55994 | 1.95312 | 89.55994 | 0.01346 | -0.00056 | 0.00400 |
# | 31 | module.layer3.2.conv1.weight | (256, 1024, 1, 1) | 262144 | 26238 | 0.00000 | 0.00000 | 1.75781 | 89.99100 | 0.00000 | 89.99100 | 0.01042 | -0.00007 | 0.00305 |
# | 32 | module.layer3.2.conv2.weight | (256, 256, 3, 3) | 589824 | 67618 | 0.00000 | 0.00000 | 0.00000 | 45.94727 | 0.00000 | 88.53590 | 0.00971 | -0.00023 | 0.00312 |
# | 33 | module.layer3.2.conv3.weight | (1024, 256, 1, 1) | 262144 | 28073 | 0.00000 | 0.00000 | 0.00000 | 89.29100 | 0.97656 | 89.29100 | 0.01248 | -0.00014 | 0.00381 |
# | 34 | module.layer3.3.conv1.weight | (256, 1024, 1, 1) | 262144 | 27645 | 0.00000 | 0.00000 | 0.48828 | 89.45427 | 0.00000 | 89.45427 | 0.01131 | -0.00005 | 0.00343 |
# | 35 | module.layer3.3.conv2.weight | (256, 256, 3, 3) | 589824 | 69321 | 0.00000 | 0.00000 | 0.00000 | 44.19861 | 0.00000 | 88.24717 | 0.00961 | -0.00017 | 0.00315 |
# | 36 | module.layer3.3.conv3.weight | (1024, 256, 1, 1) | 262144 | 29057 | 0.00000 | 0.00000 | 0.00000 | 88.91563 | 3.61328 | 88.91563 | 0.01201 | -0.00033 | 0.00376 |
# | 37 | module.layer3.4.conv1.weight | (256, 1024, 1, 1) | 262144 | 28934 | 0.00000 | 0.00000 | 0.09766 | 88.96255 | 0.00000 | 88.96255 | 0.01172 | -0.00016 | 0.00366 |
# | 38 | module.layer3.4.conv2.weight | (256, 256, 3, 3) | 589824 | 70785 | 0.00000 | 0.00000 | 0.00000 | 44.31305 | 0.00000 | 87.99896 | 0.00958 | -0.00025 | 0.00319 |
# | 39 | module.layer3.4.conv3.weight | (1024, 256, 1, 1) | 262144 | 29261 | 0.00000 | 0.00000 | 0.00000 | 88.83781 | 1.46484 | 88.83781 | 0.01205 | -0.00054 | 0.00379 |
# | 40 | module.layer3.5.conv1.weight | (256, 1024, 1, 1) | 262144 | 30074 | 0.00000 | 0.00000 | 0.00000 | 88.52768 | 0.00000 | 88.52768 | 0.01263 | -0.00009 | 0.00405 |
# | 41 | module.layer3.5.conv2.weight | (256, 256, 3, 3) | 589824 | 72988 | 0.00000 | 0.00000 | 0.00000 | 44.97986 | 0.00000 | 87.62546 | 0.00984 | -0.00028 | 0.00332 |
# | 42 | module.layer3.5.conv3.weight | (1024, 256, 1, 1) | 262144 | 31194 | 0.00000 | 0.00000 | 0.00000 | 88.10043 | 1.66016 | 88.10043 | 0.01284 | -0.00089 | 0.00420 |
# | 43 | module.layer4.0.conv1.weight | (512, 1024, 1, 1) | 524288 | 114432 | 0.00000 | 0.00000 | 0.00000 | 78.17383 | 0.00000 | 78.17383 | 0.01710 | -0.00038 | 0.00754 |
# | 44 | module.layer4.0.conv2.weight | (512, 512, 3, 3) | 2359296 | 461529 | 0.00000 | 0.00000 | 0.00000 | 41.99524 | 0.00000 | 80.43785 | 0.00872 | -0.00015 | 0.00370 |
# | 45 | module.layer4.0.conv3.weight | (2048, 512, 1, 1) | 1048576 | 190377 | 0.00000 | 0.00000 | 0.00000 | 81.84423 | 0.00000 | 81.84423 | 0.01097 | -0.00013 | 0.00443 |
# | 46 | module.layer4.0.downsample.0.weight | (2048, 1024, 1, 1) | 2097152 | 296214 | 0.00000 | 0.00000 | 0.00000 | 85.87542 | 0.00000 | 85.87542 | 0.00690 | -0.00001 | 0.00243 |
# | 47 | module.layer4.1.conv1.weight | (512, 2048, 1, 1) | 1048576 | 235460 | 0.00000 | 0.00000 | 0.00000 | 77.54478 | 0.00000 | 77.54478 | 0.01123 | -0.00028 | 0.00503 |
# | 48 | module.layer4.1.conv2.weight | (512, 512, 3, 3) | 2359296 | 569044 | 0.00000 | 0.00000 | 0.00000 | 27.84805 | 0.00000 | 75.88077 | 0.00897 | -0.00042 | 0.00423 |
# | 49 | module.layer4.1.conv3.weight | (2048, 512, 1, 1) | 1048576 | 193763 | 0.00000 | 0.00000 | 0.00000 | 81.52132 | 0.00000 | 81.52132 | 0.01092 | 0.00017 | 0.00445 |
# | 50 | module.layer4.2.conv1.weight | (512, 2048, 1, 1) | 1048576 | 254128 | 0.00000 | 0.00000 | 0.00000 | 75.76447 | 0.00000 | 75.76447 | 0.01357 | -0.00013 | 0.00634 |
# | 51 | module.layer4.2.conv2.weight | (512, 512, 3, 3) | 2359296 | 537393 | 0.00000 | 0.00000 | 0.00000 | 47.85385 | 0.00000 | 77.22232 | 0.00767 | -0.00029 | 0.00354 |
# | 52 | module.layer4.2.conv3.weight | (2048, 512, 1, 1) | 1048576 | 162045 | 0.00000 | 0.00000 | 0.00000 | 84.54618 | 0.14648 | 84.54618 | 0.00990 | 0.00027 | 0.00362 |
# | 53 | module.fc.weight | (1000, 2048) | 2048000 | 396407 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 80.64419 | 0.03125 | 0.00427 | 0.01213 |
# | 54 | Total sparsity: | - | 25502912 | 4434272 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 82.61268 | 0.00000 | 0.00000 | 0.00000 |
# +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
# 2018-11-14 09:44:15,161 - Total sparsity: 82.61
#
# 2018-11-14 09:44:15,304 - --- validate (epoch=99)-----------
# 2018-11-14 09:44:15,305 - 50000 samples (256 per mini-batch)
# 2018-11-14 09:44:34,609 - Epoch: [99][ 50/ 195] Loss 0.697465 Top1 81.437500 Top5 95.703125
# 2018-11-14 09:44:42,914 - Epoch: [99][ 100/ 195] Loss 0.816492 Top1 78.804688 Top5 94.542969
# 2018-11-14 09:44:51,516 - Epoch: [99][ 150/ 195] Loss 0.930595 Top1 76.380208 Top5 93.135417
# 2018-11-14 09:44:58,448 - ==> Top1: 75.518 Top5: 92.620 Loss: 0.975
#
# 2018-11-14 09:44:58,508 - ==> Best Top1: 75.518 on Epoch: 99
# 2018-11-14 09:44:58,508 - ==> Best Top1: 75.494 on Epoch: 83
# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.480 on Epoch: 89
# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.462 on Epoch: 91
# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.454 on Epoch: 97
# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.452 on Epoch: 93
# 2018-11-14 09:44:58,509 - ==> Best Top1: 75.450 on Epoch: 96
# 2018-11-14 09:44:58,510 - ==> Best Top1: 75.448 on Epoch: 90
# 2018-11-14 09:44:58,510 - ==> Best Top1: 75.438 on Epoch: 94
# 2018-11-14 09:44:58,510 - ==> Best Top1: 75.436 on Epoch: 73
# 2018-11-14 09:44:58,510 - Saving checkpoint to: logs/resnet50_lr_0.001_mult_0.005___2018.11.12-041119/resnet50_lr_0.001_mult_0.005_checkpoint.pth.tar
# 2018-11-14 09:44:59,539 - --- test ---------------------
# 2018-11-14 09:44:59,540 - 50000 samples (256 per mini-batch)
# 2018-11-14 09:45:18,661 - Test: [ 50/ 195] Loss 0.697465 Top1 81.437500 Top5 95.703125
# 2018-11-14 09:45:27,176 - Test: [ 100/ 195] Loss 0.816492 Top1 78.804688 Top5 94.542969
# 2018-11-14 09:45:36,202 - Test: [ 150/ 195] Loss 0.930595 Top1 76.380208 Top5 93.135417
# 2018-11-14 09:45:43,168 - ==> Top1: 75.518 Top5: 92.620 Loss: 0.975#
# --- validate (epoch=359)-----------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.480 Top5: 99.600 Loss: 0.363
#
# ==> Best Top1: 91.790 (0.0 sparsity) on Epoch: 181
#
# Saving checkpoint to: logs/2018.10.31-232827/checkpoint.pth.tar
# --- test ---------------------
# 10000 samples (256 per mini-batch)
# ==> Top1: 91.480 Top5: 99.600 Loss: 0.363
#
#
# Log file for this run: /home/cvds_lab/nzmora/pytorch_workspace/distiller/examples/classifier_compression/logs/2018.10.31-232827/2018.10.31-232827.log
#
# real 64m27.317s
# user 118m46.020s
# sys 14m3.627s

version: 1
pruners:
pruner1:
class: SplicingPruner
low_thresh_mult: 0.9 # 0.6
hi_thresh_mult: 1.1 # 0.7
sensitivity_multiplier: 0.005 # 0.015
sensitivities:
#'module.conv1.weight': 0.60
module.layer1.0.conv1.weight: 0.10
module.layer1.0.conv2.weight: 0.40
module.layer1.0.conv3.weight: 0.40
module.layer1.0.downsample.0.weight: 0.20
module.layer1.1.conv1.weight: 0.60
module.layer1.1.conv2.weight: 0.60
module.layer1.1.conv3.weight: 0.60
module.layer1.2.conv1.weight: 0.30
module.layer1.2.conv2.weight: 0.60
module.layer1.2.conv3.weight: 0.60

module.layer2.0.conv1.weight: 0.30
module.layer2.0.conv2.weight: 0.40
module.layer2.0.conv3.weight: 0.60
module.layer2.0.downsample.0.weight: 0.50
module.layer2.1.conv1.weight: 0.60
module.layer2.1.conv2.weight: 0.60
module.layer2.1.conv3.weight: 0.60
module.layer2.2.conv1.weight: 0.40
module.layer2.2.conv2.weight: 0.60
module.layer2.2.conv3.weight: 0.60
module.layer2.3.conv1.weight: 0.50
module.layer2.3.conv2.weight: 0.40
module.layer2.3.conv3.weight: 0.50

module.layer3.0.conv1.weight: 0.40
module.layer3.0.conv2.weight: 0.30
module.layer3.0.conv3.weight: 0.60
module.layer3.0.downsample.0.weight: 0.60
module.layer3.1.conv1.weight: 0.60
module.layer3.1.conv2.weight: 0.60
module.layer3.1.conv3.weight: 0.60
module.layer3.2.conv1.weight: 0.60
module.layer3.2.conv2.weight: 0.60
module.layer3.2.conv3.weight: 0.60
module.layer3.3.conv1.weight: 0.60
module.layer3.3.conv2.weight: 0.60
module.layer3.3.conv3.weight: 0.60
module.layer3.4.conv1.weight: 0.60
module.layer3.4.conv2.weight: 0.60
module.layer3.4.conv3.weight: 0.60
module.layer3.5.conv1.weight: 0.60
module.layer3.5.conv2.weight: 0.60
module.layer3.5.conv3.weight: 0.60

module.layer4.0.conv1.weight: 0.20
module.layer4.0.conv2.weight: 0.30
module.layer4.0.conv3.weight: 0.30
module.layer4.0.downsample.0.weight: 0.40
module.layer4.1.conv1.weight: 0.15
module.layer4.1.conv2.weight: 0.15
module.layer4.1.conv3.weight: 0.30
module.layer4.2.conv1.weight: 0.15
module.layer4.2.conv2.weight: 0.30
module.layer4.2.conv3.weight: 0.45
module.fc.weight: 0.50

lr_schedulers:
training_lr:
class: StepLR
step_size: 45
gamma: 0.10

policies:
- pruner:
instance_name: pruner1
args:
keep_mask: True
#mini_batch_pruning_frequency: 1
mask_on_forward_only: True
starting_epoch: 0
ending_epoch: 47
frequency: 1


- lr_scheduler:
instance_name: training_lr
starting_epoch: 0
ending_epoch: 400
frequency: 1

0 comments on commit 37d5774

Please sign in to comment.