# Pruning ResNet

In [1]:
import logging
logging.basicConfig(level=logging.INFO)

import numpy as np
import pandas as pd

from chainer.serializers import save_npz
from chainercv.links import ResNet50

from chainerpruner import Pruner, Graph
from chainerpruner.serializers import load_npz
from chainerpruner.masks import NormMask
from chainerpruner.utils import calc_computational_cost

## model definition

In [2]:
x = np.zeros((1, 3, 224, 224), dtype=np.float32)
model = ResNet50(pretrained_model=None, n_class=1000, arch='he')

In [3]:
show_name = False
# show_name = True

In [4]:
if show_name:
    for name, _ in model.namedlinks(skipself=True):
        print(name)

In [5]:
cch = calc_computational_cost(model, x)
# cch.show_report(unit='G', mode='table') # details
df = pd.DataFrame([cch.total_report], index=['before pruning'])

In [6]:
# benchmark
result = %timeit -o -n 5 model(x)
df['time str'] = str(result)
df['time'] = result.average

219 ms ± 9.57 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [7]:
df

Unnamed: 0,flops,mread,mrw,mwrite,n_layers,name,type,time str,time
before pruning,3884871583,269242688,413171840,143929152,176,total,total,219 ms ± 9.57 ms per loop (mean ± std. dev. of...,0.219364


In [8]:
# dump weight
save_npz('/tmp/model.npz', model)

## pruning

In [9]:
percent = 0.8

In [10]:
# pruning target link names
target_layers = [
    '/res2/a/conv2/conv',
    '/res2/b1/conv2/conv',
    '/res2/b2/conv2/conv',
    '/res3/a/conv2/conv',
    '/res3/b1/conv2/conv',
    '/res3/b2/conv2/conv',
    '/res3/b3/conv2/conv',
    '/res4/a/conv2/conv',
    '/res4/b1/conv2/conv',
    '/res4/b2/conv2/conv',
    '/res4/b3/conv2/conv',
    '/res4/b4/conv2/conv',
    '/res4/b5/conv2/conv',
    '/res5/a/conv2/conv',
    '/res5/b1/conv2/conv',
    '/res5/b2/conv2/conv',
]

In [11]:
graph = Graph(model, x)
mask = NormMask(model, graph, target_layers, percent=percent)
pruner = Pruner(model, x, target_layers, mask)

In [12]:
pruner.apply_mask()
info = pruner.apply_rebuild()
info # pruning results

[{'name': '/res2/a/conv2/conv', 'before': 64, 'after': 13},
 {'name': '/res2/b1/conv2/conv', 'before': 64, 'after': 13},
 {'name': '/res2/b2/conv2/conv', 'before': 64, 'after': 13},
 {'name': '/res3/a/conv2/conv', 'before': 128, 'after': 26},
 {'name': '/res3/b1/conv2/conv', 'before': 128, 'after': 26},
 {'name': '/res3/b2/conv2/conv', 'before': 128, 'after': 26},
 {'name': '/res3/b3/conv2/conv', 'before': 128, 'after': 26},
 {'name': '/res4/a/conv2/conv', 'before': 256, 'after': 52},
 {'name': '/res4/b1/conv2/conv', 'before': 256, 'after': 52},
 {'name': '/res4/b2/conv2/conv', 'before': 256, 'after': 52},
 {'name': '/res4/b3/conv2/conv', 'before': 256, 'after': 52},
 {'name': '/res4/b4/conv2/conv', 'before': 256, 'after': 52},
 {'name': '/res4/b5/conv2/conv', 'before': 256, 'after': 52},
 {'name': '/res5/a/conv2/conv', 'before': 512, 'after': 103},
 {'name': '/res5/b1/conv2/conv', 'before': 512, 'after': 103},
 {'name': '/res5/b2/conv2/conv', 'before': 512, 'after': 103}]

In [13]:
save_npz('/tmp/model_pruned.npz', model)

In [14]:
cch = calc_computational_cost(model, x)
# cch.show_report(unit='G', mode='table') # details
df_ = pd.DataFrame([cch.total_report], index=['after pruning ({})'.format(percent)])

In [15]:
# benchmark
result = %timeit -o -n 5 model(x)
df_['time str'] = str(result)
df_['time'] = result.average
df_

177 ms ± 7.27 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


Unnamed: 0,flops,mread,mrw,mwrite,n_layers,name,type,time str,time
after pruning (0.8),1752625705,203835740,334568408,130732668,176,total,total,177 ms ± 7.27 ms per loop (mean ± std. dev. of...,0.177442


In [16]:
# merge results
df = df.append(df_)
df['time ratio'] = df['time'] / df.loc['before pruning']['time']
df[['flops', 'time str', 'time ratio', 'n_layers']]

Unnamed: 0,flops,time str,time ratio,n_layers
before pruning,3884871583,219 ms ± 9.57 ms per loop (mean ± std. dev. of...,1.0,176
after pruning (0.8),1752625705,177 ms ± 7.27 ms per loop (mean ± std. dev. of...,0.808894,176


## load weight

In [17]:
!ls -lha /tmp/model*.npz

-rw-r--r-- 1 docker docker 91M Dec 15 10:32 /tmp/model.npz
-rw-r--r-- 1 docker docker 45M Dec 15 10:32 /tmp/model_pruned.npz


In [18]:
model_new = ResNet50(pretrained_model=None, n_class=1000, arch='he')
load_npz('/tmp/model.npz', model_new)
model_new(x);

In [19]:
model_new = ResNet50(pretrained_model=None, n_class=1000, arch='he')
load_npz('/tmp/model_pruned.npz', model_new)
model_new(x);

INFO:chainerpruner.serializers.npz:load res2/a/conv3/conv/W: (256, 64, 1, 1) to (256, 13, 1, 1)
INFO:chainerpruner.serializers.npz:load res2/a/conv2/bn/gamma: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/a/conv2/bn/beta: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/a/conv2/bn/avg_mean: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/a/conv2/bn/avg_var: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/a/conv2/conv/W: (64, 64, 3, 3) to (13, 64, 3, 3)
INFO:chainerpruner.serializers.npz:load res2/b1/conv3/conv/W: (256, 64, 1, 1) to (256, 13, 1, 1)
INFO:chainerpruner.serializers.npz:load res2/b1/conv2/bn/gamma: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/b1/conv2/bn/beta: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/b1/conv2/bn/avg_mean: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/b1/conv2/bn/avg_var: (64,) to (13,)
INFO:chainerpruner.serializers.npz:load res2/b1/conv2/conv/W: (64, 64, 3, 3) 