In [1]:
import argparse
import numpy as np
import os

import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

import models 
from models import channel_selection


In [2]:
# # 从原始的model 来初始化，然后把ckpt里面的参数填进去

# model = models.__dict__['resnet'](dataset='cifar10', depth=164)
# # model.cuda()

# if os.path.isfile('/userhome/34/gyu/logs_sr/3wk_p1/checkpoint_3EPO_sr.pth.tar'):
#     checkpoint = torch.load('/userhome/34/gyu/logs_sr/3wk_p1/checkpoint_3EPO_sr.pth.tar')
#     model.load_state_dict(checkpoint['state_dict'])


In [3]:
# 剪枝过后的 model 来初始化，然后把ckpt里面的参数填进去
checkpoint = torch.load('/userhome/34/gyu/logs_sr/3wk_p1/prune_1st/pruned.pth.tar')

model = models.__dict__['resnet'](dataset='cifar10', depth=164, cfg=checkpoint['cfg'])
model.load_state_dict(checkpoint['state_dict'])                
# model.cuda()



IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [4]:
model

resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(11, 12, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(12, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(14, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection(

In [5]:
total = 0
# total ： 总channel数量

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

In [6]:
bn = torch.zeros(total)

index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone() # 的确应该取绝对值
        index += size

In [7]:
bn # 所有bn的weight

tensor([0.0000, 0.4114, 0.0000,  ..., 0.6796, 0.6411, 0.6802])

In [8]:
y, i = torch.sort(bn)
thre_index = int(total * 0.2)
thre = y[thre_index]
# y,i = y.cuda(),i.cuda()
# thre = thre.cuda()

In [9]:
# y: 排序后的bn数组
# i: 对应元素在排序前的index
# thre: 小于阈值的被剪枝

In [10]:
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre).float() # 大于阈值 mask 对应元素为 1
        pruned = pruned + mask.shape[0] - torch.sum(mask) # pruned += mask中为0的元素个数
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask))) # cfg 为mask中元素等于1的个数
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

layer index: 4 	 total channel: 16 	 remaining channel: 11
layer index: 7 	 total channel: 12 	 remaining channel: 12
layer index: 9 	 total channel: 14 	 remaining channel: 13
layer index: 15 	 total channel: 64 	 remaining channel: 31
layer index: 18 	 total channel: 12 	 remaining channel: 11
layer index: 20 	 total channel: 11 	 remaining channel: 11
layer index: 24 	 total channel: 64 	 remaining channel: 41
layer index: 27 	 total channel: 13 	 remaining channel: 12
layer index: 29 	 total channel: 13 	 remaining channel: 13
layer index: 33 	 total channel: 64 	 remaining channel: 44
layer index: 36 	 total channel: 12 	 remaining channel: 12
layer index: 38 	 total channel: 13 	 remaining channel: 12
layer index: 42 	 total channel: 64 	 remaining channel: 39
layer index: 45 	 total channel: 13 	 remaining channel: 13
layer index: 47 	 total channel: 14 	 remaining channel: 14
layer index: 51 	 total channel: 64 	 remaining channel: 34
layer index: 54 	 total channel: 13 	 remai

In [11]:
# for tmp in range(len(cfg)):
#     print(cfg[tmp],len(cfg_mask[tmp]),int(torch.sum(cfg_mask[tmp])))

In [12]:
len(cfg_mask)

163

In [13]:


pruned_ratio = pruned/total

print('Pre-processing Successful!')
print(pruned_ratio)

Pre-processing Successful!
tensor(0.2001)


In [14]:
# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test(model):
    model.cuda()
    kwargs = {'num_workers': 1, 'pin_memory': True} 
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
        batch_size=64, shuffle=False, **kwargs)
    model.eval()
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    model.cpu()
    return correct / float(len(test_loader.dataset))


In [15]:
acc = test(model)
print(acc)

  



Test set: Accuracy: 2235/10000 (22.0%)

tensor(0)


In [16]:
# 从这里定义Net的源码 可以发现cfg决定了channel的数量
# 这里通过剪枝得到了剪完以后的cfg 根据这个来init newmodel
print("Cfg:")
print(cfg)
print(len(cfg))
print(len(cfg_mask))

