In [1]:
import argparse
import numpy as np
import os

import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

import models 
from models import channel_selection


In [2]:
# # 从原始的model 来初始化，然后把ckpt里面的参数填进去

# model = models.__dict__['resnet'](dataset='cifar10', depth=164)
# # model.cuda()

# if os.path.isfile('/userhome/34/gyu/logs_sr/3wk_p1/checkpoint_3EPO_sr.pth.tar'):
#     checkpoint = torch.load('/userhome/34/gyu/logs_sr/3wk_p1/checkpoint_3EPO_sr.pth.tar')
#     model.load_state_dict(checkpoint['state_dict'])


In [3]:
# 剪枝过后的 model 来初始化，然后把ckpt里面的参数填进去
checkpoint = torch.load('/userhome/34/gyu/logs_sr/3wk_p1/prune_1st/pruned.pth.tar')

model = models.__dict__['resnet'](dataset='cifar10', depth=164, cfg=checkpoint['cfg'])
model.load_state_dict(checkpoint['state_dict'])                
# model.cuda()



IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [4]:
model

resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(11, 12, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(12, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(14, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection(

In [5]:
def count_conv_channels(model):
    sum_channel = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            sum_channel += m.in_channels
#     print(sum_channel)
    return sum_channel

In [6]:
print("observe from the conv angle :::")
print("  original model conv channels: ",count_conv_channels(models.__dict__['resnet'](dataset='cifar10', depth=164)))
print("  after prune 20%: ",count_conv_channels(model))
print("  9655/12067 = 0.8001")
print("observe from the cfg angle :::")
n = (164 - 2) // 9
orig_cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]]
orig_cfg = [item for sub_list in orig_cfg for item in sub_list]
print("  oringinal cfg sum: ", sum(orig_cfg))
print("  after prune 20%: ", sum(checkpoint['cfg']))
print("  9689 / 12112 =  0.79995")

observe from the conv angle :::
  original model conv channels:  12067
  after prune 20%:  9655
  9655/12067 = 0.8001
observe from the cfg angle :::
  oringinal cfg sum:  12112
  after prune 20%:  9689
  9689 / 12112 =  0.79995


In [7]:
# total = 0
# # total ： 当前总channel数量

# for m in model.modules():
#     if isinstance(m, nn.BatchNorm2d):
#         total += m.weight.data.shape[0]
# print(total)

In [9]:
# update: 这里应该判定一下，当bn后一层是cs层的时候，这里不能用bn的shape，而是应该用它后一层的cs层里面select了多少来算
total = 0
module_list = list(model.modules())
for i in range(len(module_list)):
    m = module_list[i]
    if (i + 1< len(module_list)) and isinstance(m, nn.BatchNorm2d) and isinstance(module_list[i+1],channel_selection):
        cs = module_list[i+1]
        total += int(torch.sum(cs.indexes))
    elif isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]
print(total)

9689


In [10]:
# total = 11452
# bn = torch.zeros(total)

# index = 0
# for m in model.modules():
#     if isinstance(m, nn.BatchNorm2d):
#         size = m.weight.data.shape[0]
#         bn[index:(index+size)] = m.weight.data.abs().clone() # 的确应该取绝对值
#         index += size
# print(index)

In [11]:
bn = torch.zeros(total)

index = 0
for i in range(len(module_list)):
    m = module_list[i]
    if i+1 < len(module_list) and isinstance(m, nn.BatchNorm2d) and isinstance(module_list[i+1],channel_selection):
        cs = module_list[i+1]
        size = int(torch.sum(cs.indexes))
        mask = cs.indexes.clone().detach().numpy()
        bn[index:(index+size)] = torch.from_numpy(m.weight.data.abs().clone().numpy()[np.where(mask)])
        index += size
    elif isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone() # 的确应该取绝对值
        index += size

In [12]:
bn # 所有bn的weight

tensor([0.4114, 0.7914, 0.5157,  ..., 0.6796, 0.6411, 0.6802])

In [14]:
y, i = torch.sort(bn)
thre_index = int(total * 0.2)
thre = y[thre_index]
# y,i = y.cuda(),i.cuda()
# thre = thre.cuda()
print(y)
print(thre)

tensor([0.3462, 0.3462, 0.3462,  ..., 1.3288, 1.4001, 1.5130])
tensor(0.3676)


In [15]:
# y: 排序后的bn数组
# i: 对应元素在排序前的index
# thre: 小于阈值的被剪枝

In [16]:
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre).float() # 大于阈值 mask 对应元素为 1
        pruned = pruned + mask.shape[0] - torch.sum(mask) # pruned += mask中为0的元素个数
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask))) # cfg 为mask中元素等于1的个数
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

