In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES']='2'
import pprint

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchsummary import summary

import _init_paths
from core.config import config
from core.config import update_config
from core.config import update_dir
from core.loss import JointsMSELoss
from core.function import validate
from utils.utils import create_logger

import dataset
import models

import quantize_dorefa
from quantize_iao import *
# from quantize_iao_uint import *  #对feature map进行uint对称量化

import numpy as np
# 保证所有数据能够显示，而不是用省略号表示，np.inf表示一个足够大的数
np.set_printoptions(threshold = np.inf) 

# # 若想不以科学计数显示:
# np.set_printoptions(suppress = True)

In [2]:
def bn_fuse_conv(bn_conv,device):
    # ******************** bn参数 *********************
    mean = bn_conv.running_mean
    std = torch.sqrt(bn_conv.running_var + bn_conv.eps)
    gamma = bn_conv.gamma
    beta = bn_conv.beta
    # ******************* conv参数 ********************
    w = bn_conv.weight
    w_fused = w.clone()
    if bn_conv.bias is not None:
        b = bn_conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    b_fused = b.clone()
    # ******************* bn融合 *******************
    w_fused = w * (gamma / std).reshape([bn_conv.out_channels, 1, 1, 1])
    b_fused = beta + (b - mean) * (gamma / std)
    bn_fused_conv = QuantConv2d(bn_conv.in_channels,
                                         bn_conv.out_channels,
                                         bn_conv.kernel_size,
                                         stride=bn_conv.stride,
                                         padding=bn_conv.padding,
                                         dilation=bn_conv.dilation,
                                         groups=bn_conv.groups,
                                         bias=True,
                                         padding_mode=bn_conv.padding_mode,
                                         a_bits=config.QUANTIZATION.A_BITS,
                                         w_bits=config.QUANTIZATION.W_BITS,
                                         q_type=config.QUANTIZATION.Q_TYPE,
                                         q_level=config.QUANTIZATION.Q_LEVEL,
                                         device=device,
                                         quant_inference=True)
    bn_fused_conv.weight.data = w_fused
    bn_fused_conv.bias.data = b_fused
    bn_fused_conv.activation_quantizer.scale.copy_(bn_conv.activation_quantizer.scale)
    bn_fused_conv.activation_quantizer.zero_point.copy_(bn_conv.activation_quantizer.zero_point)
    bn_fused_conv.activation_quantizer.eps = bn_conv.activation_quantizer.eps
    bn_fused_conv.weight_quantizer.scale.copy_(bn_conv.weight_quantizer.scale)
    bn_fused_conv.weight_quantizer.zero_point.copy_(bn_conv.weight_quantizer.zero_point)
    bn_fused_conv.weight_quantizer.eps = bn_conv.weight_quantizer.eps
    return bn_fused_conv

def bn_fuse_deconv(bn_conv,device):
    # ******************** bn参数 *********************
    mean = bn_conv.running_mean
    std = torch.sqrt(bn_conv.running_var + bn_conv.eps)
    gamma = bn_conv.gamma
    beta = bn_conv.beta
    # ******************* conv参数 ********************
    w = bn_conv.weight
    w_fused = w.clone()
    if bn_conv.bias is not None:
        b = bn_conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    b_fused = b.clone()
    # ******************* bn融合 *******************
    w_fused = w * (gamma / std).reshape([bn_conv.out_channels, 1, 1, 1])
    b_fused = beta + (b - mean) * (gamma / std)
    bn_fused_conv = QuantConvTranspose2d(bn_conv.in_channels,
                                         bn_conv.out_channels,
                                         bn_conv.kernel_size,
                                         stride=bn_conv.stride,
                                         padding=bn_conv.padding,
                                         output_padding=bn_conv.output_padding,
                                         dilation=bn_conv.dilation,
                                         groups=bn_conv.groups,
                                         bias=True,
                                         padding_mode=bn_conv.padding_mode,
                                         a_bits=config.QUANTIZATION.A_BITS,
                                         w_bits=config.QUANTIZATION.W_BITS,
                                         q_type=config.QUANTIZATION.Q_TYPE,
                                         q_level=config.QUANTIZATION.Q_LEVEL,
                                         device=device,
                                         quant_inference=True)
    bn_fused_conv.weight.data = w_fused
    bn_fused_conv.bias.data = b_fused
    bn_fused_conv.activation_quantizer.scale.copy_(bn_conv.activation_quantizer.scale)
    bn_fused_conv.activation_quantizer.zero_point.copy_(bn_conv.activation_quantizer.zero_point)
    bn_fused_conv.activation_quantizer.eps = bn_conv.activation_quantizer.eps
    bn_fused_conv.weight_quantizer.scale.copy_(bn_conv.weight_quantizer.scale)
    bn_fused_conv.weight_quantizer.zero_point.copy_(bn_conv.weight_quantizer.zero_point)
    bn_fused_conv.weight_quantizer.eps = bn_conv.weight_quantizer.eps
    return bn_fused_conv

