In [1]:
import torch

# A subset of VOCDataLoader just for one class (person) (0)
from utils.dataloader import VOCDataLoaderPerson

loader = VOCDataLoaderPerson(train=True, batch_size=1, shuffle=True)
loader_test = VOCDataLoaderPerson(train=False, batch_size=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


Using downloaded and verified file: data/VOCtrainval_11-May-2012.tar
Extracting data/VOCtrainval_11-May-2012.tar to data/
cuda


In [2]:
def fuse_conv_and_bn(conv, bn):
    with torch.no_grad():
        # Fuse conv and bn layers
        fusedconv = torch.nn.Conv2d(conv.in_channels,
                                    conv.out_channels,
                                    kernel_size=conv.kernel_size,
                                    stride=conv.stride,
                                    padding=conv.padding,
                                    bias=True).to(device)

        # Prepare filters
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))).to(device)

        fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

        # Prepare spatial bias
        if conv.bias is None:
            b_conv = torch.zeros(conv.weight.size(0)).to(device)
        else:
            b_conv = conv.bias

        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)).to(device)

        fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

        return fusedconv

In [3]:
from tinyyolov2NoBN import TinyYoloV2NoBN
from tinyyolov2 import TinyYoloV2

# Initialize the original model and load the weights
origin_net = TinyYoloV2(num_classes=1).to(device)

# load pretrained weights
state_dict = torch.load("models/voc_finetuned_100_epochs_lr0001_decay0005.pt")
origin_net.load_state_dict(state_dict)

# Create a new model without BatchNorm layers
fused_net = TinyYoloV2NoBN(num_classes=1).to(device)

# Fuse each conv and bn layer
fused_net.conv1 = fuse_conv_and_bn(origin_net.conv1, origin_net.bn1)
fused_net.conv2 = fuse_conv_and_bn(origin_net.conv2, origin_net.bn2)
fused_net.conv3 = fuse_conv_and_bn(origin_net.conv3, origin_net.bn3)
fused_net.conv4 = fuse_conv_and_bn(origin_net.conv4, origin_net.bn4)
fused_net.conv5 = fuse_conv_and_bn(origin_net.conv5, origin_net.bn5)
fused_net.conv6 = fuse_conv_and_bn(origin_net.conv6, origin_net.bn6)
fused_net.conv7 = fuse_conv_and_bn(origin_net.conv7, origin_net.bn7)
fused_net.conv8 = fuse_conv_and_bn(origin_net.conv8, origin_net.bn8)

# Copy the final conv layer directly (since it doesn't have BN)
fused_net.conv9 = origin_net.conv9

fused_net.eval()
# Save the fused model state dict
#fused_weights_path = 'path/to/your/fused_weights.pth'
torch.save(fused_net.state_dict(), 'models/voc_fused_100_epochs_lr0001_decay0005.pt')