Cfg:
[11, 12, 13, 31, 11, 11, 41, 12, 13, 44, 12, 12, 39, 13, 14, 34, 13, 10, 37, 11, 12, 37, 12, 13, 40, 10, 10, 43, 11, 12, 35, 9, 10, 38, 12, 11, 32, 7, 15, 43, 12, 8, 40, 11, 8, 41, 12, 8, 44, 11, 10, 39, 8, 7, 40, 24, 28, 89, 20, 30, 83, 20, 32, 84, 24, 29, 87, 23, 30, 85, 21, 30, 87, 22, 30, 78, 19, 29, 88, 26, 26, 83, 17, 31, 88, 21, 28, 83, 23, 30, 86, 23, 30, 90, 21, 29, 85, 24, 31, 83, 21, 32, 93, 20, 31, 85, 23, 30, 85, 43, 63, 197, 37, 64, 188, 45, 64, 189, 45, 64, 190, 42, 64, 200, 38, 64, 200, 42, 64, 195, 39, 64, 205, 49, 64, 195, 38, 64, 208, 41, 64, 205, 48, 64, 199, 45, 64, 201, 46, 64, 212, 51, 63, 217, 44, 64, 200, 37, 64, 215, 48, 64, 242]
163
163


In [17]:

newmodel = models.__dict__["resnet"](dataset="cifar10",depth=164,cfg=cfg)
# newmodel = resnet(depth=164, dataset='cifar-10', cfg=cfg)
newmodel

resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(11, 12, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(12, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(13, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection(

In [18]:

# num of param in new model
num_parameters = sum([param.nelement() for param in newmodel.parameters()])

old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0


# 把原来model里面的参数 填入new model

for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]
#     print('start_mask: ',start_mask)
#     print('end_mask: ',end_mask)
    if isinstance(m0, nn.BatchNorm2d):
#         print("current is BN layer")
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))
        
        # 当后面一层是channel selection时
        if isinstance(old_modules[layer_id + 1], channel_selection):
#             print("  next layer is CS")
            # If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
            # 既然后面是 channel selection layer，就不用挑出哪些是prune的
            # 直接原模型所有参数apply到新的上
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()

            # We need to set the channel selection layer.
            # 需要把cs layer的indexs参数改成这里的idx就行
            # 用cs layer 来实现prune的效果
            m2 = new_modules[layer_id + 1]
            m2.indexes.data.zero_()
            m2.indexes.data[idx1.tolist()] = 1.0

            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):
                end_mask = cfg_mask[layer_id_in_cfg]
        else:
            # 从老模型里面挑出保留下来的channel
            # 然后把这些参数apply到新的模型
#             print("  next layer is NOT CS")
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            m1.running_mean = m0.running_mean[idx1.tolist()].clone()
            m1.running_var = m0.running_var[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d):
#         print("current is conv layer")
        if conv_count == 0: # 第一层卷积层不prune
#             print("  current is first conv layer")
            m1.weight.data = m0.weight.data.clone()
            conv_count += 1
            continue
        if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
            # 所有残差模块里面的conv层要么跟在bn后面 要么跟在channel selection后面
            # This convers the convolutions in the residual block.
            # The convolutions are either after the channel selection layer or after the batch normalization layer.
            conv_count += 1
            # start_mask 指的是当前层的前一层的mask
            # end_mask 代表当前层
            
            
            #奇怪的补丁：
            if isinstance(old_modules[layer_id-1], channel_selection):
                old_prev_cs = old_modules[layer_id-1]
                new_prev_cs = new_modules[layer_id-1]
                old_ones = np.squeeze(np.argwhere(old_prev_cs.indexes.data))
                new_ones = np.squeeze(np.argwhere(new_prev_cs.indexes.data))
                start_mask = [0] * len(old_ones)
                for idx in range(len(old_ones)):
                    if old_ones[idx] in new_ones:
                        start_mask[idx] = 1
                    else:
                        start_mask[idx] = 0
                start_mask = torch.FloatTensor(start_mask)
                
