In [40]:
import os
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

In [41]:
from torchvision import models

In [42]:
from models.seresnet.FastPose import createModel

In [43]:
from src.opt import opt

In [44]:
weights = "/media/hkuit164/MB155_1/sparsed_demo/origin_5E-4-acc/origin_5E-4_best_acc.pkl"

In [45]:
model =  createModel(cfg=None).cpu()

In [46]:
opt.kps = 17

In [47]:
model.load_state_dict(torch.load(weights, map_location="cpu"))

<All keys matched successfully>

In [48]:
def obtain_prune_idx(path):
    lines = []
    with open(path, 'r') as f:
        file = f.readlines()
        for line in file:
            lines.append(line)
            
    idx = 0
    prune_idx = []
    for line in lines:
        if "):" in line:
            idx  += 1
        if "BatchNorm2d" in line:
            #print(idx, line)
            prune_idx.append(idx)
    
    prune_idx = prune_idx[1:] # 去除第一个bn1层
    return prune_idx

model_name = "./model.txt"
print(model, file=open(model_name, 'w'))
prune_idx = obtain_prune_idx(model_name)

In [49]:
len(prune_idx)

105

In [50]:
def sort_bn(model, prune_idx):
    size_list = [m.weight.data.shape[0] for idx, m in enumerate(model.modules()) if idx in prune_idx]
    # bn_layer = [m for m in model.modules() if isinstance(m, nn.BatchNorm2d)]
    bn_prune_layers = [m for idx, m in enumerate(model.modules()) if idx in prune_idx]
    bn_weights = torch.zeros(sum(size_list))

    index = 0
    for module, size in zip(bn_prune_layers, size_list):
        bn_weights[index:(index + size)] = module.weight.data.abs().clone()
        index += size
    sorted_bn = torch.sort(bn_weights)[0]
    
    return sorted_bn

sorted_bn = sort_bn(model, prune_idx)

In [51]:
sorted_bn

tensor([6.0189e-10, 2.9141e-09, 3.3896e-09,  ..., 9.9507e-01, 1.0077e+00,
        1.0361e+00])

In [52]:
def obtain_bn_mask(bn_module, thre):
    if device != "cpu":
        thre = thre.cuda()
    mask = bn_module.weight.data.abs().ge(thre).float()

    return mask

def obtain_filters_mask(model, prune_idx, thre):
    pruned = 0
    bn_count = 0
    total = 0
    num_filters = []
    pruned_filters = []
    filters_mask = []
    pruned_maskers = []
    
    for idx, module in enumerate(model.modules()):
        if isinstance(module, nn.BatchNorm2d):
            if idx in prune_idx:
                mask = obtain_bn_mask(module, thre).cpu().numpy()
                remain = int(mask.sum())
                pruned = pruned + mask.shape[0] - remain

                if remain == 0: # 保证至少有一个channel
                    # print("Channels would be all pruned!")
                    # raise Exception
                    max_value = module.weight.data.abs().max()
                    mask = obtain_bn_mask(module, max_value).cpu().numpy()
                    remain = int(mask.sum())
                    pruned = pruned + mask.shape[0] - remain
                    bn_count += 1
                print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t '
                      f'remaining channel: {remain:>4d}')
                
                pruned_filters.append(remain)
                pruned_maskers.append(mask.copy())
            else:
                mask = np.ones(module.weight.data.shape)
                remain = mask.shape[0]
            
            total += mask.shape[0]
            num_filters.append(remain)
            filters_mask.append(mask.copy())
    
    prune_ratio = pruned / total
    print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}')

    return pruned_filters, pruned_maskers

In [53]:
percent = 0.5

In [54]:
def obtain_bn_threshold(model, sorted_bn, percentage):
    thre_index = int(len(sorted_bn) * percentage)
    thre = sorted_bn[thre_index]
    
    return thre

threshold = obtain_bn_threshold(model, sorted_bn, percent)
print(threshold)

tensor(0.1801)


In [55]:
device = "cpu"

In [56]:
pruned_filters, pruned_maskers = obtain_filters_mask(model, prune_idx, threshold)

layer index:   9 	 total channel:   64 	 remaining channel:   33
layer index:  11 	 total channel:   64 	 remaining channel:   27
layer index:  13 	 total channel:  256 	 remaining channel:  104
layer index:  23 	 total channel:  256 	 remaining channel:  175
layer index:  26 	 total channel:   64 	 remaining channel:    6
layer index:  28 	 total channel:   64 	 remaining channel:   15
layer index:  30 	 total channel:  256 	 remaining channel:   27
layer index:  33 	 total channel:   64 	 remaining channel:   40
layer index:  35 	 total channel:   64 	 remaining channel:   46
layer index:  37 	 total channel:  256 	 remaining channel:   52
layer index:  41 	 total channel:  128 	 remaining channel:   89
layer index:  43 	 total channel:  128 	 remaining channel:  106
layer index:  45 	 total channel:  512 	 remaining channel:  217
layer index:  55 	 total channel:  512 	 remaining channel:  187
layer index:  58 	 total channel:  128 	 remaining channel:    9
layer index:  60 	 total 

In [57]:
print(pruned_filters, file=open("ceiling.txt", "w"))

In [60]:
print(model.modules)

<bound method Module.modules of FastPose(
  (preact): SeResnet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): SeBottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (se): SELayer(
          (avg_pool

In [59]:
compact_module_defs = deepcopy(model.module_defs)

AttributeError: 'FastPose' object has no attribute 'module_defs'

In [None]:
for idx, num in zip(CBL_idx, num_filters):
    assert compact_module_defs[idx]['type'] == 'convolutional'
    compact_module_defs[idx]['filters'] = str(num)

In [None]:
new_model = createModel(cfg="ceiling.txt").cpu()

In [32]:
new_model = createModel(cfg="ceiling.txt").cpu()

In [33]:
new_model

FastPose(
  (preact): SeResnet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): SeBottleneck(
        (conv1): Conv2d(64, 33, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(33, 27, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(27, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (se): SELayer(
          (avg_pool): AdaptiveAvgPool2d(output_size