In [3]:
def bn_fuse_module(module, device):
    for name, child in module.named_children():
        if isinstance(child, QuantBNFuseConv2d):
            bn_fused_conv = bn_fuse_conv(child, device)
            module._modules[name] = bn_fused_conv
        elif isinstance(child, QuantBNFuseConvTranspose2d):
            bn_fused_deconv = bn_fuse_deconv(child, device)
            module._modules[name] = bn_fused_deconv
        else:
            bn_fuse_module(child, device)


def model_bn_fuse(model, inplace=False):
    if not inplace:
        model = copy.deepcopy(model)
    device = next(model.parameters()).device
    bn_fuse_module(model,device)
    return model

In [4]:
def select_device(device='', apex=False, batch_size=None):
    # device = 'cpu' or '0' or '0,1,2,3'
    cpu_request = device.lower() == 'cpu'
    if device and not cpu_request:  # if device requested other than 'cpu'
        # os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable
        assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device  # check availablity

    cuda = False if cpu_request else torch.cuda.is_available()
    if cuda:
        c = 1024 ** 2  # bytes to MB
        ng = torch.cuda.device_count()
        if ng > 1 and batch_size:  # check that batch_size is compatible with device_count
            assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
        x = [torch.cuda.get_device_properties(i) for i in range(ng)]
        s = 'Using CUDA ' + ('Apex ' if apex else '')  # apex for mixed precision https://github.com/NVIDIA/apex
        for i in range(0, ng):
            if i == 1:
                s = ' ' * len(s)
            print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
                  (s, i, x[i].name, x[i].total_memory / c))
    else:
        print('Using CPU')

    print('')  # skip a line
    return torch.device('cuda:0' if cuda else 'cpu')

In [5]:
cfg='../experiments/coco/resnet50/mobile_quant_relu_int.yaml' #MODEL_FILE: 'output/weights_quan/float_mobilenetpose_nobn.pt'
update_config(cfg)
# cudnn related setting
cudnn.benchmark = config.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = config.CUDNN.ENABLED

# for shufflenetv2
shufflenetv2_spec = {'0.5': ([4, 8, 4], [24, 48, 96, 192, 1024]),
                        '1.0': ([4, 8, 4], [24, 116, 232, 464, 1024]),
                        '1.5': ([4, 8, 4], [24, 176, 352, 704, 1024]),
                        '2.0': ([4, 8, 4], [24, 244, 488, 976, 2048])}
stages_repeats, stages_out_channels = shufflenetv2_spec['1.0']
print('models.'+config.MODEL.NAME+'.get_pose_net')
model = eval('models.'+config.MODEL.NAME+'.get_pose_net')(
        config, 
        stages_repeats, stages_out_channels,
        is_train=False
    )

models.pose_mobilenet_relu.get_pose_net


  exp_config = edict(yaml.load(f))