#             print('conv_count: ',conv_count)
#             print('layer_id: ',layer_id)
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
#             print(idx0)
#             print(idx1)
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
#             print("idx0 to list: ")
#             print(idx0.tolist())
#             print(m0.weight.data.shape)
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
#             print('w1')
#             print(w1)
            # If the current convolution is not the last convolution in the residual block, then we can change the 
            # number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
            # 这里残差模块里面一共有3个conv层，prune的时候不剪第三个conv层
            if conv_count % 3 != 1:
                w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            continue

        # We need to consider the case where there are downsampling convolutions. 
        # For these convolutions, we just copy the weights.
        # 对于下采样层也不剪枝
        m1.weight.data = m0.weight.data.clone()
    elif isinstance(m0, nn.Linear):
        
        # 奇怪的补丁
        old_prev_cs = old_modules[layer_id-3]
        new_prev_cs = new_modules[layer_id-3]
        old_ones = np.squeeze(np.argwhere(old_prev_cs.indexes.data))
        new_ones = np.squeeze(np.argwhere(new_prev_cs.indexes.data))
        start_mask = [0] * len(old_ones)
        for idx in range(len(old_ones)):
            if old_ones[idx] in new_ones:
                start_mask[idx] = 1
            else:
                start_mask[idx] = 0
        start_mask = torch.FloatTensor(start_mask)

        
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()



In shape: 11, Out shape 12.
In shape: 12, Out shape 13.
In shape: 13, Out shape 31.
In shape: 31, Out shape 11.
In shape: 11, Out shape 11.
In shape: 11, Out shape 41.
In shape: 41, Out shape 12.
In shape: 12, Out shape 13.
In shape: 13, Out shape 44.
In shape: 44, Out shape 12.
In shape: 12, Out shape 12.
In shape: 12, Out shape 39.
In shape: 39, Out shape 13.
In shape: 13, Out shape 14.
In shape: 14, Out shape 34.
In shape: 34, Out shape 13.
In shape: 13, Out shape 10.
In shape: 10, Out shape 37.
In shape: 37, Out shape 11.
In shape: 11, Out shape 12.
In shape: 12, Out shape 37.
In shape: 37, Out shape 12.
In shape: 12, Out shape 13.
In shape: 13, Out shape 40.
In shape: 40, Out shape 10.
In shape: 10, Out shape 10.
In shape: 10, Out shape 43.
In shape: 43, Out shape 11.
In shape: 11, Out shape 12.
In shape: 12, Out shape 35.
In shape: 35, Out shape 9.
In shape: 9, Out shape 10.
In shape: 10, Out shape 38.
In shape: 38, Out shape 12.
In shape: 12, Out shape 11.
In shape: 11, Out shap

In [19]:
# 不保存
# torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))
print(newmodel)


resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(11, 12, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(12, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(13, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection(

  



Test set: Accuracy: 2235/10000 (22.0%)



tensor(0)

In [None]:
model = newmodel
test(model)

In [20]:
np.squeeze(np.argwhere(np.asarray(torch.FloatTensor(start_mask).cpu().data)))

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,
       105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,
       118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
       131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
       144, 145, 146, 147, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171,
       172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 18

In [None]:
new_cs = new_modules[layer_id-1]

In [None]:
print('len of cs indexes: ',len(np.asarray(new_cs.indexes.data)))
print(np.asarray(new_cs.indexes.data))
print(np.squeeze(np.argwhere(np.asarray(new_cs.indexes.data))))

In [None]:
idx1

In [None]:
len(old_modules)

In [None]:
old_idx =[0,1,1,1,1,1,1]
new_idx =[0,0,1,1,0,0,0]

In [None]:
old_ones = np.squeeze(np.argwhere(old_idx))

In [None]:
new_ones = np.squeeze(np.argwhere(new_idx))

In [None]:
new_ones.shape[0]

In [None]:
old_ones

In [None]:
np.setdiff1d(old_ones,new_ones)

In [None]:
tmp = [1,2,3,4,5]

In [None]:
del tmp,[0,4]

In [None]:
a = [0] * 4

In [None]:
a[2]=1

In [None]:
a

In [None]:
1 in [10,20,30]