In [1]:
import caffe
import numpy as np
import shutil

In [2]:
root = "/media/vanilla/Fun/zhongxing/Test/"
model = "/media/vanilla/Fun/zhongxing/banben/TestModel.caffemodel"
prototxt = "/media/vanilla/Fun/zhongxing/banben/TestModel.prototxt"
caffe.set_mode_gpu()
net = caffe.Net(prototxt, model, caffe.TEST)

In [3]:
class Pruner(object):
    def __init__(self, net):
        self._net = net
        self.conv_data = {}
        
    def prune_res_conv(self, name, bottom=None, ref='conv3_1_1b'):
        self.conv_data[name] = self._prune_res(self._net.params[name], del_kernels=self.conv_data[bottom][2],not_del_filters=True,ref=self.conv_data[ref][2])
        
    def _prune_res(self, conv_param, del_kernels=None, not_del_filters=False,ref=None):
        weight, bias = conv_param
        weight = weight.data
        bias = bias.data
        origin_channels = weight.shape[0]
        print("weight original:"+str(weight.shape))
        
        if ref is not None:
            weight = np.delete(weight, ref, axis=0)
            bias = np.delete(bias, ref, axis=0)
        # delete filters
        if not not_del_filters:
#             abs_mean = np.abs(weight).mean(axis=(1,2,3))
#             del_filters = np.where(abs_mean < 8e-2)[0]
            sum_l1 = {}
            temp = []
            for i in range(weight.shape[0]):
                sum_l1[str(np.sum(abs(weight[i, :, :, :])))] = i # sum_l1存放每个卷积核的所有权重绝对值之和
            del_filters = []  # de_keral存放大于阈值的卷积核的序号
            for i in sum_l1:
                temp.append(float(i))
            temp.sort()
#             print(temp)
            for i in range(int(len(temp)*0.07)):
                del_filters.append(sum_l1[str(temp[i])])
            weight = np.delete(weight, del_filters, axis=0)
            bias = np.delete(bias, del_filters, axis=0)
        else:
            del_filters = np.array([])
        print("weight cut nums:"+str(weight.shape))
        # delete kernels
        if del_kernels is not None:
            weight = np.delete(weight, del_kernels, axis=1)
        print("weight cut channels:"+str(weight.shape))
            
        return weight, bias, del_filters, origin_channels
        
        
        
        
    def _prune(self, conv_param, del_kernels=None, not_del_filters=False,ref=None):
        weight, bias = conv_param
        weight = weight.data
        bias = bias.data
        origin_channels = weight.shape[0]
        print("weight original:"+str(weight.shape))
        if ref is not None:
            weight = np.reshape(weight, (512,1536,9,9))
            print("weight original reshape:"+str(weight.shape))
        
        # delete filters
        if not not_del_filters:
#             abs_mean = np.abs(weight).mean(axis=(1,2,3))
#             del_filters = np.where(abs_mean < 1.2e-3)[0]
            sum_l1 = {}
            temp = []
            for i in range(weight.shape[0]):
                sum_l1[str(np.sum(abs(weight[i, :, :, :])))] = i # sum_l1存放每个卷积核的所有权重绝对值之和
            del_filters = []  # de_keral存放大于阈值的卷积核的序号
            for i in sum_l1:
                temp.append(float(i))
            temp.sort()
#             print(temp)
            for i in range(int(len(temp)*0.07)):
                del_filters.append(sum_l1[str(temp[i])])
            weight = np.delete(weight, del_filters, axis=0)
            bias = np.delete(bias, del_filters, axis=0)
        else:
            del_filters = np.array([])
        print("weight cut nums:"+str(weight.shape))
        # delete kernels
        if del_kernels is not None:
            weight = np.delete(weight, del_kernels, axis=1)
        if ref is not None:
            weight = np.reshape(weight, (512,-1))
            print("weight original reshape:"+str(weight.shape))
        print("weight cut channels:"+str(weight.shape))    
        return weight, bias, del_filters, origin_channels
    
    def prune_conv(self, name, bottom=None):
        if bottom is None:
            self.conv_data[name] = self._prune(self._net.params[name])
        else:
            self.conv_data[name] = self._prune(self._net.params[name], self.conv_data[bottom][2])