layer index: 4 	 total channel: 16 	 remaining channel: 10
layer index: 7 	 total channel: 12 	 remaining channel: 11
layer index: 9 	 total channel: 14 	 remaining channel: 12
layer index: 15 	 total channel: 64 	 remaining channel: 30
layer index: 18 	 total channel: 12 	 remaining channel: 11
layer index: 20 	 total channel: 11 	 remaining channel: 11
layer index: 24 	 total channel: 64 	 remaining channel: 32
layer index: 27 	 total channel: 13 	 remaining channel: 11
layer index: 29 	 total channel: 13 	 remaining channel: 13
layer index: 33 	 total channel: 64 	 remaining channel: 38
layer index: 36 	 total channel: 12 	 remaining channel: 12
layer index: 38 	 total channel: 13 	 remaining channel: 11
layer index: 42 	 total channel: 64 	 remaining channel: 37
layer index: 45 	 total channel: 13 	 remaining channel: 13
layer index: 47 	 total channel: 14 	 remaining channel: 12
layer index: 51 	 total channel: 64 	 remaining channel: 32
layer index: 54 	 total channel: 13 	 remai

In [17]:
# for tmp in range(len(cfg)):
#     print(cfg[tmp],len(cfg_mask[tmp]),int(torch.sum(cfg_mask[tmp])))

In [18]:
len(cfg_mask)

163

In [19]:


pruned_ratio = pruned/total

print('Pre-processing Successful!')
print(pruned_ratio)

Pre-processing Successful!
tensor(0.3820)


In [None]:
# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test(model):
    model.cuda()
    kwargs = {'num_workers': 1, 'pin_memory': True} 
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
        batch_size=64, shuffle=False, **kwargs)
    model.eval()
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    model.cpu()
    return correct / float(len(test_loader.dataset))


In [None]:
acc = test(model)
print(acc)

In [22]:
# 从这里定义Net的源码 可以发现cfg决定了channel的数量
# 这里通过剪枝得到了剪完以后的cfg 根据这个来init newmodel
print("Cfg:")
print(cfg)
print(len(cfg))
print(len(cfg_mask))

Cfg:
[10, 11, 12, 30, 11, 11, 32, 11, 13, 38, 12, 11, 37, 13, 12, 32, 13, 9, 34, 11, 12, 34, 11, 13, 35, 10, 9, 36, 11, 12, 32, 9, 9, 34, 12, 10, 30, 7, 15, 38, 11, 7, 33, 11, 8, 36, 11, 8, 43, 11, 9, 34, 7, 7, 31, 22, 25, 74, 17, 28, 69, 20, 32, 72, 22, 29, 74, 23, 28, 72, 17, 30, 81, 16, 29, 67, 15, 29, 71, 21, 25, 67, 16, 29, 72, 19, 24, 67, 18, 29, 76, 21, 28, 75, 20, 26, 74, 22, 27, 71, 19, 31, 80, 17, 28, 67, 20, 29, 66, 33, 63, 147, 32, 64, 133, 34, 64, 152, 39, 64, 147, 34, 64, 163, 35, 64, 150, 36, 64, 145, 36, 63, 153, 33, 64, 151, 32, 64, 149, 33, 64, 153, 38, 64, 154, 37, 63, 159, 34, 63, 157, 42, 63, 167, 32, 64, 154, 28, 64, 155, 34, 64, 238]
163
163


In [23]:

newmodel = models.__dict__["resnet"](dataset="cifar10",depth=164,cfg=cfg)
# newmodel = resnet(depth=164, dataset='cifar-10', cfg=cfg)
newmodel

resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(10, 11, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(11, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(12, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection(

In [24]:

# num of param in new model
num_parameters = sum([param.nelement() for param in newmodel.parameters()])

old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0


# 把原来model里面的参数 填入new model

for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]
#     print('start_mask: ',start_mask)
#     print('end_mask: ',end_mask)
    if isinstance(m0, nn.BatchNorm2d):
#         print("current is BN layer")
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))
        
        # 当后面一层是channel selection时
        if isinstance(old_modules[layer_id + 1], channel_selection):
#             print("  next layer is CS")
            # If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
            # 既然后面是 channel selection layer，就不用挑出哪些是prune的
            # 直接原模型所有参数apply到新的上
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()

            # We need to set the channel selection layer.
            # 需要把cs layer的indexs参数改成这里的idx就行
            # 用cs layer 来实现prune的效果
            m2 = new_modules[layer_id + 1]
            m2.indexes.data.zero_()
            m2.indexes.data[idx1.tolist()] = 1.0

            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):
                end_mask = cfg_mask[layer_id_in_cfg]
        else:
            # 从老模型里面挑出保留下来的channel
            # 然后把这些参数apply到新的模型
#             print("  next layer is NOT CS")
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            m1.running_mean = m0.running_mean[idx1.tolist()].clone()
            m1.running_var = m0.running_var[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d):
#         print("current is conv layer")
        if conv_count == 0: # 第一层卷积层不prune