In [6]:
####################################### bnfuse model ############################################
bnfuse_model = eval('models.pose_mobilenet_relu_bnfuse.get_pose_net')(
    config, 
    stages_repeats, stages_out_channels,
    is_train=False
)

In [7]:
gpus = [int(i) for i in config.GPUS.split(',')]
device = select_device(config.GPUS, batch_size=config.TEST.BATCH_SIZE*len(gpus))

model = model.to(device)
# summary(model,input_size=(3, 256, 192))

Using CUDA device0 _CudaDeviceProperties(name='Tesla V100-PCIE-32GB', total_memory=32510MB)



In [8]:
#print('*******************ori_model*******************\n', model)
if(config.QUANTIZATION.QUANT_METHOD == 1): # DoReFa
    quantize_dorefa.prepare(model, inplace=True, a_bits=config.QUANTIZATION.A_BITS, w_bits=config.QUANTIZATION.W_BITS, quant_inference=config.QUANTIZATION.QUANT_INFERENCE, is_activate=False)
else: #default quant_method == 0   IAO
    prepare(model, inplace=True, a_bits=config.QUANTIZATION.A_BITS, w_bits=config.QUANTIZATION.W_BITS,q_type=config.QUANTIZATION.Q_TYPE, q_level=config.QUANTIZATION.Q_LEVEL, device=device,#device=next(model.parameters()).device, 
                        weight_observer=config.QUANTIZATION.WEIGHT_OBSERVER, bn_fuse=config.QUANTIZATION.BN_FUSE, quant_inference=config.QUANTIZATION.QUANT_INFERENCE)
#print('\n*******************quant_model*******************\n', model)
# print('\n*******************Using quant_model in test*******************\n')

a_bits= 8 	w_bits= 8 	q_type= 0 	q_level= 0 	device= cuda:0 	weight_observer= 0 	bn_fuse= 1 	quant_inference= False


In [9]:
# if config.TEST.MODEL_FILE:
#     # logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
#     if(config.TEST.MODEL_FILE.split('/')[-1]=='checkpoint.pth.tar'):
#         model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
#         #model.load_state_dict(torch.load(config.TEST.MODEL_FILE,map_location=torch.device('cuda'))['state_dict'])
#         model.load_state_dict(torch.load(config.TEST.MODEL_FILE,map_location=device)['state_dict'])
#         #torch.save(model.module.state_dict(), 'output/coco_quan/mobile_quant_relu_w8a8_bnfuse0/checkpoint_nomodule.pth.tar')
#     elif(config.TEST.MODEL_FILE.split('/')[-1]=='model_best.pth.tar'):  #multiGPU has model.module.
#         model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
#         model.load_state_dict(torch.load(config.TEST.MODEL_FILE,map_location=device))
#     elif(config.TEST.MODEL_FILE.split('/')[-1]=='checkpoint_resave.pth.tar'):  #multiGPU has model.module.
#         model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
#         model.load_state_dict(torch.load(config.TEST.MODEL_FILE,map_location=device))
#     else:  #final_state.pth.tar
#         model.load_state_dict(torch.load(config.TEST.MODEL_FILE,map_location=device))
#         model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

In [10]:
######################################################################################################
# ********************* quant_bn_fused_model_inference **********************
model.to(device)
model_bn_fuse(model, inplace=True)  # bn融合
# print('\n*******************For inference bn_fuse quant_model*******************\n', model)
# ckpt = {'model': model.module.state_dict() if hasattr(model, 'module') else model.state_dict()}
# torch.save(ckpt, '../output/weights_quan/int8_mobilenet8_relu_bnfuse_inference.pt')
print('*******************For inference bn_fuse quant_model*******************')

*******************For inference bn_fuse quant_model*******************


In [11]:
model.load_state_dict(torch.load('../'+config.TEST.MODEL_FILE,map_location=device)['model'])  ##为什么还在'model'里面呀？
# model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

<All keys matched successfully>

