In [20]:
# Copyright (c) OpenMMLab. All rights reserved.
from operator import attrgetter
from typing import List, Union

import torch
import torch.nn as nn


def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
                                   conv: nn.modules.conv._ConvNd,
                                   x: torch.Tensor):
    """Code borrowed from mmcv 2.0.1, so that this feature can be used for old
    mmcv versions.

    Implementation based on https://arxiv.org/abs/2305.11624
    "Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
    It leverages the associative law between convolution and affine transform,
    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
    It works for Eval mode of ConvBN blocks during validation, and can be used
    for training as well. It reduces memory and computation cost.
    Args:
        bn (_BatchNorm): a BatchNorm module.
        conv (nn._ConvNd): a conv module
        x (torch.Tensor): Input feature map.
    """
    # These lines of code are designed to deal with various cases
    # like bn without affine transform, and conv without bias
    weight_on_the_fly = conv.weight
    if conv.bias is not None:
        bias_on_the_fly = conv.bias
    else:
        bias_on_the_fly = torch.zeros_like(bn.running_var)

    if bn.weight is not None:
        bn_weight = bn.weight
    else:
        bn_weight = torch.ones_like(bn.running_var)

    if bn.bias is not None:
        bn_bias = bn.bias
    else:
        bn_bias = torch.zeros_like(bn.running_var)

    # shape of [C_out, 1, 1, 1] in Conv2d
    weight_coeff = torch.rsqrt(bn.running_var +
                               bn.eps).reshape([-1] + [1] *
                                               (len(conv.weight.shape) - 1))
    # shape of [C_out, 1, 1, 1] in Conv2d
    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff

    # shape of [C_out, C_in, k, k] in Conv2d
    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
    # shape of [C_out] in Conv2d
    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
        (bias_on_the_fly - bn.running_mean)

    return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
                                   conv: nn.modules.conv._ConvNd,
                                   x: torch.Tensor):
    """This function controls whether to use `efficient_conv_bn_eval_forward`.

    If the following `bn` is in `eval` mode, then we turn on the special
    `efficient_conv_bn_eval_forward`.
    """
    if not bn.training:
        # bn in eval mode
        output = efficient_conv_bn_eval_forward(bn, conv, x)
        return output
    else:
        conv_out = conv._conv_forward(x, conv.weight, conv.bias)
        return bn(conv_out)


def efficient_conv_bn_eval_graph_transform(fx_model):
    """Find consecutive conv+bn calls in the graph, inplace modify the graph
    with the fused operation."""
    modules = dict(fx_model.named_modules())

    patterns = [(torch.nn.modules.conv._ConvNd,
                 torch.nn.modules.batchnorm._BatchNorm)]

    pairs = []
    # Iterate through nodes in the graph to find ConvBN blocks
    for node in fx_model.graph.nodes:
        # If our current node isn't calling a Module then we can ignore it.
        if node.op != 'call_module':
            continue
        target_module = modules[node.target]
        found_pair = False
        for conv_class, bn_class in patterns:
            if isinstance(target_module, bn_class):
                source_module = modules[node.args[0].target]
                if isinstance(source_module, conv_class):
                    found_pair = True
        # Not a conv-BN pattern or output of conv is used by other nodes
        if not found_pair or len(node.args[0].users) > 1:
            continue

        # Find a pair of conv and bn computation nodes to optimize
        conv_node = node.args[0]
        bn_node = node
        pairs.append([conv_node, bn_node])

    for conv_node, bn_node in pairs:
        # set insertion point
        fx_model.graph.inserting_before(conv_node)
        # create `get_attr` node to access modules
        # note that we directly call `create_node` to fill the `name`
        # argument. `fx_model.graph.get_attr` and
        # `fx_model.graph.call_function` does not allow the `name` argument.
        conv_get_node = fx_model.graph.create_node(
            op='get_attr', target=conv_node.target, name='get_conv')
        bn_get_node = fx_model.graph.create_node(
            op='get_attr', target=bn_node.target, name='get_bn')
        # prepare args for the fused function
        args = (bn_get_node, conv_get_node, conv_node.args[0])
        # create a new node
        new_node = fx_model.graph.create_node(
            op='call_function',
            target=efficient_conv_bn_eval_control,
            args=args,
            name='efficient_conv_bn_eval')
        # this node replaces the original conv + bn, and therefore
        # should replace the uses of bn_node
        bn_node.replace_all_uses_with(new_node)
        # take care of the deletion order:
        # delete bn_node first, and then conv_node
        fx_model.graph.erase_node(bn_node)
        fx_model.graph.erase_node(conv_node)

    # regenerate the code
    fx_model.graph.lint()
    fx_model.recompile()


