# Fuse BN with conv

Source: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/

In [1]:
import torch
import torchvision

In [2]:
def fuse_conv_and_bn(conv, bn):
    fusedconv = torch.nn.Conv2d(
        conv.in_channels,             # tyle samo co w wejściowej konwolucji
        conv.out_channels,            # tyle samo co w wejściowej konwolucji
        kernel_size=conv.kernel_size, # tyle samo co w wejściowej konwolucji
        stride=conv.stride,           # tyle samo co w wejściowej konwolucji
        padding=conv.padding,         # tyle samo co w wejściowej konwolucji
        bias=True
    )
    
    w_conv = conv.weight.clone().view(conv.out_channels, -1)                       # wagi W_Conv
    
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))            # wagi W_BN
    
    fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) ) # nowe wagi W_BN * W_Conv
    
    if conv.bias is not None:
        b_conv = conv.bias                          # kopia biasu jeżeli istnieje
    else:
        b_conv = torch.zeros( conv.weight.size(0) ) # bias=0 gdy nie był używany
    
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) # Bias dla BN
    
    fusedconv.bias.copy_( b_conv + b_bn )
    return fusedconv

In [3]:
with torch.no_grad():
    r18 = torchvision.models.resnet18(pretrained=True)

    x = torch.randn(16, 3, 256, 256)
    r18.eval()
    
    net = torch.nn.Sequential(
        r18.conv1,
        r18.bn1
    )
    
    original_output = net.forward(x)
    
    fusedconv = fuse_conv_and_bn(net[0], net[1])
    fused_output = fusedconv.forward(x)
    
    diff = (original_output - fused_output).norm().div(original_output.norm()).item()
    print("error: %.8f" % diff)

error: 0.00000030