In [12]:
remapped_state = {}
print('Model.state_dict:')
# ######################################## before #######################################
# for n,param_tensor in enumerate(model.state_dict()):
#     #打印 key value字典
#     print(n, param_tensor,'\t',model.state_dict()[param_tensor].size())
#     # if(n<5):
#     #     print(n, param_tensor,'\t',model.state_dict()[param_tensor].size())
#     #     # print(model.state_dict()[param_tensor])

# ######################################### after #######################################
# for n,param_tensor in enumerate(bnfuse_model.state_dict()):
#     #打印 key value字典
#     print(n, param_tensor,'\t',bnfuse_model.state_dict()[param_tensor].size())
#     # if(n<4):
#     #     print(n, param_tensor,'\t',model.state_dict()[param_tensor].size())
#     #     print(model.state_dict()[param_tensor])

Model.state_dict:


In [14]:
remapped_state = {}
for n,state_key in enumerate(bnfuse_model.state_dict().keys()):
    k = state_key.split('.') # pytorch  ['features', '0', '0', 'weight']
    if(k[0]!='final_layer'):
        number = int(k[-2])//2*3 
        # print(number)
        k[-2]=str(number)
        # print(k)
        remapped_state_key=('.').join(k) #进行重映射
    else: #final_layer
        remapped_state_key=state_key
    print(n, state_key, model.state_dict()[remapped_state_key].shape)
    remapped_state[state_key]= model.state_dict()[remapped_state_key]     