def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
    import torch.fx as fx

    # currently we use `fx.symbolic_trace` to trace models.
    # in the future, we might turn to pytorch 2.0 compile infrastructure to
    # get the `fx.GraphModule` IR. Nonetheless, the graph transform function
    # can remain unchanged. We just need to change the way
    # we get `fx.GraphModule`.
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    efficient_conv_bn_eval_graph_transform(fx_model)
    model.forward = fx_model.forward


def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
                                   modules: Union[List[str], str]):
    if isinstance(modules, str):
        modules = [modules]
    for module_name in modules:
        module = attrgetter(module_name)(model)
        turn_on_efficient_conv_bn_eval_for_single_model(module)

In [44]:
import torch
import torch.nn as nn
import torch.fx as fx

# 定义一个简单的卷积神经网络作为示例
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc = nn.Linear(32 * 32 * 32, 10)  # 假设输入图像大小为32x32

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc(x)
        return x

class SimpleCNN2(nn.Module):
    def __init__(self):
        super(SimpleCNN2 ,self).__init__()
        self.backbone = SimpleCNN()

    def forward(self, x):
        x = self.backbone(x)
        return x
# 调用模型并检查优化情况
model = SimpleCNN2()

# 创建一个随机输入
input_tensor = torch.randn(1, 3, 32, 32)

# 调用函数以启用高效的Conv + BN融合并打印优化前后的计算图


model = model.eval()
turn_on_efficient_conv_bn_eval(model, ['backbone'])
print(model.forward)
# 执行前向传播
output = model(input_tensor)

# 打印输出的形状
print(f"Output shape: {output.shape}")


<bound method SimpleCNN2.forward of SimpleCNN2(
  (backbone): SimpleCNN(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc): Linear(in_features=32768, out_features=10, bias=True)
  )
)>
Output shape: torch.Size([1, 10])


In [29]:
# 检查替换后的 forward 方法
print(model.forward)


<bound method forward of SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=32768, out_features=10, bias=True)
)>


In [30]:
# 打印 FX 计算图
import torch.fx as fx
fx_model = fx.symbolic_trace(model)
print(fx_model.graph)


graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%bn1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %bn2 : [num_users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.relu](args = (%bn2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%relu_1, 0), kwargs = {})
    %view : [num_users=1] = call_method[target=view](args = (%relu_1, %size, -1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%view,), kwargs = {})
    return fc


In [31]:
# 打印 FX 计算图
import torch.fx as fx
fx_model = fx.symbolic_trace(model)
print(fx_model.graph)


graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%bn1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %bn2 : [num_users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.relu](args = (%bn2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%relu_1, 0), kwargs = {})
    %view : [num_users=1] = call_method[target=view](args = (%relu_1, %size, -1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%view,), kwargs = {})
    return fc


In [28]:
# 检查替换后的 forward 方法
print(model.forward)


<bound method forward of SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=32768, out_features=10, bias=True)
)>


In [20]:
print(output_before_fusion)

