In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from src.model import CRNN

## load model

In [2]:
model = CRNN(img_channel=32, img_height=128, img_width=128, num_class=5)

In [3]:
model.cnn

Sequential(
  (conv0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu0): ReLU(inplace)
  (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace)
  (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace)
  (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace)
  (pooling

In [14]:
cfg_new

[33, 66, 139, 145, 257, 232, 247]

## setting parameter

In [4]:
# 剪枝比率
pruning_rate = 0.5

In [5]:
# cfg
'''
model: 各架構位置
skip: 不剪枝的層數
cfg: 剪枝後剩餘的 channel 數量
cfg_mask: 剪枝後剩餘 channel 的位置
cat_layer: 有 concat 的層數
'''
pruning_cfg = {
    'cnn':{
        'model': model.cnn,
        'skip': [],
        'cfg': [],
        'cfg_mask': [],
        'cat_layer': []
    }
}

## compute threshold

In [6]:
"""計算global threshold"""
# 計算總共多少 channels
total = 0
for m in model.cnn.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0] # m.weight 就是 gamma

# 所有 gamma 值 取絕對值存進 bn
bn = torch.zeros(total) # 1*n維
index = 0
for m in model.cnn.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0] # channels
        bn[index:(index + size)] = m.weight.data.abs().clone()
        index += size
# 由小到大排序
y, i = torch.sort(bn) # 小 -> 大
thre_index = int(total * pruning_rate) # scale sparse rate 0.5 剪枝比例
thre = y[thre_index] if thre_index != 0 else 0 # 取第 thre_index 個值當作 threshold，如果 thre_index=0 代表全留，不能取第 0 個要直接改 0
# 之後 weight 會跟 thre 這個數字比大小，產生一個 0, 1 的 tensor，大於 thre 的留下(小於 thre 的就不會被存進 newmodel)
print('Global threshold: {}'.format(thre))
print('Total channels: {}'.format(total))

Global threshold: 0.50532466173172
Total channels: 2240


## start pruning

(改變) !!!!!!! 要先有第一層 image channel !!!!!!! <br>
起初的 image_channel 皆保留

In [26]:
"""記錄誰該留下誰該剪掉"""
pruned = 0
cfg_new = [32] # remaining channel
cfg_mask = [torch.ones(32)] # 記錄每層 channels，以 0,1 表示剪枝，假設 channels=3, cfg_mask=[0,1,1]
for k, m in enumerate(model.cnn.modules()):
    if isinstance(m, nn.BatchNorm2d):
        thre_ = 0 if k in pruning_cfg['cnn']['skip'] else thre # skip 的 layer thre=0
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre_).float() # 比大小，大的標記 1 & 小的標記 0，存進 mask

        cfg_new.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())

        pruned = pruned + mask.shape[0] - torch.sum(mask) # 計算pruning ratio
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
pruned_ratio = pruned / total
print('-------------------------------------------------------------------------')
print('channels pruned / channels total: {} / {}'.format(pruned, total))
print('pruned ratio: {}'.format(pruned_ratio))

layer index: 2 	 total channel: 64 	 remaining channel: 33
layer index: 6 	 total channel: 128 	 remaining channel: 66
layer index: 10 	 total channel: 256 	 remaining channel: 139
layer index: 13 	 total channel: 256 	 remaining channel: 145
layer index: 17 	 total channel: 512 	 remaining channel: 257
layer index: 20 	 total channel: 512 	 remaining channel: 232
layer index: 24 	 total channel: 512 	 remaining channel: 247
-------------------------------------------------------------------------
channels pruned / channels total: 1121.0 / 2240
pruned ratio: 0.5004464387893677


In [27]:
print(cfg_new)

[32, 33, 66, 139, 145, 257, 232, 247]


In [28]:
len(cfg_mask)

8

## save weights to new model

In [15]:
# cfg
'''
model: 各架構位置
skip: 不剪枝的層數
cfg: 剪枝後剩餘的 channel 數量
cfg_mask: 剪枝後剩餘 channel 的位置
cat_layer: 有 concat 的層數
'''
pruning_cfg = {
    'cnn':{
        'model': model.cnn,
        'skip': [],
        'cfg': [32, 33, 66, 139, 145, 257, 232, 247],
        'cfg_mask': [],
        'cat_layer': []
    }
}