0 features.0.0.weight torch.Size([16, 3, 3, 3])
1 features.0.0.bias torch.Size([16])
2 features.1.conv.0.weight torch.Size([16, 1, 3, 3])
3 features.1.conv.0.bias torch.Size([16])
4 features.1.conv.2.weight torch.Size([8, 16, 1, 1])
5 features.1.conv.2.bias torch.Size([8])
6 features.2.conv.0.weight torch.Size([48, 8, 1, 1])
7 features.2.conv.0.bias torch.Size([48])
8 features.2.conv.2.weight torch.Size([48, 1, 3, 3])
9 features.2.conv.2.bias torch.Size([48])
10 features.2.conv.4.weight torch.Size([16, 48, 1, 1])
11 features.2.conv.4.bias torch.Size([16])
12 features.3.conv.0.weight torch.Size([96, 16, 1, 1])
13 features.3.conv.0.bias torch.Size([96])
14 features.3.conv.2.weight torch.Size([96, 1, 3, 3])
15 features.3.conv.2.bias torch.Size([96])
16 features.3.conv.4.weight torch.Size([16, 96, 1, 1])
17 features.3.conv.4.bias torch.Size([16])
18 features.4.conv.0.weight torch.Size([96, 16, 1, 1])
19 features.4.conv.0.bias torch.Size([96])
20 features.4.conv.2.weight torch.Size([96, 1, 

In [15]:
bnfuse_model.load_state_dict(remapped_state)

<All keys matched successfully>

In [18]:
ckpt = {'model': bnfuse_model.module.state_dict() if hasattr(bnfuse_model, 'module') else bnfuse_model.state_dict()}
torch.save(ckpt, '../output/weights_quan/float_mobilenetpose_nobn.pt')

In [13]:
#导出权重和偏置至二进制文件中
result_path='scripts/exported_weights/'
#result_path='scripts/before_bnfuse/'
filename=''
a=0.0
print('Model.state_dict:')
for n,param_tensor in enumerate(model.state_dict()):
    #打印 key value字典
    # print(n, param_tensor,'\t',model.state_dict()[param_tensor].size())
    if(n==2): #if(n<4*2):
        print(n, param_tensor,'\t',model.state_dict()[param_tensor].size())
        a=model.state_dict()[param_tensor].numpy()  # module_list.0.0.weight   a:tensor->numpy dtype(float32) 
        #print(a.shape)  # [out_channel,in_channel,kernel_size,kernel_size] [8,3,3,3]
        #filename=('_').join(param_tensor.split('.')) + '.bin'  # module_list_0_0_weight
        print(filename)

        if(param_tensor.split('.')[-1]=='weight'):
            
            # ############################################# needed: [outer for, ..., inner for] [kernel_size,kernel_size,out_channel,in_channel] [3,3,8,3]
        elif(param_tensor.split('.')[-1]=='bias'):
            a_reshape=a
        else:
            print("###################### WARNING ########################\n")
            print(model.state_dict()[param_tensor])

        print(a_reshape)
        #a_reshape.astype(np.float32).tofile(result_path+filename+'.bin')
        np.savetxt(result_path+filename+'.txt', a_reshape, fmt="%f", delimiter='  ')
        #b=np.loadtxt(result_path+filename, dtype=np.float32, delimiter='  ')


Model.state_dict:
2 conv1.activation_quantizer.zero_point 	 torch.Size([1])


IndexError: list index out of range

In [8]:
###验证导出的权重是否正确
######################以下是bnfuse后的权重，与bnfuse_model.pt对应######################
###格式： [out_ch,in_ch,kernel_size1,kernel_size0]
###第一个是普通卷积
weight0_shape=[8,3,3,3]
bias0_shape=[8]
###接下来是深度可分离卷积
weight1_shape=[8,8,1,1]
bias1_shape=[8]
weight2_shape=[8,1,3,3]
bias2_shape=[8]
weight3_shape=[4,8,1,1]
bias3_shape=[4]

list=[[8,3,3,3],[8],[8,8,1,1],[8],[8,1,3,3],[8],[4,8,1,1],[4]]

#####################给的bin文件是权重重排后的结果######################
###格式： [kernel_size1,kernel_size0,out_ch,in_ch]
#以下可以将.bin文件中的权重返回.pt的数据排列，以便进行数据比对
result_path='./'
filename=['0_weight','0_bias','1_weight','1_bias','2_weight','2_bias','3_weight','3_bias']
for i in range(4*2):
    b=np.fromfile(result_path+filename[i]+'.bin',dtype=np.float32) ##一维的数据排列 ->[kernel_size1*kernel_size0*out_ch*in_ch]
    if(i==1):
        if(filename[i].split('_')[-1]=='weight'): #是weight
            b=b.reshape(list[i][2]*list[i][3],list[i][0],list[i][1]) #->[kernel_size1*kernel_size0,out_ch,in_ch]
            print(b.shape)
            b=b.transpose(1,2,0) #->[out_ch,in_ch,kernel_size1*kernel_size0]
            print(b.shape)
            b=b.reshape(list[i][0],list[i][1],list[i][2],list[i][3]) #->[out_ch,in_ch,kernel_size1,kernel_size0]
            print(b.shape)
            print(b)
        else: #否则是bias 无需转换，直接打印
            print(b)



FileNotFoundError: [Errno 2] No such file or directory: './0_weight.bin'

In [35]:
#b=np.fromfile(result_path+filename+'.bin',dtype=np.float32)
b=np.loadtxt(result_path+filename+'.txt', dtype=np.float32, delimiter='  ')
print(b.shape)
b=b.reshape(9,8,3)
print(b.shape)
b=b.transpose(1,2,0) #[9,8,3]->[8,3,9]
print(b.shape)
b=b.reshape(8,3,3,3)
print(b.shape)
print(b)

# print(a.shape)
# print(a.shape[0],a.shape[1],a.shape[2],a.shape[3])
# #a_resize=a.reshape([8,3,-1])
# a_reshape=a.reshape([a.shape[0],a.shape[1],-1])
# print(a_reshape.shape)
# a_reshape=a_reshape.transpose(2,0,1) #[8,3,9]->[9,8,3]
# print(a_reshape.shape)

(72, 3)
(9, 8, 3)
(8, 3, 9)
(8, 3, 3, 3)
[[[[ 0.169188  0.189951  0.165524]
   [ 0.15833   0.17425   0.18598 ]
   [ 0.181422  0.20745   0.193455]]

  [[ 0.127036  0.130175  0.13121 ]
   [ 0.108236  0.123582  0.129688]
   [ 0.135348  0.153904  0.160524]]

  [[ 0.081796  0.095659  0.107188]
   [ 0.101577  0.093322  0.11248 ]
   [ 0.09426   0.111809  0.128836]]]


 [[[-0.560707 -0.90087  -0.724879]
   [-0.370212 -0.85396  -0.942397]
   [-0.189665 -0.430859 -0.813114]]

  [[ 0.23858  -0.06389  -0.333687]
   [ 0.172932 -0.257281 -0.48805 ]
   [ 0.045519 -0.056772 -0.272473]]

  [[ 0.176914  0.34184   0.036285]
   [-0.111699  0.024917 -0.026201]
   [-0.267103  0.064347  0.139831]]]


 [[[-0.286476 -0.299834 -0.234171]
   [-0.238944 -0.192332 -0.181511]
   [-0.090237  0.026358 -0.08187 ]]

  [[-0.115784 -0.009535 -0.043332]
   [ 0.017743  0.077202  0.13005 ]
   [ 0.133964  0.113496  0.13871 ]]

  [[-0.151784 -0.026175 -0.001373]
   [ 0.069346  0.049721 -0.005715]
   [ 0.032771  0.093534  0.07

In [9]:
#a.reshape(8,3,9).transpose(2,1,0).shape #一次可以置换三个
#a.swapaxes(1,2)  #swapaxes只能两两置换 对于swapaxes来说，括号内的两参数，交换位置和不交换，实际结果相同。

# import cv2
# image=cv2.imread('/home/ytwang/dataset/VOC/images/val/000001.jpg')
# cv2.imshow("image",image)
# cv2.waitKey(0)
#cv2.destroyAllWindows()

In [10]:
##################################### 权重、feature map导出 ##################################
x0=torch.randn(1,3,352,256)
print(x0.shape) #[1,3,352,256]
x=x0.numpy().squeeze()
print(x.shape) #(3, 352, 256)
x=x.transpose(1,2,0)
print(x.shape) #(352,256,3)
#x.astype(np.float32).tofile(result_path+'input000001_352x256.bin') # 二进制文件导出
x=x.reshape([-1,x.shape[-1]]) #(90112,3)=(352*256,3)
print(x.shape)
#np.savetxt(result_path+'input.txt', x, fmt="%f", delimiter=',\t') # .txt文件导出



torch.Size([1, 3, 352, 256])
(3, 352, 256)
(352, 256, 3)
(90112, 3)


In [11]:
#b=np.fromfile(result_path+'input000001_352x256.bin',dtype=np.float32)
# b=np.loadtxt(result_path+'input.txt', dtype=np.float32, delimiter=',\t')
# print(b.shape)
# b=b.reshape(352,256,3).transpose(2,0,1)
# print(b.shape)
# b=b.reshape(1,3,352,256)
# print(b.shape)
# b=torch.tensor(b)
# type(b)
b=np.loadtxt(result_path+'input.txt', dtype=np.float32, delimiter=',\t')
print(b.shape)
print(x0.shape)



OSError: ./input.txt not found.

In [23]:
b=b.reshape(x0.shape[2],x0.shape[3],x0.shape[1]).transpose(2,0,1)
print(b.shape)
b=np.expand_dims(b,0) #[3,352,256]->[1,3,352,256]
print(b.shape)
x=torch.tensor(b)
print(type(x))

(3, 352, 256)
(1, 3, 352, 256)
<class 'torch.Tensor'>


In [13]:
# (x0==b).all()
print(x0[0][0][0]) #torch.Size([1, 3, 352, 256])

tensor([-6.81406e-01, -8.01516e-01,  3.28198e-01, -7.11460e-01, -9.71156e-01, -2.00248e-02, -1.63450e+00,  9.71171e-01, -7.00366e-01,  1.60246e+00, -1.78841e+00,  3.63713e-01,  9.74244e-01,  1.33411e+00, -3.02220e-01, -4.06555e-01,  1.37745e-01, -6.31540e-01, -1.46539e+00,  2.76620e-01,  1.23277e+00, -1.40946e+00,
        -1.51452e+00,  1.19671e-01,  4.42624e-01, -7.22908e-01, -4.04740e-01, -4.34604e-01,  2.06584e-01, -1.40587e+00,  5.70870e-01,  1.07140e-01,  1.61050e+00,  1.02770e+00, -3.53327e-01,  2.04177e-01,  2.31896e-02,  1.28927e+00, -5.48752e-01,  1.42098e-01,  9.03911e-01, -1.20936e+00,  9.63919e-01, -1.01060e+00,
         8.54119e-02, -1.53845e-01, -1.54839e+00, -6.02410e-02, -2.57630e+00,  3.93433e-01,  2.39229e-01, -1.21564e-01, -1.15327e+00,  1.26149e+00,  1.45040e+00, -9.07547e-01,  1.31109e+00, -1.65926e-02, -3.07028e-01,  6.83296e-01, -5.43891e-01,  6.85024e-01,  4.69522e-01,  1.76420e+00,  1.61510e+00,  1.25916e+00,
        -1.77150e-01,  2.47641e+00, -1.46249e+00, -2

In [36]:
# print(b)

In [18]:
####################################### 验证bnfuse后结果是否正确 #########################################
bn_weight=torch.tensor([
    0.9627532362937927,
    1.6811193227767944,
    1.0011911392211914,
    1.6239477396011353,
    1.3853583335876465,
    0.9575594663619995,
    1.325505256652832,
    0.9901668429374695])
bn_bias=torch.tensor([
    -24.751123428344727,
    16.553693771362305,
    -21.58427619934082,
    13.835719108581543,
    7.6066741943359375,
    12.691370010375977,
    24.215431213378906,
    -10.96066665649414])
bn_running_mean=torch.tensor([
    0.2264288365840912,
    -1.438944935798645,
    -0.012756695970892906,
    -1.1514368057250977,
    2.064279079437256,
    -1.000322699546814,
    1.4272160530090332,
    -0.01157035119831562])
bn_running_var=torch.tensor([
    0.018602905794978142,
    0.7236471176147461,
    0.0006818491965532303,
    0.586725652217865,
    1.5904737710952759,
    0.3611223101615906,
    0.7421475648880005,
    0.00004840613837586716])
bn_eps=torch.tensor([1E-4]).expand([8])

In [22]:
# bn.bias=
# bn.weight=
# bn.running_mean=
# bn.running_var=
# bn.eps=
afterfuse_b=bn_bias-bn_weight*bn_running_mean/(torch.sqrt(bn_running_var+bn_eps))
print(afterfuse_b)

b_bn = bn_bias - bn_weight.mul(bn_running_mean).div(torch.sqrt(bn_running_var + bn_eps))
print(b_bn)


tensor([-26.34514,  19.39717, -21.12751,  16.27666,   5.33914,  14.28511,  22.01961, -10.02023])
tensor([-26.34514,  19.39717, -21.12751,  16.27666,   5.33914,  14.28511,  22.01961, -10.02023])


In [None]:
############################### torch_utils.py中的bu_fuse实现 ##########################
def fuse_conv_and_bn(conv, bn):
    # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    with torch.no_grad():
        # init
        fusedconv = torch.nn.Conv2d(conv.in_channels,
                                    conv.out_channels,
                                    kernel_size=conv.kernel_size,
                                    stride=conv.stride,
                                    padding=conv.padding,
                                    groups=conv.groups,
                                    bias=True)

        # prepare filters
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
        fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

        # prepare spatial bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros(conv.weight.size(0))
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
        fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

        return fusedconv

In [None]:
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))