In [16]:
import os
import torch
import torch.nn as nn
from detectron2.layers import Conv2d
from detectron2.config import get_cfg
from detectron2.modeling import build_model
import struct

In [24]:
def fuse_conv_and_bn(conv):
    # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    bn = conv.norm
    # init
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # prepare filters
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

    # prepare spatial bias
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fusedconv

def fuse_bn(model):
    for child_name, child in model.named_children():
        if isinstance(child, Conv2d) and child.norm is not None:
            setattr(model, child_name, fuse_conv_and_bn(child))
        else:
            fuse_bn(child)

def gen_wts(model, filename):
    f = open('./' + filename + '.wts', 'w')
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        vr = v.reshape(-1).cpu().numpy()
        print('{}\t{}\t{}'.format(k, v.shape, len(vr)))
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f',float(vv)).hex())
        f.write('\n')
    f.close()

In [2]:
model_path = 'F:/models/detectron2/COCO-Detection/faster_rcnn_R_50_C4_1x.pkl'
config_path = 'F:/models/detectron2/COCO-Detection/faster_rcnn_R_50_C4_1x.yaml'

In [5]:
def redirection(config):
    d = 'E:/works/timesai_tools/timesAI_platform/times-ai/visionCode/algorithm_contribute/py_module/timesai_det1_interfance/configs/COCO-Detection'
    fname = os.path.basename(config)
    return os.path.join(d, fname)

In [8]:
cfg = get_cfg()
cfg.DATALOADER.NUM_WORKERS = 0
cfg.merge_from_file(redirection(config_path))
cfg.MODEL.WEIGHTS = model_path
cfg.freeze()
model = build_model(cfg)

In [25]:
model.eval()
fuse_bn(model)
gen_wts(model, 'faster')

backbone.stem.conv1.weight	torch.Size([64, 3, 7, 7])	9408
backbone.stem.conv1.bias	torch.Size([64])	64
backbone.res2.0.shortcut.weight	torch.Size([256, 64, 1, 1])	16384
backbone.res2.0.shortcut.bias	torch.Size([256])	256
backbone.res2.0.conv1.weight	torch.Size([64, 64, 1, 1])	4096
backbone.res2.0.conv1.bias	torch.Size([64])	64
backbone.res2.0.conv2.weight	torch.Size([64, 64, 3, 3])	36864
backbone.res2.0.conv2.bias	torch.Size([64])	64
backbone.res2.0.conv3.weight	torch.Size([256, 64, 1, 1])	16384
backbone.res2.0.conv3.bias	torch.Size([256])	256
backbone.res2.1.conv1.weight	torch.Size([64, 256, 1, 1])	16384
backbone.res2.1.conv1.bias	torch.Size([64])	64
backbone.res2.1.conv2.weight	torch.Size([64, 64, 3, 3])	36864
backbone.res2.1.conv2.bias	torch.Size([64])	64
backbone.res2.1.conv3.weight	torch.Size([256, 64, 1, 1])	16384
backbone.res2.1.conv3.bias	torch.Size([256])	256
backbone.res2.2.conv1.weight	torch.Size([64, 256, 1, 1])	16384
backbone.res2.2.conv1.bias	torch.Size([64])	64
backbone.