#         if name == 'conv2_5' or name == 'conv2_6' or name == 'conv2_7' or name == 'conv2_8':
#             self.conv_data[name] = 0, 0, [], 0
            
    def prune_concat(self, name, bottoms):
        offsets = [0] + [self.conv_data[b][3] for b in bottoms]
        for i in range(1, len(offsets)):
            offsets[i] += offsets[i-1]
        del_filters = [np.asarray(self.conv_data[b][2]) + offsets[i] for i, b in enumerate(bottoms)]
        del_filters_new = np.concatenate(del_filters)
        if name == 'fc5_':
            self.conv_data[name] = self._prune(self._net.params[name], del_filters_new, not_del_filters=True,ref='fc5_')
        elif name == 'conv5_1_1'or name == 'conv5_1_1b':
            self.conv_data[name] = self._prune(self._net.params[name], del_filters_new, not_del_filters=False)
        else:
            self.conv_data[name] = self._prune(self._net.params[name], del_filters_new, not_del_filters=True)
        
    def prune_sum(self, name, ref):
        del_kernels = self.conv_data[ref][2]
        self.conv_data[name] = self._prune_res(self._net.params[name], del_kernels=del_kernels, not_del_filters=True,ref=None)
        
    def save(self, new_model, output_weights):
        net2 = caffe.Net(new_model, caffe.TEST)
        for key in net2.params.keys():
            if key in self.conv_data:
                net2.params[key][0].data[...] = self.conv_data[key][0]
                net2.params[key][1].data[...] = self.conv_data[key][1]
            else:
                net2.params[key][0].data[...] = net.params[key][0].data
                net2.params[key][1].data[...] = net.params[key][1].data
        net2.save(output_weights)

In [4]:
pruner = Pruner(net)

In [5]:
pruner.prune_conv("conv1_1_1")
pruner.prune_conv("conv1_2_1")
pruner.prune_conv("conv1_2_2", "conv1_2_1")
pruner.prune_conv("conv1_3_1")
pruner.prune_conv("conv1_3_2", "conv1_3_1")
pruner.prune_conv("conv1_3_3", "conv1_3_2")

weight original:(32, 3, 3, 3)
weight cut nums:(30, 3, 3, 3)
weight cut channels:(30, 3, 3, 3)
weight original:(32, 3, 3, 3)
weight cut nums:(30, 3, 3, 3)
weight cut channels:(30, 3, 3, 3)
weight original:(32, 32, 3, 3)
weight cut nums:(30, 32, 3, 3)
weight cut channels:(30, 30, 3, 3)
weight original:(32, 3, 3, 3)
weight cut nums:(30, 3, 3, 3)
weight cut channels:(30, 3, 3, 3)
weight original:(32, 32, 3, 3)
weight cut nums:(30, 32, 3, 3)
weight cut channels:(30, 30, 3, 3)
weight original:(32, 32, 3, 3)
weight cut nums:(30, 32, 3, 3)
weight cut channels:(30, 30, 3, 3)


In [6]:
pruner.prune_concat("conv2_1", ("conv1_1_1", "conv1_2_2", "conv1_3_3"))
pruner.prune_conv("conv2_2", "conv2_1")
pruner.prune_conv("conv2_3", "conv2_2")
pruner.prune_conv("conv2_4", "conv2_3")
pruner.prune_conv("conv2_5", "conv2_4")
pruner.prune_conv("conv2_6", "conv2_5")
pruner.prune_conv("conv2_7", "conv2_6")
pruner.prune_conv("conv2_8", "conv2_7")

weight original:(64, 96, 3, 3)
weight cut nums:(64, 96, 3, 3)
weight cut channels:(64, 90, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 64, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 60, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 60, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 60, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 60, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 60, 3, 3)
weight original:(64, 64, 3, 3)
weight cut nums:(60, 64, 3, 3)
weight cut channels:(60, 60, 3, 3)




In [7]:
pruner.prune_concat("conv3_1_1b", ("conv2_2", "conv2_4", "conv2_6", "conv2_8"))
pruner.prune_concat("conv3_1_1", ("conv2_2", "conv2_4", "conv2_6", "conv2_8"))
# pruner.prune_concat("conv3_1_1b", ("conv2_2", "conv2_4"))
# pruner.prune_concat("conv3_1_1", ("conv2_2", "conv2_4"))

weight original:(128, 256, 1, 1)
weight cut nums:(128, 256, 1, 1)
weight cut channels:(128, 240, 1, 1)
weight original:(128, 256, 3, 3)
weight cut nums:(128, 256, 3, 3)
weight cut channels:(128, 240, 3, 3)


