In [1]:
import os
import torch
import torch.nn as nn
import numpy as np

In [2]:
from torchvision import models

In [3]:
from models.mobilenet.MobilePose import createModel

True


In [4]:
model =  createModel(cfg=None).cpu()

In [5]:
weights = "saved/model_102.pkl"

In [6]:
model.load_state_dict(torch.load(weights, map_location="cpu"))

<All keys matched successfully>

In [7]:
model

MobilePose(
  (mobile): MobileNetV2(
    (features): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)

In [9]:
def obtain_prune_idx(path):
    lines = []
    with open(path, 'r') as f:
        file = f.readlines()
        for line in file:
            lines.append(line)
            
    idx = 0
    prune_idx = []
    for line in lines:
        if "):" in line:
            idx  += 1
        if "BatchNorm2d" in line:
            #print(idx, line)
            prune_idx.append(idx)
    
    prune_idx = prune_idx[1:] # 去除第一个bn1层
    return prune_idx

model_name = "./model.txt"
print(model, file=open(model_name, 'w'))
prune_idx = obtain_prune_idx(model_name)

In [10]:
len(prune_idx)

53

In [12]:
def sort_bn(model, prune_idx):
    size_list = [m.weight.data.shape[0] for idx, m in enumerate(model.modules()) if idx in prune_idx]
    # bn_layer = [m for m in model.modules() if isinstance(m, nn.BatchNorm2d)]
    bn_prune_layers = [m for idx, m in enumerate(model.modules()) if idx in prune_idx]
    bn_weights = torch.zeros(sum(size_list))

    index = 0
    for module, size in zip(bn_prune_layers, size_list):
        bn_weights[index:(index + size)] = module.weight.data.abs().clone()
        index += size
    sorted_bn = torch.sort(bn_weights)[0]
    
    return sorted_bn

sorted_bn = sort_bn(model, prune_idx)

In [13]:
sorted_bn

tensor([4.4228e-10, 1.2192e-09, 1.5593e-09,  ..., 2.3769e-01, 2.5831e-01,
        3.5836e-01])

In [28]:
percent = 0.7

In [29]:
def obtain_bn_threshold(model, sorted_bn, percentage):
    thre_index = int(len(sorted_bn) * percentage)
    thre = sorted_bn[thre_index]
    
    return thre

threshold = obtain_bn_threshold(model, sorted_bn, percent)
print(threshold)

tensor(0.0003)


In [30]:
device = "cpu"

In [31]:
def obtain_bn_mask(bn_module, thre):
    if device != "cpu":
        thre = thre.cuda()
    mask = bn_module.weight.data.abs().ge(thre).float()

    return mask

def obtain_filters_mask(model, prune_idx, thre):
    pruned = 0
    bn_count = 0
    total = 0
    num_filters = []
    pruned_filters = []
    filters_mask = []
    pruned_maskers = []
    
    for idx, module in enumerate(model.modules()):
        if isinstance(module, nn.BatchNorm2d):
            if idx in prune_idx:
                mask = obtain_bn_mask(module, thre).cpu().numpy()
                remain = int(mask.sum())
                pruned = pruned + mask.shape[0] - remain

                if remain == 0: # 保证至少有一个channel
                    # print("Channels would be all pruned!")
                    # raise Exception
                    max_value = module.weight.data.abs().max()
                    mask = obtain_bn_mask(module, max_value).cpu().numpy()
                    remain = int(mask.sum())
                    pruned = pruned + mask.shape[0] - remain
                    bn_count += 1
                print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t '
                      f'remaining channel: {remain:>4d}')
                
                pruned_filters.append(remain)
                pruned_maskers.append(mask.copy())
            else:
                mask = np.ones(module.weight.data.shape)
                remain = mask.shape[0]
            
            total += mask.shape[0]
            num_filters.append(remain)
            filters_mask.append(mask.copy())
    
    prune_ratio = pruned / total
    print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}')

    return pruned_filters, pruned_maskers

In [32]:
pruned_filters, pruned_maskers = obtain_filters_mask(model, prune_idx, threshold)

layer index:  11 	 total channel:   32 	 remaining channel:   26
layer index:  14 	 total channel:   16 	 remaining channel:   16
layer index:  19 	 total channel:   96 	 remaining channel:   78
layer index:  23 	 total channel:   96 	 remaining channel:   69
layer index:  26 	 total channel:   24 	 remaining channel:   24
layer index:  31 	 total channel:  144 	 remaining channel:   88
layer index:  35 	 total channel:  144 	 remaining channel:   74
layer index:  38 	 total channel:   24 	 remaining channel:   24
layer index:  43 	 total channel:  144 	 remaining channel:   93
layer index:  47 	 total channel:  144 	 remaining channel:   74
layer index:  50 	 total channel:   32 	 remaining channel:   32
layer index:  55 	 total channel:  192 	 remaining channel:   98
layer index:  59 	 total channel:  192 	 remaining channel:   78
layer index:  62 	 total channel:   32 	 remaining channel:   32
layer index:  67 	 total channel:  192 	 remaining channel:   72
layer index:  71 	 total 