tensor([[[[-0.4928, -0.7771, -1.0377,  ..., -0.8974, -0.3298, -0.3931],
          [ 0.2901,  0.2581,  0.7802,  ...,  0.9171,  0.9820,  0.1806],
          [-1.3592,  0.2125,  1.1953,  ..., -0.4506,  0.3617,  1.6183],
          ...,
          [-0.5671,  0.9266,  2.0484,  ..., -0.4558, -0.8388,  0.0452],
          [ 0.0935, -0.7979,  0.9503,  ...,  0.5598, -1.9626, -0.2911],
          [ 0.1995,  0.6507, -0.2142,  ..., -0.3146, -0.2791,  0.4547]],

         [[ 1.4356,  0.9129, -0.5031,  ..., -0.1020,  1.0058, -0.4193],
          [-1.4852, -0.4363, -0.8256,  ..., -0.9477, -0.0374,  0.1647],
          [-0.3652, -0.2904,  2.5691,  ...,  0.5813, -0.8416, -0.9438],
          ...,
          [ 0.4869,  2.5273,  1.1332,  ...,  2.2675,  0.0770,  0.6853],
          [ 0.0430, -1.8617, -0.6892,  ..., -0.8342, -0.1863, -0.8786],
          [ 0.0802,  0.3856, -0.6600,  ...,  0.8860,  1.6499, -0.8380]],

         [[-0.1324,  0.0349,  0.6080,  ..., -0.8111,  0.4878, -0.4131],
          [-0.3960, -0.8790, -

In [21]:
print(output_after_fusion)

tensor([[[[-0.2322, -0.3684, -0.4932,  ..., -0.4260, -0.1541, -0.1844],
          [ 0.1428,  0.1274,  0.3775,  ...,  0.4430,  0.4741,  0.0903],
          [-0.6471,  0.1056,  0.5763,  ..., -0.2120,  0.1770,  0.7788],
          ...,
          [-0.2678,  0.4476,  0.9848,  ..., -0.2145, -0.3979,  0.0255],
          [ 0.0486, -0.3783,  0.4589,  ...,  0.2719, -0.9361, -0.1356],
          [ 0.0993,  0.3154, -0.0988,  ..., -0.1468, -0.1298,  0.2216]],

         [[ 0.8490,  0.5423, -0.2886,  ..., -0.0532,  0.5968, -0.2394],
          [-0.8648, -0.2494, -0.4778,  ..., -0.5494, -0.0153,  0.1032],
          [-0.2077, -0.1638,  1.5140,  ...,  0.3477, -0.4872, -0.5472],
          ...,
          [ 0.2923,  1.4895,  0.6715,  ...,  1.3371,  0.0518,  0.4087],
          [ 0.0319, -1.0857, -0.3978,  ..., -0.4829, -0.1027, -0.5089],
          [ 0.0537,  0.2329, -0.3806,  ...,  0.5265,  0.9747, -0.4851]],

         [[-0.0990,  0.0130,  0.3966,  ..., -0.5533,  0.3161, -0.2869],
          [-0.2754, -0.5987, -

In [26]:
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest import TestCase

import torch
from torch import nn





class BackboneModel(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.conv1 = nn.Conv2d(6, 6, 6)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 6, 6)
        self.bn2 = nn.BatchNorm2d(6)
        self.conv3 = nn.Conv2d(6, 6, 6)
        self.bn3 = nn.BatchNorm2d(6)

    def forward(self, x):

        # this conv-bn pair can use efficient_conv_bn_eval feature
        x = self.bn1(self.conv1(x))
        # this conv-bn pair can use efficient_conv_bn_eval feature
        # only for the second `self.conv2` call.
        x = self.bn2(self.conv2(self.conv2(x)))
        # this conv-bn pair can use efficient_conv_bn_eval feature
        # just for the first forward of the `self.bn3`
        x = self.bn3(self.bn3(self.conv3(x)))
        return x




model = BackboneModel()
model.eval()
input = torch.randn(64, 6, 32, 32)
output = model(input)
turn_on_efficient_conv_bn_eval_for_single_model(model)
output2 = model(input)
print((output - output2).abs().max().item())


2.0116567611694336e-07


In [14]:
import torch
import torch.nn as nn

class ConvBn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)


    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std


    def switch_to_deploy(self):
        if self.bn.training:
            raise RuntimeError("BatchNorm should be in evaluation mode (eval) before deployment.")
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs


conv_bn_layer = ConvBn(3, 16, 3, 1, 1).eval()
input_tensor = torch.randn(1, 3, 32, 32)

output_before_fusion = conv_bn_layer(input_tensor)
conv_bn_layer.switch_to_deploy()
output_after_fusion = conv_bn_layer(input_tensor)

print("Output before fusion: ", output_before_fusion.shape)
print("Output after fusion: ", output_after_fusion.shape)
print("Are the outputs close? ", torch.allclose(output_before_fusion, output_after_fusion, atol=1e-6))


Output before fusion:  torch.Size([1, 16, 32, 32])
Output after fusion:  torch.Size([1, 16, 32, 32])
Are the outputs close?  True


In [1]:

from model.base_module import BaseModule
import torch
import torch.nn as nn
class ConvBn(BaseModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, init_cfg=None):
        super().__init__(init_cfg)
        self.in_channels = in_channels
        self.groups = groups
        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        if hasattr(self, 'rbr_reparam'):
            return self.rbr_reparam(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

init_cfg = [
    dict(type='Constant',val=2,layer='Conv2d'),
    dict(type='Constant',override=dict(name='bn'),val =3,bias=1), 
]
# Example of usage:
conv_bn_layer = ConvBn(3, 16, 3, 1, 1,init_cfg = init_cfg )

# Define the initialization configuration for Conv2d and BatchNorm2d
  # Example using Xavier initialization for Conv2d

# Initialize weights
conv_bn_layer.init_weights()

# Create a random input tensor
input_tensor = torch.randn(1, 3, 32, 32)

# Forward pass to see the output
output = conv_bn_layer(input_tensor)

# Print initialized weights of Conv2d and BatchNorm2d
# 打印每一层的权重均值和偏置
param_mean = {name: p.mean().item() for name, p in conv_bn_layer.named_parameters()}
for name, mean in param_mean.items():
    print(f"Mean of {name}: {mean:.6f}")


Mean of conv.weight: 2.000000
Mean of bn.weight: 3.000000
Mean of bn.bias: 1.000000


In [3]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.19.8-py3-none-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting click!=8.0.0,>=7.1 (from wandb)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 (from wandb)
  Downloading protobuf-5.29.4-cp38-abi3-macosx_10_9_universal2.whl.metadata (592 bytes)
Collecting pydantic<3,>=2.6 (from wandb)
  Downloading pydantic-2.10.6-py3-none-any.whl.metadata (30 kB)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Downloading sentry_sdk-2.24.0-py2.py3-none-any.whl.metadata (10 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.5-cp310-cp310-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading 

In [1]:
from apps.builder import make_model
from apps.registry import register_dir


register_dir('model')
model =  make_model(dict(type='resnet18',pretrained =True,dataset='imagenet1k_v1'))
model.init_weights()
param_mean = {name: p.mean().item() for name, p in model.named_parameters()}
for name, mean in param_mean.items():
    print(f"Mean of {name}: {mean:.6f}")

Processing directory: model
Loading module: model.backbone
Loading module: model.efficient_conv_bn_eval
Loading module: model.resnet
Loading module: model.utils
Unexpected error while loading module model.utils: name 'nn' is not defined
Loading module: model.weight_url
Loading module: model.base_module
Loading module: model.weight_init
Processing directory: model/__pycache__
Processing directory: model/segmentor
Loading module: model.segmentor.encoderdecoder
ModuleNotFoundError: No module named 'mmengine'
Loading module: model.segmentor.segmentor
Unexpected error while loading module model.segmentor.segmentor: name 'Tensor' is not defined
Processing directory: model/segmentor/__pycache__
Mean of conv1.weight: 0.000029
Mean of bn1.weight: 0.257577
Mean of bn1.bias: 0.181120
Mean of layer1.0.conv1.weight: -0.003087
Mean of layer1.0.bn1.weight: 0.339601
Mean of layer1.0.bn1.bias: -0.034137
Mean of layer1.0.conv2.weight: -0.000889
Mean of layer1.0.bn2.weight: 0.333055
Mean of layer1.0.bn2.

In [26]:
class ConvBn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)


    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std


    def switch_to_deploy(self):
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class RepConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes=(3, 7, 11), stride=1, padding=None, dilation=1, deploy=False, groups=1,use_bn):
        """
        Reparameterized convolution module for Conv2d, with an interface consistent with nn.Conv2d.
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_sizes: List of kernel sizes
        :param stride: Stride of the convolution
        :param padding: Padding size; if None, automatically computes "same" padding
        :param dilation: Dilation rate
        :param bias: Whether to include a bias term
        :param use_identity: Whether to use identity mapping (only effective when in_channels == out_channels and groups == 1)
        :param groups: Number of groups for group convolution
        """
        super(RepConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = sorted(kernel_sizes)
        self.stride = stride
        self.dilation = dilation
        self.groups = groups


        self.max_kernel_size = max(self.kernel_sizes)
        self.padding = padding if padding is not None else (self.max_kernel_size - 1) // 2 * dilation
        if use_bn:
            self.convs = nn.ModuleList([
            ConvBn(in_channels, out_channels, k, stride=stride, dilation=dilation,
                      padding=(k - 1) // 2 * dilation, bias=False, groups=groups)
            for k in self.kernel_sizes
            ])
        else:
            self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, k, stride=stride, dilation=dilation,
                      padding=(k - 1) // 2 * dilation, bias=False, groups=groups)
            for k in self.kernel_sizes
            ])
        self.deploy = deploy 
        if deploy:


    def forward(self, x):
        if self.reparameterized:
            return self.rbr_reparam(
                x
            )
        else:
            conv_outputs = []
            for conv in self.convs:
                conv_outputs.append(conv(x))
            if self.use_identity:
                return sum(conv_outputs) / len(self.convs) + x
            else:
                return sum(conv_outputs) / len(self.convs)

    def _convert_weight_and_bias(self):
        weight = self.convs[-1].weight
        bias = self.convs[-1].bias if self.convs[-1].bias is not None else torch.zeros(self.out_channels, device=weight.device)

        for conv in self.convs[:-1]:
            pad = (self.max_kernel_size - conv.weight.shape[-1]) // 2
            weight = weight + F.pad(conv.weight, [pad, pad, pad, pad])
            conv_bias = conv.bias if conv.bias is not None else torch.zeros(self.out_channels, device=weight.device)
            bias = bias + conv_bias

        weight = weight / len(self.convs)
        bias = bias / len(self.convs)

        if self.use_identity:
            pad = (self.max_kernel_size - 1) // 2
            identity_weight = F.pad(
                torch.eye(self.out_channels, self.in_channels // self.groups).unsqueeze(-1).unsqueeze(-1).to(weight.device),
                [pad, pad, pad, pad]
            ).repeat(1, self.groups, 1, 1).reshape(self.out_channels, self.in_channels // self.groups, self.max_kernel_size, self.max_kernel_size)
            weight = weight + identity_weight

        self.weight = weight.detach()
        self.bias = bias.detach()

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        self._convert_weight_and_bias()
        self.rbr_reparam = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=(self.max_kernel_size, self.max_kernel_size),
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
            bias=True
        )
        self.rbr_reparam.weight.data = self.weight
        self.rbr_reparam.bias.data = self.bias
        del self.convs
        self.reparameterized = True



In [10]:
import torch

# Initialize the RepConv2d model
model = RepConv2d(in_channels=3, out_channels=16, kernel_sizes=(3, 5, 7), stride=1).eval()

# Create a dummy input tensor
x = torch.randn(1, 3, 224, 224)  # Batch size 1, 3 channels, 224x224 image

# Run forward pass before deployment (pre-deployment)
output_before_deploy = model(x)

# Deploy the model
model.switch_to_deploy()

# Run forward pass after deployment (post-deployment)
output_after_deploy = model(x)

# Compare the outputs
print(torch.allclose(output_before_deploy, output_after_deploy))  # Should be True if the outputs are the same


False


In [20]:
output_before_deploy

tensor([[[[-1.0509e-01, -1.6618e-01,  2.3734e-02,  ...,  4.3876e-02,
           -2.7332e-01,  7.8780e-02],
          [-4.3497e-02, -2.1642e-01,  1.4031e-01,  ...,  1.1755e-01,
           -2.8986e-01, -7.6739e-02],
          [ 1.0535e-01, -2.4501e-01, -9.3907e-02,  ...,  1.5210e-01,
            1.1109e-01, -3.8096e-02],
          ...,
          [-4.9162e-02,  1.0591e-01,  2.6854e-01,  ...,  3.2434e-01,
            1.1938e-01,  1.5615e-01],
          [ 1.6362e-01,  3.3904e-01,  1.2398e-01,  ..., -1.7295e-01,
            1.4540e-01,  1.7867e-03],
          [ 3.7581e-02,  3.7735e-01,  7.6096e-02,  ...,  3.2586e-01,
            5.5626e-01, -5.7446e-02]],

         [[-1.8109e-01,  3.9394e-02,  3.0771e-01,  ...,  8.3651e-02,
           -3.2509e-01, -1.4123e-01],
          [ 7.4080e-02, -2.8459e-01, -5.4809e-02,  ...,  4.1137e-02,
            1.3135e-01,  3.6649e-01],
          [-9.8002e-02,  4.2682e-01,  2.3538e-03,  ...,  3.5936e-01,
            2.5884e-01, -3.5275e-01],
          ...,
     

In [21]:
output_after_deploy

tensor([[[[-1.0509e-01, -1.6618e-01,  2.3734e-02,  ...,  4.3876e-02,
           -2.7332e-01,  7.8780e-02],
          [-4.3497e-02, -2.1642e-01,  1.4031e-01,  ...,  1.1755e-01,
           -2.8986e-01, -7.6739e-02],
          [ 1.0535e-01, -2.4501e-01, -9.3907e-02,  ...,  1.5210e-01,
            1.1109e-01, -3.8096e-02],
          ...,
          [-4.9162e-02,  1.0591e-01,  2.6854e-01,  ...,  3.2434e-01,
            1.1938e-01,  1.5615e-01],
          [ 1.6362e-01,  3.3904e-01,  1.2398e-01,  ..., -1.7295e-01,
            1.4540e-01,  1.7867e-03],
          [ 3.7581e-02,  3.7735e-01,  7.6096e-02,  ...,  3.2586e-01,
            5.5626e-01, -5.7446e-02]],

         [[-1.8109e-01,  3.9394e-02,  3.0771e-01,  ...,  8.3651e-02,
           -3.2509e-01, -1.4123e-01],
          [ 7.4080e-02, -2.8459e-01, -5.4809e-02,  ...,  4.1137e-02,
            1.3135e-01,  3.6649e-01],
          [-9.8002e-02,  4.2682e-01,  2.3537e-03,  ...,  3.5936e-01,
            2.5884e-01, -3.5275e-01],
          ...,
     