In [8]:
# residual layers cut
pruner.prune_res_conv("conv3_1_2", "conv3_1_1","conv3_1_1b")
pruner.prune_sum("conv3_2_1", "conv3_1_1b")
pruner.prune_res_conv("conv3_2_2", "conv3_2_1","conv3_1_1b")
pruner.prune_sum("conv3_3_1", "conv3_1_1b")
pruner.prune_res_conv("conv3_3_2", "conv3_3_1","conv3_1_1b")
pruner.prune_sum("conv3_4_1", "conv3_1_1b")
pruner.prune_res_conv("conv3_4_2", "conv3_4_1","conv3_1_1b")
pruner.prune_sum("conv3_5_1", "conv3_1_1b")
pruner.prune_res_conv("conv3_5_2", "conv3_5_1","conv3_1_1b")
pruner.prune_sum("conv3_6_1", "conv3_1_1b")
pruner.prune_res_conv("conv3_6_2", "conv3_6_1","conv3_1_1b")

weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight cut channels:(128, 128, 3, 3)
weight original:(128, 128, 3, 3)
weight cut nums:(128, 128, 3, 3)
weight 



In [9]:
pruner.prune_concat("conv4_1_1b", ("conv3_1_1b", "conv3_1_1b", "conv3_1_1b"))
pruner.prune_concat("conv4_1_1", ("conv3_1_1b", "conv3_1_1b", "conv3_1_1b"))

weight original:(256, 384, 1, 1)
weight cut nums:(256, 384, 1, 1)
weight cut channels:(256, 384, 1, 1)
weight original:(256, 384, 3, 3)
weight cut nums:(256, 384, 3, 3)
weight cut channels:(256, 384, 3, 3)




In [10]:
# residual layers cut
pruner.prune_res_conv("conv4_1_2", "conv4_1_1","conv4_1_1b")
pruner.prune_sum("conv4_2_1", "conv4_1_1b")
pruner.prune_res_conv("conv4_2_2", "conv4_2_1","conv4_1_1b")
pruner.prune_sum("conv4_3_1", "conv4_1_1b")
pruner.prune_res_conv("conv4_3_2", "conv4_3_1","conv4_1_1b")
pruner.prune_sum("conv4_4_1", "conv4_1_1b")
pruner.prune_res_conv("conv4_4_2", "conv4_4_1","conv4_1_1b")
pruner.prune_sum("conv4_5_1", "conv4_1_1b")
pruner.prune_res_conv("conv4_5_2", "conv4_5_1","conv4_1_1b")
pruner.prune_sum("conv4_6_1", "conv4_1_1b")
pruner.prune_res_conv("conv4_6_2", "conv4_6_1","conv4_1_1b")

weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight cut channels:(256, 256, 3, 3)
weight original:(256, 256, 3, 3)
weight cut nums:(256, 256, 3, 3)
weight 



In [11]:
pruner.prune_concat("conv5_1_1b", ("conv4_1_1b", "conv4_1_1b", "conv4_1_1b"))
pruner.prune_concat("conv5_1_1", ("conv4_1_1b", "conv4_1_1b", "conv4_1_1b"))

weight original:(512, 768, 1, 1)
weight cut nums:(477, 768, 1, 1)
weight cut channels:(477, 768, 1, 1)
weight original:(512, 768, 1, 1)
weight cut nums:(477, 768, 1, 1)
weight cut channels:(477, 768, 1, 1)




In [12]:
# residual layers cut
pruner.prune_res_conv("conv5_1_2", "conv5_1_1","conv5_1_1b")
pruner.prune_sum("conv5_2_1", "conv5_1_1b")
pruner.prune_res_conv("conv5_2_2", "conv5_2_1","conv5_1_1b")
pruner.prune_sum("conv5_3_1", "conv5_1_1b")
pruner.prune_res_conv("conv5_3_2", "conv5_3_1","conv5_1_1b")
pruner.prune_sum("conv5_4_1", "conv5_1_1b")
pruner.prune_res_conv("conv5_4_2", "conv5_4_1","conv5_1_1b")
pruner.prune_sum("conv5_5_1", "conv5_1_1b")
pruner.prune_res_conv("conv5_5_2", "conv5_5_1","conv5_1_1b")
pruner.prune_sum("conv5_6_1", "conv5_1_1b")
pruner.prune_res_conv("conv5_6_2", "conv5_6_1","conv5_1_1b")