#             print("  current is first conv layer")
            m1.weight.data = m0.weight.data.clone()
            conv_count += 1
            continue
        if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
            # 所有残差模块里面的conv层要么跟在bn后面 要么跟在channel selection后面
            # This convers the convolutions in the residual block.
            # The convolutions are either after the channel selection layer or after the batch normalization layer.
            conv_count += 1
            # start_mask 指的是当前层的前一层的mask
            # end_mask 代表当前层
            
            
            #奇怪的补丁：
            if isinstance(old_modules[layer_id-1], channel_selection):
                old_prev_cs = old_modules[layer_id-1]
                new_prev_cs = new_modules[layer_id-1]
                old_ones = np.squeeze(np.argwhere(old_prev_cs.indexes.data))
                new_ones = np.squeeze(np.argwhere(new_prev_cs.indexes.data))
                start_mask = [0] * len(old_ones)
                for idx in range(len(old_ones)):
                    if old_ones[idx] in new_ones:
                        start_mask[idx] = 1
                    else:
                        start_mask[idx] = 0
                start_mask = torch.FloatTensor(start_mask)
                
#             print('conv_count: ',conv_count)
#             print('layer_id: ',layer_id)
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
#             print(idx0)
#             print(idx1)
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
#             print("idx0 to list: ")
#             print(idx0.tolist())
#             print(m0.weight.data.shape)
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
#             print('w1')
#             print(w1)
            # If the current convolution is not the last convolution in the residual block, then we can change the 
            # number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
            # 这里残差模块里面一共有3个conv层，prune的时候不剪第三个conv层
            if conv_count % 3 != 1:
                w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            continue

        # We need to consider the case where there are downsampling convolutions. 
        # For these convolutions, we just copy the weights.
        # 对于下采样层也不剪枝
        m1.weight.data = m0.weight.data.clone()
    elif isinstance(m0, nn.Linear):
        
        # 奇怪的补丁
        old_prev_cs = old_modules[layer_id-3]
        new_prev_cs = new_modules[layer_id-3]
        old_ones = np.squeeze(np.argwhere(old_prev_cs.indexes.data))
        new_ones = np.squeeze(np.argwhere(new_prev_cs.indexes.data))
        start_mask = [0] * len(old_ones)
        for idx in range(len(old_ones)):
            if old_ones[idx] in new_ones:
                start_mask[idx] = 1
            else:
                start_mask[idx] = 0
        start_mask = torch.FloatTensor(start_mask)

        
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()



In shape: 10, Out shape 11.
In shape: 11, Out shape 12.
In shape: 12, Out shape 30.
In shape: 30, Out shape 11.
In shape: 11, Out shape 11.
In shape: 11, Out shape 32.
In shape: 32, Out shape 11.
In shape: 11, Out shape 13.
In shape: 13, Out shape 38.
In shape: 38, Out shape 12.
In shape: 12, Out shape 11.
In shape: 11, Out shape 37.
In shape: 37, Out shape 13.
In shape: 13, Out shape 12.
In shape: 12, Out shape 32.
In shape: 32, Out shape 13.
In shape: 13, Out shape 9.
In shape: 9, Out shape 34.
In shape: 34, Out shape 11.
In shape: 11, Out shape 12.
In shape: 12, Out shape 34.
In shape: 34, Out shape 11.
In shape: 11, Out shape 13.
In shape: 13, Out shape 35.
In shape: 35, Out shape 10.
In shape: 10, Out shape 9.
In shape: 9, Out shape 36.
In shape: 36, Out shape 11.
In shape: 11, Out shape 12.
In shape: 12, Out shape 32.
In shape: 32, Out shape 9.
In shape: 9, Out shape 9.
In shape: 9, Out shape 34.
In shape: 34, Out shape 12.
In shape: 12, Out shape 10.
In shape: 10, Out shape 30.


In [25]:
# 不保存
# torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))
print(newmodel)


resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(10, 11, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(11, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(12, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection(

In [31]:
print(count_conv_channels(model))
print(count_conv_channels(newmodel))
print('7724 / 9655 = 0.8')

9655
7724
7724 / 9655 = 0.8


0.8

In [None]:
model = newmodel
test(model)

In [None]:
np.squeeze(np.argwhere(np.asarray(torch.FloatTensor(start_mask).cpu().data)))

In [None]:
new_cs = new_modules[layer_id-1]

In [None]:
print('len of cs indexes: ',len(np.asarray(new_cs.indexes.data)))
print(np.asarray(new_cs.indexes.data))
print(np.squeeze(np.argwhere(np.asarray(new_cs.indexes.data))))

In [None]:
idx1

In [None]:
len(old_modules)

In [None]:
old_idx =[0,1,1,1,1,1,1]
new_idx =[0,0,1,1,0,0,0]

In [None]:
old_ones = np.squeeze(np.argwhere(old_idx))

In [None]:
new_ones = np.squeeze(np.argwhere(new_idx))

In [None]:
new_ones.shape[0]

In [None]:
old_ones

In [None]:
np.setdiff1d(old_ones,new_ones)

In [None]:
tmp = [1,2,3,4,5]

In [None]:
del tmp,[0,4]

In [None]:
a = [0] * 4

In [None]:
a[2]=1

In [None]:
a

In [None]:
1 in [10,20,30]