In [7]:
import numpy as np
import torch
import torch.nn as nn
from scipy.spatial import distance
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,8,(4,5))
        self.conv2 = nn.Conv2d(8,16,(4,5))
        self.conv3 = nn.Conv2d(16,256,(4,5))
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

m = CNN()
print(list(m.parameters()))

[Parameter containing:
tensor([[[[-0.0031, -0.0178, -0.0434, -0.0886, -0.0829],
          [ 0.1275, -0.1271, -0.0019,  0.0055, -0.1037],
          [-0.0668, -0.1140,  0.0868,  0.0518,  0.0449],
          [-0.0061, -0.0160,  0.0027,  0.0258, -0.0541]],

         [[ 0.0254, -0.0016,  0.0715,  0.0558, -0.0968],
          [-0.1265,  0.1005,  0.0294,  0.0823, -0.0860],
          [ 0.1222, -0.1251,  0.1162,  0.0257, -0.0735],
          [ 0.0691,  0.0515,  0.0627,  0.0450, -0.0885]],

         [[ 0.0706, -0.0774, -0.0017,  0.0791, -0.0662],
          [-0.0790, -0.0302,  0.1079, -0.0727, -0.1196],
          [-0.0168, -0.1265,  0.0890,  0.0426,  0.0610],
          [ 0.1094,  0.0257,  0.0716,  0.0497, -0.1005]]],


        [[[ 0.0524, -0.1210, -0.0591,  0.0904, -0.0531],
          [-0.0234, -0.0213,  0.0103,  0.0561, -0.0294],
          [-0.1107,  0.1132,  0.0806,  0.0888,  0.0768],
          [-0.0161,  0.1048,  0.0172, -0.0102,  0.0477]],

         [[-0.1044,  0.0692,  0.0874,  0.0479, -0.0597]

In [8]:

for i in m.modules():
    if isinstance(i, nn.Conv2d):
        print(i)
        print(type(i.weight))

Conv2d(3, 8, kernel_size=(4, 5), stride=(1, 1))
<class 'torch.nn.parameter.Parameter'>
Conv2d(8, 16, kernel_size=(4, 5), stride=(1, 1))
<class 'torch.nn.parameter.Parameter'>
Conv2d(16, 256, kernel_size=(4, 5), stride=(1, 1))
<class 'torch.nn.parameter.Parameter'>


In [9]:
model_size = {}
model_length = {}
def init_length():
    for index, item in enumerate(m.parameters()):
        model_size[index] = item.size()

    for index1 in model_size:
        for index2 in range(0, len(model_size[index1])):
            if index2 == 0:
                model_length[index1] = model_size[index1][0]
                print('if', model_size[index1][0])
            else:
                model_length[index1] *= model_size[index1][index2]
                print('else', model_length[index1])
init_length()
print(m.parameters)
print(model_size)
print(model_length)

if 8
else 24
else 96
else 480
if 8
if 16
else 128
else 512
else 2560
if 16
if 256
else 4096
else 16384
else 81920
if 256
<bound method Module.parameters of CNN(
  (conv1): Conv2d(3, 8, kernel_size=(4, 5), stride=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(4, 5), stride=(1, 1))
  (conv3): Conv2d(16, 256, kernel_size=(4, 5), stride=(1, 1))
)>
{0: torch.Size([8, 3, 4, 5]), 1: torch.Size([8]), 2: torch.Size([16, 8, 4, 5]), 3: torch.Size([16]), 4: torch.Size([256, 16, 4, 5]), 5: torch.Size([256])}
{0: 480, 1: 8, 2: 2560, 3: 16, 4: 81920, 5: 256}


In [10]:
for p in m.parameters():
    # print(p)
    print(p.size())
    print(type(p))

torch.Size([8, 3, 4, 5])
<class 'torch.nn.parameter.Parameter'>
torch.Size([8])
<class 'torch.nn.parameter.Parameter'>
torch.Size([16, 8, 4, 5])
<class 'torch.nn.parameter.Parameter'>
torch.Size([16])
<class 'torch.nn.parameter.Parameter'>
torch.Size([256, 16, 4, 5])
<class 'torch.nn.parameter.Parameter'>
torch.Size([256])
<class 'torch.nn.parameter.Parameter'>


In [11]:
wt = None
for i in m.modules():
    if isinstance(i, nn.Conv2d):
        # print(i.weight)
        print(type(i.weight))
        wt = i.weight
        break
print(wt)
'''model.modules.weight'''
'''model.conv1.weight'''

<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([[[[-0.0031, -0.0178, -0.0434, -0.0886, -0.0829],
          [ 0.1275, -0.1271, -0.0019,  0.0055, -0.1037],
          [-0.0668, -0.1140,  0.0868,  0.0518,  0.0449],
          [-0.0061, -0.0160,  0.0027,  0.0258, -0.0541]],

         [[ 0.0254, -0.0016,  0.0715,  0.0558, -0.0968],
          [-0.1265,  0.1005,  0.0294,  0.0823, -0.0860],
          [ 0.1222, -0.1251,  0.1162,  0.0257, -0.0735],
          [ 0.0691,  0.0515,  0.0627,  0.0450, -0.0885]],

         [[ 0.0706, -0.0774, -0.0017,  0.0791, -0.0662],
          [-0.0790, -0.0302,  0.1079, -0.0727, -0.1196],
          [-0.0168, -0.1265,  0.0890,  0.0426,  0.0610],
          [ 0.1094,  0.0257,  0.0716,  0.0497, -0.1005]]],


        [[[ 0.0524, -0.1210, -0.0591,  0.0904, -0.0531],
          [-0.0234, -0.0213,  0.0103,  0.0561, -0.0294],
          [-0.1107,  0.1132,  0.0806,  0.0888,  0.0768],
          [-0.0161,  0.1048,  0.0172, -0.0102,  0.0477]],

         [[-0.104

'model.conv1.weight'

In [20]:
def get_filter_similar(weights, length, compress_rate = 0.5, distance_rate = 0.4, mix=True, dist_type="l2"):
    '''
        weights: torch.nn.parameter.Parameter
        length: model_length[index]
        weights.size(): out_channel, in_channel, w, h
    '''
    codebook = np.ones(length)
    if len(weights.size()) == 4:
        filter_pruned_num = int(weights.size()[0] * (1 - compress_rate))
        similar_pruned_num = int(weights.size()[0] * distance_rate)
        weight_vec = weights.view(weights.size()[0], -1)
        '''把该层的每个filter拉成一整条vector'''
        print('\nweight_vec', weight_vec)
        
        if dist_type == "l2" or "cos":
            norm = torch.norm(weight_vec, 2, 1)
            norm_np = norm.detach().numpy()
        elif dist_type == "l1":
            norm = torch.norm(weight_vec, 1, 1)
            norm_np = norm.detach().numpy()
        print('\nnorm_np\n', norm_np)
        
        filter_selected = norm_np.argsort()
        '''排序返回索引'''
                    
        if mix:
            filter_selected = filter_selected[filter_pruned_num:]
        else:
            filter_selected = filter_selected[:]
            
        print('\nfilter_selected', filter_selected)
        '''如果是mix,则选出norm最大的几个,返回其索引'''

        indices = torch.LongTensor(filter_selected)
        print('\nindices', indices)
        
        weight_vec_after_norm = torch.index_select(weight_vec, 0, indices).detach().numpy()
        '''选取weight_vec第0维的指定index的张量,即norm最大的那几个'''
        
        # for euclidean distance
        if dist_type == "l2" or "l1":
            similar_matrix = distance.cdist(weight_vec_after_norm, weight_vec_after_norm, 'euclidean')
            '''计算每两个点之间的距离'''
        elif dist_type == "cos":  # for cos similarity
            similar_matrix = 1 - distance.cdist(weight_vec_after_norm, weight_vec_after_norm, 'cosine')
        similar_sum = np.sum(np.abs(similar_matrix), axis=0)
        '''每个点与其他各点的距离之和'''
        print('\nsimilar_sum\n', similar_sum)
        print('\nsimilar_sum.argsort\n', similar_sum.argsort())

        # for distance similar: get the filter index with largest similarity == small distance
        similar_large_index = similar_sum.argsort()[similar_pruned_num:]
        '''排序后返回较大的几个的索引'''
        print('\nsimilar_large_index\n', similar_large_index)
        similar_small_index = similar_sum.argsort()[:similar_pruned_num]
        print('\nsimilar_small_index\n', similar_small_index)
        similar_index_for_filter = [filter_selected[i] for i in similar_small_index]
        '''从选择的filter索引（较大或全部）里选出较小几个的filter'''
        print('\nsimilar_index_for_filter\n', similar_index_for_filter)
        kernel_length = weights.size()[1] * weights.size()[2] * weights.size()[3]
        '''out * w * h'''
        for x in range(0, len(similar_index_for_filter)):
            codebook[
            similar_index_for_filter[x] * kernel_length: (similar_index_for_filter[x] + 1) * kernel_length] = 0
        '''被选中的filter的所有参数在codebook里置0'''
        print("\nsimilar index done")
    else:
        pass
    print(type(codebook))
    return codebook
print('wt\n', wt.size(), '\nmodel_length\n', model_length)
print(get_filter_similar(wt, model_length[0]))

wt
 torch.Size([8, 3, 4, 5]) 
model_length
 {0: 480, 1: 8, 2: 2560, 3: 16, 4: 81920, 5: 256}

weight_vec tensor([[-0.0031, -0.0178, -0.0434, -0.0886, -0.0829,  0.1275, -0.1271, -0.0019,
          0.0055, -0.1037, -0.0668, -0.1140,  0.0868,  0.0518,  0.0449, -0.0061,
         -0.0160,  0.0027,  0.0258, -0.0541,  0.0254, -0.0016,  0.0715,  0.0558,
         -0.0968, -0.1265,  0.1005,  0.0294,  0.0823, -0.0860,  0.1222, -0.1251,
          0.1162,  0.0257, -0.0735,  0.0691,  0.0515,  0.0627,  0.0450, -0.0885,
          0.0706, -0.0774, -0.0017,  0.0791, -0.0662, -0.0790, -0.0302,  0.1079,
         -0.0727, -0.1196, -0.0168, -0.1265,  0.0890,  0.0426,  0.0610,  0.1094,
          0.0257,  0.0716,  0.0497, -0.1005],
        [ 0.0524, -0.1210, -0.0591,  0.0904, -0.0531, -0.0234, -0.0213,  0.0103,
          0.0561, -0.0294, -0.1107,  0.1132,  0.0806,  0.0888,  0.0768, -0.0161,
          0.1048,  0.0172, -0.0102,  0.0477, -0.1044,  0.0692,  0.0874,  0.0479,
         -0.0597, -0.0540, -0.1139, -0.