In [16]:
# 用新的 cfg 定義新模型架構
newmodel = CRNN(img_channel=32, img_height=128, img_width=128, num_class=5, pruning_cfg=pruning_cfg['cnn']['cfg'])

In [24]:
newmodel.cnn

Sequential(
  (conv0): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm0): BatchNorm2d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu0): ReLU(inplace)
  (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv1): Conv2d(33, 66, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm1): BatchNorm2d(66, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace)
  (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(66, 139, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm2): BatchNorm2d(139, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace)
  (conv3): Conv2d(139, 145, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm3): BatchNorm2d(145, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace)
  (pooling2):

In [42]:
old_modules = list(model.cnn.modules())
new_modules = list(newmodel.cnn.modules())
layer_id_in_cfg = 0
start_mask = cfg_mask[layer_id_in_cfg] # 第一個維度
end_mask = cfg_mask[layer_id_in_cfg+1]

(改變) 先做 conv 在做 batch

In [43]:
for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]

    # 針對 conv
    if isinstance(m0, nn.Conv2d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('=====================================================')
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # in_channel
        w1 = w1[idx1.tolist(), :, :, :].clone() # out_channel
        m1.weight.data = w1.clone() # 存入新的權重

    # 針對 batchnorm
    elif isinstance(m0, nn.BatchNorm2d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))

        # 存入新的權重
        m1.weight.data = m0.weight.data[idx1.tolist()].clone()
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        
        # 跑最後一層會有 list index 超出範圍，所以限制
        if layer_id_in_cfg < 6:
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            end_mask = cfg_mask[layer_id_in_cfg+1]

In shape: 32, Out shape 33.
In shape: 33, Out shape 66.
In shape: 66, Out shape 139.
In shape: 139, Out shape 145.
In shape: 145, Out shape 257.
In shape: 257, Out shape 232.
In shape: 232, Out shape 247.


In [61]:
# 新的 model
# torch.size() 順序是 output channel,  input channel, kernel size
for i in newmodel.cnn.state_dict():
    if ('conv' in i) and ('weight' in i):
        print(("================= {} =================").format(i.split('.')[0]))
        print('Conv shape: {}'.format(newmodel.cnn.state_dict()[i].shape))
    if ('batchnorm' in i) and ('weight' in i):
        print('Batch shape: {}'.format(newmodel.cnn.state_dict()[i].shape))

Conv shape: torch.Size([33, 32, 3, 3])
Batch shape: torch.Size([33])
Conv shape: torch.Size([66, 33, 3, 3])
Batch shape: torch.Size([66])
Conv shape: torch.Size([139, 66, 3, 3])
Batch shape: torch.Size([139])
Conv shape: torch.Size([145, 139, 3, 3])
Batch shape: torch.Size([145])
Conv shape: torch.Size([257, 145, 3, 3])
Batch shape: torch.Size([257])
Conv shape: torch.Size([232, 257, 3, 3])
Batch shape: torch.Size([232])
Conv shape: torch.Size([247, 232, 2, 2])
Batch shape: torch.Size([247])


In [62]:
# 新的 model
# torch.size() 順序是 output channel,  input channel, kernel size
for i in model.cnn.state_dict():
    if ('conv' in i) and ('weight' in i):
        print(("================= {} =================").format(i.split('.')[0]))
        print('Conv shape: {}'.format(model.cnn.state_dict()[i].shape))
    if ('batchnorm' in i) and ('weight' in i):
        print('Batch shape: {}'.format(model.cnn.state_dict()[i].shape))

Conv shape: torch.Size([64, 32, 3, 3])
Batch shape: torch.Size([64])
Conv shape: torch.Size([128, 64, 3, 3])
Batch shape: torch.Size([128])
Conv shape: torch.Size([256, 128, 3, 3])
Batch shape: torch.Size([256])
Conv shape: torch.Size([256, 256, 3, 3])
Batch shape: torch.Size([256])
Conv shape: torch.Size([512, 256, 3, 3])
Batch shape: torch.Size([512])
Conv shape: torch.Size([512, 512, 3, 3])
Batch shape: torch.Size([512])
Conv shape: torch.Size([512, 512, 2, 2])
Batch shape: torch.Size([512])