weight original:(512, 512, 1, 1)
weight cut nums:(477, 512, 1, 1)
weight cut channels:(477, 477, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(512, 512, 1, 1)
weight cut channels:(512, 477, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(477, 512, 1, 1)
weight cut channels:(477, 512, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(512, 512, 1, 1)
weight cut channels:(512, 477, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(477, 512, 1, 1)
weight cut channels:(477, 512, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(512, 512, 1, 1)
weight cut channels:(512, 477, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(477, 512, 1, 1)
weight cut channels:(477, 512, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(512, 512, 1, 1)
weight cut channels:(512, 477, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(477, 512, 1, 1)
weight cut channels:(477, 512, 1, 1)
weight original:(512, 512, 1, 1)
weight cut nums:(512, 512, 1, 1)
weight 



In [13]:
pruner.prune_concat("fc5_", ("conv5_1_1b", "conv5_1_1b", "conv5_1_1b"))

weight original:(512, 124416)
weight original reshape:(512, 1536, 9, 9)
weight cut nums:(512, 1536, 9, 9)
weight original reshape:(512, 115911)
weight cut channels:(512, 115911)


In [14]:
[(k, v[0].shape[0]) for k, v in pruner.conv_data.items() if v[0] is not None]

[('conv4_2_1', 256),
 ('conv4_5_1', 256),
 ('conv3_5_1', 128),
 ('conv4_4_1', 256),
 ('fc5_', 512),
 ('conv5_1_2', 477),
 ('conv5_4_1', 512),
 ('conv5_4_2', 477),
 ('conv4_1_1b', 256),
 ('conv1_1_1', 30),
 ('conv5_2_1', 512),
 ('conv5_3_1', 512),
 ('conv3_1_2', 128),
 ('conv4_3_1', 256),
 ('conv3_4_2', 128),
 ('conv5_2_2', 477),
 ('conv2_8', 60),
 ('conv4_6_2', 256),
 ('conv3_6_1', 128),
 ('conv2_5', 60),
 ('conv3_2_2', 128),
 ('conv3_3_1', 128),
 ('conv4_3_2', 256),
 ('conv3_4_1', 128),
 ('conv5_3_2', 477),
 ('conv3_5_2', 128),
 ('conv3_2_1', 128),
 ('conv4_1_1', 256),
 ('conv5_5_1', 512),
 ('conv1_3_3', 30),
 ('conv1_2_2', 30),
 ('conv2_1', 64),
 ('conv5_5_2', 477),
 ('conv4_1_2', 256),
 ('conv3_1_1', 128),
 ('conv1_3_1', 30),
 ('conv3_3_2', 128),
 ('conv1_2_1', 30),
 ('conv5_1_1b', 477),
 ('conv1_3_2', 30),
 ('conv5_1_1', 477),
 ('conv2_4', 60),
 ('conv5_6_2', 477),
 ('conv3_6_2', 128),
 ('conv2_3', 60),
 ('conv3_1_1b', 128),
 ('conv2_7', 60),
 ('conv2_2', 60),
 ('conv4_2_2', 256),


In [15]:
def get_prototxt(pk, pro_n):  # 复制原来的prototxt,并修改修剪层的num_output,这一段代码有点绕,有空的话优化为几个单独的函数或者弄个类
    with open(pro_n, "r") as p:
        lines = p.readlines()
    k = 0
    with open(pro_n, "w") as p:
        while k < len(lines):  # 遍历所有的lines,此处不宜用for.
#             print("lines[k]:",lines[k])
            if 'name:' in lines[k]:
#                 print("lines[k].split:",lines[k].split('"')[1])
#                 print(pk.keys())
                l_name = lines[k].split('"')[1]  # 获取layer name
                if l_name in pk.keys():  # 如果name在待修剪层中,则需要修改,下面进入一个找channel的循环块.
                    while True:
                        if "num_output:" in lines[k]:
                            channel_n = "    num_output: " + str(len(pk[l_name][0])) + "\n"
#                             print(channel_n)
                            p.write(channel_n)
                            k = k + 1
                            break
                        else:
                            p.write(lines[k])
                            k = k + 1
                else:  # name不在待修剪层中,直接copy行
                    p.write(lines[k])
                    k = k + 1
 
            else:
                p.write(lines[k])
                k = k + 1
    print("deploy_rebirth_prune.prototxt已写好")

In [16]:
pro_n = root + "TestModel.prototxt"
shutil.copyfile(prototxt, pro_n)
get_prototxt(pruner.conv_data, pro_n)

deploy_rebirth_prune.prototxt已写好


In [17]:
# You should modify the number of channels in new prototxt before save
pruner.save(pro_n, "/media/vanilla/Fun/zhongxing/Test/TestModel.caffemodel")

In [18]:
# import caffe
# net = caffe.Net('/media/vanilla/Fun/zhongxing/Test/TestModel_deploy.prototxt', '/media/vanilla/Fun/zhongxing/Test/TestModel.caffemodel', caffe.TEST)
# net.save('/media/vanilla/Fun/zhongxing/Test/TestModel_remove.caffemodel')