In [20]:
# Copyright (c) OpenMMLab. All rights reserved.
from operator import attrgetter
from typing import List, Union

import torch
import torch.nn as nn


def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
                                   conv: nn.modules.conv._ConvNd,
                                   x: torch.Tensor):
    """Code borrowed from mmcv 2.0.1, so that this feature can be used for old
    mmcv versions.

    Implementation based on https://arxiv.org/abs/2305.11624
    "Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
    It leverages the associative law between convolution and affine transform,
    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
    It works for Eval mode of ConvBN blocks during validation, and can be used
    for training as well. It reduces memory and computation cost.
    Args:
        bn (_BatchNorm): a BatchNorm module.
        conv (nn._ConvNd): a conv module
        x (torch.Tensor): Input feature map.
    """
    # These lines of code are designed to deal with various cases
    # like bn without affine transform, and conv without bias
    weight_on_the_fly = conv.weight
    if conv.bias is not None:
        bias_on_the_fly = conv.bias
    else:
        bias_on_the_fly = torch.zeros_like(bn.running_var)

    if bn.weight is not None:
        bn_weight = bn.weight
    else:
        bn_weight = torch.ones_like(bn.running_var)

    if bn.bias is not None:
        bn_bias = bn.bias
    else:
        bn_bias = torch.zeros_like(bn.running_var)

    # shape of [C_out, 1, 1, 1] in Conv2d
    weight_coeff = torch.rsqrt(bn.running_var +
                               bn.eps).reshape([-1] + [1] *
                                               (len(conv.weight.shape) - 1))
    # shape of [C_out, 1, 1, 1] in Conv2d
    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff

    # shape of [C_out, C_in, k, k] in Conv2d
    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
    # shape of [C_out] in Conv2d
    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
        (bias_on_the_fly - bn.running_mean)

    return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
                                   conv: nn.modules.conv._ConvNd,
                                   x: torch.Tensor):
    """This function controls whether to use `efficient_conv_bn_eval_forward`.

    If the following `bn` is in `eval` mode, then we turn on the special
    `efficient_conv_bn_eval_forward`.
    """
    if not bn.training:
        # bn in eval mode
        output = efficient_conv_bn_eval_forward(bn, conv, x)
        return output
    else:
        conv_out = conv._conv_forward(x, conv.weight, conv.bias)
        return bn(conv_out)


def efficient_conv_bn_eval_graph_transform(fx_model):
    """Find consecutive conv+bn calls in the graph, inplace modify the graph
    with the fused operation."""
    modules = dict(fx_model.named_modules())

    patterns = [(torch.nn.modules.conv._ConvNd,
                 torch.nn.modules.batchnorm._BatchNorm)]

    pairs = []
    # Iterate through nodes in the graph to find ConvBN blocks
    for node in fx_model.graph.nodes:
        # If our current node isn't calling a Module then we can ignore it.
        if node.op != 'call_module':
            continue
        target_module = modules[node.target]
        found_pair = False
        for conv_class, bn_class in patterns:
            if isinstance(target_module, bn_class):
                source_module = modules[node.args[0].target]
                if isinstance(source_module, conv_class):
                    found_pair = True
        # Not a conv-BN pattern or output of conv is used by other nodes
        if not found_pair or len(node.args[0].users) > 1:
            continue

        # Find a pair of conv and bn computation nodes to optimize
        conv_node = node.args[0]
        bn_node = node
        pairs.append([conv_node, bn_node])

    for conv_node, bn_node in pairs:
        # set insertion point
        fx_model.graph.inserting_before(conv_node)
        # create `get_attr` node to access modules
        # note that we directly call `create_node` to fill the `name`
        # argument. `fx_model.graph.get_attr` and
        # `fx_model.graph.call_function` does not allow the `name` argument.
        conv_get_node = fx_model.graph.create_node(
            op='get_attr', target=conv_node.target, name='get_conv')
        bn_get_node = fx_model.graph.create_node(
            op='get_attr', target=bn_node.target, name='get_bn')
        # prepare args for the fused function
        args = (bn_get_node, conv_get_node, conv_node.args[0])
        # create a new node
        new_node = fx_model.graph.create_node(
            op='call_function',
            target=efficient_conv_bn_eval_control,
            args=args,
            name='efficient_conv_bn_eval')
        # this node replaces the original conv + bn, and therefore
        # should replace the uses of bn_node
        bn_node.replace_all_uses_with(new_node)
        # take care of the deletion order:
        # delete bn_node first, and then conv_node
        fx_model.graph.erase_node(bn_node)
        fx_model.graph.erase_node(conv_node)

    # regenerate the code
    fx_model.graph.lint()
    fx_model.recompile()


def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
    import torch.fx as fx

    # currently we use `fx.symbolic_trace` to trace models.
    # in the future, we might turn to pytorch 2.0 compile infrastructure to
    # get the `fx.GraphModule` IR. Nonetheless, the graph transform function
    # can remain unchanged. We just need to change the way
    # we get `fx.GraphModule`.
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    efficient_conv_bn_eval_graph_transform(fx_model)
    model.forward = fx_model.forward


def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
                                   modules: Union[List[str], str]):
    if isinstance(modules, str):
        modules = [modules]
    for module_name in modules:
        module = attrgetter(module_name)(model)
        turn_on_efficient_conv_bn_eval_for_single_model(module)

In [44]:
import torch
import torch.nn as nn
import torch.fx as fx

# 定义一个简单的卷积神经网络作为示例
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc = nn.Linear(32 * 32 * 32, 10)  # 假设输入图像大小为32x32

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc(x)
        return x

class SimpleCNN2(nn.Module):
    def __init__(self):
        super(SimpleCNN2 ,self).__init__()
        self.backbone = SimpleCNN()

    def forward(self, x):
        x = self.backbone(x)
        return x
# 调用模型并检查优化情况
model = SimpleCNN2()

# 创建一个随机输入
input_tensor = torch.randn(1, 3, 32, 32)

# 调用函数以启用高效的Conv + BN融合并打印优化前后的计算图


model = model.eval()
turn_on_efficient_conv_bn_eval(model, ['backbone'])
print(model.forward)
# 执行前向传播
output = model(input_tensor)

# 打印输出的形状
print(f"Output shape: {output.shape}")


<bound method SimpleCNN2.forward of SimpleCNN2(
  (backbone): SimpleCNN(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc): Linear(in_features=32768, out_features=10, bias=True)
  )
)>
Output shape: torch.Size([1, 10])


In [29]:
# 检查替换后的 forward 方法
print(model.forward)


<bound method forward of SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=32768, out_features=10, bias=True)
)>


In [30]:
# 打印 FX 计算图
import torch.fx as fx
fx_model = fx.symbolic_trace(model)
print(fx_model.graph)


graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%bn1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %bn2 : [num_users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.relu](args = (%bn2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%relu_1, 0), kwargs = {})
    %view : [num_users=1] = call_method[target=view](args = (%relu_1, %size, -1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%view,), kwargs = {})
    return fc


In [31]:
# 打印 FX 计算图
import torch.fx as fx
fx_model = fx.symbolic_trace(model)
print(fx_model.graph)


graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%bn1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %bn2 : [num_users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.relu](args = (%bn2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%relu_1, 0), kwargs = {})
    %view : [num_users=1] = call_method[target=view](args = (%relu_1, %size, -1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%view,), kwargs = {})
    return fc


In [28]:
# 检查替换后的 forward 方法
print(model.forward)


<bound method forward of SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=32768, out_features=10, bias=True)
)>


In [20]:
print(output_before_fusion)

tensor([[[[-0.4928, -0.7771, -1.0377,  ..., -0.8974, -0.3298, -0.3931],
          [ 0.2901,  0.2581,  0.7802,  ...,  0.9171,  0.9820,  0.1806],
          [-1.3592,  0.2125,  1.1953,  ..., -0.4506,  0.3617,  1.6183],
          ...,
          [-0.5671,  0.9266,  2.0484,  ..., -0.4558, -0.8388,  0.0452],
          [ 0.0935, -0.7979,  0.9503,  ...,  0.5598, -1.9626, -0.2911],
          [ 0.1995,  0.6507, -0.2142,  ..., -0.3146, -0.2791,  0.4547]],

         [[ 1.4356,  0.9129, -0.5031,  ..., -0.1020,  1.0058, -0.4193],
          [-1.4852, -0.4363, -0.8256,  ..., -0.9477, -0.0374,  0.1647],
          [-0.3652, -0.2904,  2.5691,  ...,  0.5813, -0.8416, -0.9438],
          ...,
          [ 0.4869,  2.5273,  1.1332,  ...,  2.2675,  0.0770,  0.6853],
          [ 0.0430, -1.8617, -0.6892,  ..., -0.8342, -0.1863, -0.8786],
          [ 0.0802,  0.3856, -0.6600,  ...,  0.8860,  1.6499, -0.8380]],

         [[-0.1324,  0.0349,  0.6080,  ..., -0.8111,  0.4878, -0.4131],
          [-0.3960, -0.8790, -

In [21]:
print(output_after_fusion)

tensor([[[[-0.2322, -0.3684, -0.4932,  ..., -0.4260, -0.1541, -0.1844],
          [ 0.1428,  0.1274,  0.3775,  ...,  0.4430,  0.4741,  0.0903],
          [-0.6471,  0.1056,  0.5763,  ..., -0.2120,  0.1770,  0.7788],
          ...,
          [-0.2678,  0.4476,  0.9848,  ..., -0.2145, -0.3979,  0.0255],
          [ 0.0486, -0.3783,  0.4589,  ...,  0.2719, -0.9361, -0.1356],
          [ 0.0993,  0.3154, -0.0988,  ..., -0.1468, -0.1298,  0.2216]],

         [[ 0.8490,  0.5423, -0.2886,  ..., -0.0532,  0.5968, -0.2394],
          [-0.8648, -0.2494, -0.4778,  ..., -0.5494, -0.0153,  0.1032],
          [-0.2077, -0.1638,  1.5140,  ...,  0.3477, -0.4872, -0.5472],
          ...,
          [ 0.2923,  1.4895,  0.6715,  ...,  1.3371,  0.0518,  0.4087],
          [ 0.0319, -1.0857, -0.3978,  ..., -0.4829, -0.1027, -0.5089],
          [ 0.0537,  0.2329, -0.3806,  ...,  0.5265,  0.9747, -0.4851]],

         [[-0.0990,  0.0130,  0.3966,  ..., -0.5533,  0.3161, -0.2869],
          [-0.2754, -0.5987, -

In [26]:
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest import TestCase

import torch
from torch import nn





class BackboneModel(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.conv1 = nn.Conv2d(6, 6, 6)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 6, 6)
        self.bn2 = nn.BatchNorm2d(6)
        self.conv3 = nn.Conv2d(6, 6, 6)
        self.bn3 = nn.BatchNorm2d(6)

    def forward(self, x):

        # this conv-bn pair can use efficient_conv_bn_eval feature
        x = self.bn1(self.conv1(x))
        # this conv-bn pair can use efficient_conv_bn_eval feature
        # only for the second `self.conv2` call.
        x = self.bn2(self.conv2(self.conv2(x)))
        # this conv-bn pair can use efficient_conv_bn_eval feature
        # just for the first forward of the `self.bn3`
        x = self.bn3(self.bn3(self.conv3(x)))
        return x




model = BackboneModel()
model.eval()
input = torch.randn(64, 6, 32, 32)
output = model(input)
turn_on_efficient_conv_bn_eval_for_single_model(model)
output2 = model(input)
print((output - output2).abs().max().item())


2.0116567611694336e-07


In [14]:
import torch
import torch.nn as nn

class ConvBn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)


    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std


    def switch_to_deploy(self):
        if self.bn.training:
            raise RuntimeError("BatchNorm should be in evaluation mode (eval) before deployment.")
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs


conv_bn_layer = ConvBn(3, 16, 3, 1, 1).eval()
input_tensor = torch.randn(1, 3, 32, 32)

output_before_fusion = conv_bn_layer(input_tensor)
conv_bn_layer.switch_to_deploy()
output_after_fusion = conv_bn_layer(input_tensor)

print("Output before fusion: ", output_before_fusion.shape)
print("Output after fusion: ", output_after_fusion.shape)
print("Are the outputs close? ", torch.allclose(output_before_fusion, output_after_fusion, atol=1e-6))


Output before fusion:  torch.Size([1, 16, 32, 32])
Output after fusion:  torch.Size([1, 16, 32, 32])
Are the outputs close?  True


In [1]:

from model.base_module import BaseModule
import torch
import torch.nn as nn
class ConvBn(BaseModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, init_cfg=None):
        super().__init__(init_cfg)
        self.in_channels = in_channels
        self.groups = groups
        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        if hasattr(self, 'rbr_reparam'):
            return self.rbr_reparam(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

init_cfg = [
    dict(type='Constant',val=2,layer='Conv2d'),
    dict(type='Constant',override=dict(name='bn'),val =3,bias=1), 
]
# Example of usage:
conv_bn_layer = ConvBn(3, 16, 3, 1, 1,init_cfg = init_cfg )

# Define the initialization configuration for Conv2d and BatchNorm2d
  # Example using Xavier initialization for Conv2d

# Initialize weights
conv_bn_layer.init_weights()

# Create a random input tensor
input_tensor = torch.randn(1, 3, 32, 32)

# Forward pass to see the output
output = conv_bn_layer(input_tensor)

# Print initialized weights of Conv2d and BatchNorm2d
# 打印每一层的权重均值和偏置
param_mean = {name: p.mean().item() for name, p in conv_bn_layer.named_parameters()}
for name, mean in param_mean.items():
    print(f"Mean of {name}: {mean:.6f}")


Mean of conv.weight: 2.000000
Mean of bn.weight: 3.000000
Mean of bn.bias: 1.000000


In [3]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.19.8-py3-none-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting click!=8.0.0,>=7.1 (from wandb)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 (from wandb)
  Downloading protobuf-5.29.4-cp38-abi3-macosx_10_9_universal2.whl.metadata (592 bytes)
Collecting pydantic<3,>=2.6 (from wandb)
  Downloading pydantic-2.10.6-py3-none-any.whl.metadata (30 kB)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Downloading sentry_sdk-2.24.0-py2.py3-none-any.whl.metadata (10 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.5-cp310-cp310-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading 

In [1]:
from apps.builder import make_model
from apps.registry import register_dir


register_dir('model')
model =  make_model(dict(type='resnet18',pretrained =True,dataset='imagenet1k_v1'))
model.init_weights()
param_mean = {name: p.mean().item() for name, p in model.named_parameters()}
for name, mean in param_mean.items():
    print(f"Mean of {name}: {mean:.6f}")

Processing directory: model
Loading module: model.backbone
Loading module: model.efficient_conv_bn_eval
Loading module: model.resnet
Loading module: model.utils
Unexpected error while loading module model.utils: name 'nn' is not defined
Loading module: model.weight_url
Loading module: model.base_module
Loading module: model.weight_init
Processing directory: model/__pycache__
Processing directory: model/segmentor
Loading module: model.segmentor.encoderdecoder
ModuleNotFoundError: No module named 'mmengine'
Loading module: model.segmentor.segmentor
Unexpected error while loading module model.segmentor.segmentor: name 'Tensor' is not defined
Processing directory: model/segmentor/__pycache__
Mean of conv1.weight: 0.000029
Mean of bn1.weight: 0.257577
Mean of bn1.bias: 0.181120
Mean of layer1.0.conv1.weight: -0.003087
Mean of layer1.0.bn1.weight: 0.339601
Mean of layer1.0.bn1.bias: -0.034137
Mean of layer1.0.conv2.weight: -0.000889
Mean of layer1.0.bn2.weight: 0.333055
Mean of layer1.0.bn2.

In [26]:
class ConvBn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)


    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std


    def switch_to_deploy(self):
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class RepConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes=(3, 7, 11), stride=1, padding=None, dilation=1, deploy=False, groups=1,use_bn):
        """
        Reparameterized convolution module for Conv2d, with an interface consistent with nn.Conv2d.
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_sizes: List of kernel sizes
        :param stride: Stride of the convolution
        :param padding: Padding size; if None, automatically computes "same" padding
        :param dilation: Dilation rate
        :param bias: Whether to include a bias term
        :param use_identity: Whether to use identity mapping (only effective when in_channels == out_channels and groups == 1)
        :param groups: Number of groups for group convolution
        """
        super(RepConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = sorted(kernel_sizes)
        self.stride = stride
        self.dilation = dilation
        self.groups = groups


        self.max_kernel_size = max(self.kernel_sizes)
        self.padding = padding if padding is not None else (self.max_kernel_size - 1) // 2 * dilation
        if use_bn:
            self.convs = nn.ModuleList([
            ConvBn(in_channels, out_channels, k, stride=stride, dilation=dilation,
                      padding=(k - 1) // 2 * dilation, bias=False, groups=groups)
            for k in self.kernel_sizes
            ])
        else:
            self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, k, stride=stride, dilation=dilation,
                      padding=(k - 1) // 2 * dilation, bias=False, groups=groups)
            for k in self.kernel_sizes
            ])
        self.deploy = deploy 
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)

    def forward(self, x):
        if self.reparameterized:
            return self.rbr_reparam(
                x
            )
        else:
            conv_outputs = []
            for conv in self.convs:
                conv_outputs.append(conv(x))
            if self.use_identity:
                return sum(conv_outputs) / len(self.convs) + x
            else:
                return sum(conv_outputs) / len(self.convs)

    def _convert_weight_and_bias(self):
        weight = self.convs[-1].weight
        bias = self.convs[-1].bias if self.convs[-1].bias is not None else torch.zeros(self.out_channels, device=weight.device)

        for conv in self.convs[:-1]:
            pad = (self.max_kernel_size - conv.weight.shape[-1]) // 2
            weight = weight + F.pad(conv.weight, [pad, pad, pad, pad])
            conv_bias = conv.bias if conv.bias is not None else torch.zeros(self.out_channels, device=weight.device)
            bias = bias + conv_bias

        weight = weight / len(self.convs)
        bias = bias / len(self.convs)

        if self.use_identity:
            pad = (self.max_kernel_size - 1) // 2
            identity_weight = F.pad(
                torch.eye(self.out_channels, self.in_channels // self.groups).unsqueeze(-1).unsqueeze(-1).to(weight.device),
                [pad, pad, pad, pad]
            ).repeat(1, self.groups, 1, 1).reshape(self.out_channels, self.in_channels // self.groups, self.max_kernel_size, self.max_kernel_size)
            weight = weight + identity_weight

        self.weight = weight.detach()
        self.bias = bias.detach()

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        self._convert_weight_and_bias()
        self.rbr_reparam = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=(self.max_kernel_size, self.max_kernel_size),
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
            bias=True
        )
        self.rbr_reparam.weight.data = self.weight
        self.rbr_reparam.bias.data = self.bias
        del self.convs
        self.reparameterized = True



In [10]:
import torch

# Initialize the RepConv2d model
model = RepConv2d(in_channels=3, out_channels=16, kernel_sizes=(3, 5, 7), stride=1).eval()

# Create a dummy input tensor
x = torch.randn(1, 3, 224, 224)  # Batch size 1, 3 channels, 224x224 image

# Run forward pass before deployment (pre-deployment)
output_before_deploy = model(x)

# Deploy the model
model.switch_to_deploy()

# Run forward pass after deployment (post-deployment)
output_after_deploy = model(x)

# Compare the outputs
print(torch.allclose(output_before_deploy, output_after_deploy))  # Should be True if the outputs are the same


False


In [20]:
output_before_deploy

tensor([[[[-1.0509e-01, -1.6618e-01,  2.3734e-02,  ...,  4.3876e-02,
           -2.7332e-01,  7.8780e-02],
          [-4.3497e-02, -2.1642e-01,  1.4031e-01,  ...,  1.1755e-01,
           -2.8986e-01, -7.6739e-02],
          [ 1.0535e-01, -2.4501e-01, -9.3907e-02,  ...,  1.5210e-01,
            1.1109e-01, -3.8096e-02],
          ...,
          [-4.9162e-02,  1.0591e-01,  2.6854e-01,  ...,  3.2434e-01,
            1.1938e-01,  1.5615e-01],
          [ 1.6362e-01,  3.3904e-01,  1.2398e-01,  ..., -1.7295e-01,
            1.4540e-01,  1.7867e-03],
          [ 3.7581e-02,  3.7735e-01,  7.6096e-02,  ...,  3.2586e-01,
            5.5626e-01, -5.7446e-02]],

         [[-1.8109e-01,  3.9394e-02,  3.0771e-01,  ...,  8.3651e-02,
           -3.2509e-01, -1.4123e-01],
          [ 7.4080e-02, -2.8459e-01, -5.4809e-02,  ...,  4.1137e-02,
            1.3135e-01,  3.6649e-01],
          [-9.8002e-02,  4.2682e-01,  2.3538e-03,  ...,  3.5936e-01,
            2.5884e-01, -3.5275e-01],
          ...,
     

In [21]:
output_after_deploy

tensor([[[[-1.0509e-01, -1.6618e-01,  2.3734e-02,  ...,  4.3876e-02,
           -2.7332e-01,  7.8780e-02],
          [-4.3497e-02, -2.1642e-01,  1.4031e-01,  ...,  1.1755e-01,
           -2.8986e-01, -7.6739e-02],
          [ 1.0535e-01, -2.4501e-01, -9.3907e-02,  ...,  1.5210e-01,
            1.1109e-01, -3.8096e-02],
          ...,
          [-4.9162e-02,  1.0591e-01,  2.6854e-01,  ...,  3.2434e-01,
            1.1938e-01,  1.5615e-01],
          [ 1.6362e-01,  3.3904e-01,  1.2398e-01,  ..., -1.7295e-01,
            1.4540e-01,  1.7867e-03],
          [ 3.7581e-02,  3.7735e-01,  7.6096e-02,  ...,  3.2586e-01,
            5.5626e-01, -5.7446e-02]],

         [[-1.8109e-01,  3.9394e-02,  3.0771e-01,  ...,  8.3651e-02,
           -3.2509e-01, -1.4123e-01],
          [ 7.4080e-02, -2.8459e-01, -5.4809e-02,  ...,  4.1137e-02,
            1.3135e-01,  3.6649e-01],
          [-9.8002e-02,  4.2682e-01,  2.3537e-03,  ...,  3.5936e-01,
            2.5884e-01, -3.5275e-01],
          ...,
     

In [169]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)

    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std

    def switch_to_deploy(self):
        if self.bn.training:
            raise RuntimeError("BatchNorm should be in evaluation mode (eval) before deployment.")
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs
class RepConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes=(0, 3, 7, 11), stride=1, dilation=1, groups=1, use_bn=True, deploy=False):
        super(RepConv2d, self).__init__()
        self.deploy = deploy
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = sorted(kernel_sizes)
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.max_kernel_size = max(self.kernel_sizes)
        self.max_padding = (self.max_kernel_size - 1) // 2 * dilation

        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=self.max_kernel_size, stride=stride,
                                        padding=self.max_padding, dilation=dilation, groups=groups, bias=True)
        else:
            convs = []
            if use_bn:
                if 0 in self.kernel_sizes:
                    assert in_channels == out_channels, "in_channels and out_channels should be equal when kernel_size is 0"
                    convs.append(nn.BatchNorm2d(out_channels))  # Add BatchNorm for kernel size 0
                    self.kernel_sizes.remove(0)
                # Add ConvBn layers for other kernel sizes
                convs.extend([
                    ConvBn(in_channels, out_channels, k, stride=stride, dilation=dilation,
                           padding=(k - 1) // 2 * dilation, groups=groups)
                    for k in self.kernel_sizes
                ])
            else:
                if 0 in self.kernel_sizes:
                    assert in_channels == out_channels, "in_channels and out_channels should be equal when kernel_size is 0"
                    convs.append(nn.Identity())  # Identity for kernel size 0
                    self.kernel_sizes.remove(0)
                # Add Conv2d layers for other kernel sizes
                convs.extend([
                    nn.Conv2d(in_channels, out_channels, k, stride=stride, dilation=dilation,
                              padding=(k - 1) // 2 * dilation, bias=True, groups=groups)
                    for k in self.kernel_sizes
                ])
            self.convs = nn.ModuleList(convs)

    def forward(self, x):
        if self.deploy:
            return self.fused_conv(x)
        else:
            conv_outputs = []
            for conv in self.convs:
                conv_outputs.append(conv(x))
            return sum(conv_outputs)

    def _convert_weight_and_bias(self):
        if hasattr(self.convs[-1], 'switch_to_deploy'):
            self.convs[-1].switch_to_deploy()
            weight = self.convs[-1].fused_conv.weight
            bias = self.convs[-1].fused_conv.bias
        else:
            weight = self.convs[-1].weight
            print(weight)
            bias = self.convs[-1].bias

        for conv in self.convs[:-1]:
            if isinstance(conv, nn.BatchNorm2d):
                std = (conv.running_var + conv.eps).sqrt()
                t = (conv.weight / std).reshape(-1, 1, 1, 1)
                pad = (self.max_kernel_size - conv.weight.shape[-1]) // 2
                identity_weight = F.pad(
                    torch.eye(self.out_channels, self.in_channels // self.groups).unsqueeze(-1).unsqueeze(-1).to(weight.device),
                    [pad, pad, pad, pad]
                ).repeat(1, self.groups, 1, 1).reshape(self.out_channels, self.in_channels // self.groups, self.max_kernel_size, self.max_kernel_size)
                weight = weight + identity_weight* t
                bias = bias + conv.bias - conv.running_mean * conv.weight / std
            elif isinstance(conv, nn.Identity):
                pad = (self.max_kernel_size - 1) // 2
                identity_weight = F.pad(
                    torch.eye(self.out_channels, self.in_channels // self.groups).unsqueeze(-1).unsqueeze(-1).to(weight.device),
                    [pad, pad, pad, pad]
                ).repeat(1, self.groups, 1, 1).reshape(self.out_channels, self.in_channels // self.groups, self.max_kernel_size, self.max_kernel_size)
                print(identity_weight)
                weight = weight + identity_weight
            elif isinstance(conv, ConvBn):
                conv.switch_to_deploy()
                conv_weight = conv.fused_conv.weight
                pad = (self.max_kernel_size - conv.fused_conv.weight.shape[-1]) // 2
                conv_weight = F.pad(conv_weight, [pad, pad, pad, pad])
                conv_bias = conv.fused_conv.bias
                weight = weight + conv_weight
                bias = bias + conv_bias
            elif isinstance(conv, nn.Conv2d):
                conv_weight = conv.weight
                print(conv.weight.shape[-1])
                pad = (self.max_kernel_size - conv.weight.shape[-1]) // 2
                conv_weight = F.pad(conv_weight, [pad, pad, pad, pad])
                print(conv_weight)
                conv_bias = conv.bias
                weight = weight + conv_weight
                bias = bias + conv_bias
            else:
                raise TypeError(f"Unsupported layer type: {type(conv)}")

        self.weight = weight.detach()
        self.bias = bias.detach()
    @torch.no_grad()
    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        self._convert_weight_and_bias()
        self.fused_conv = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=(self.max_kernel_size, self.max_kernel_size),
            stride=self.stride,
            padding=self.max_padding,
            dilation=self.dilation,
            groups=self.groups,
            bias=True
        )
        self.fused_conv.weight.data = self.weight
        self.fused_conv.bias.data = self.bias
        print(self.fused_conv.weight,11111)
        del self.convs
        self.deploy = True
def initialize_conv_weights_to_one(model):
    """
    Initialize all Conv2d layers in the model to have weight values equal to 1.
    """
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            # Initialize the weight to 1
            nn.init.constant_(module.weight, 2)
            # Initialize bias to 0 if it exists
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

In [170]:
import torch

# Create the RepConv2d model with kernel sizes that include 0
model = RepConv2d(in_channels=16, out_channels=16, kernel_sizes=(0,1, 3,7,11), stride=1, dilation=3, groups=1, use_bn=False, deploy=False).eval()
print(model)
#initialize_conv_weights_to_one(model)
# Create a dummy input tensor (e.g., batch size of 1, 3 channels, 224x224 image)
x = torch.randn(1, 16, 224, 224)

# Run the model before deployment (non-deployed state)

output_before_deploy = model(x)

# Switch the model to deploy mode
model.switch_to_deploy()
print(model)
# Run the model after deployment (fused conv state)

output_after_deploy = model(x)

# Check if the outputs are 
# the same
if torch.allclose(output_before_deploy, output_after_deploy, atol=1e-5):
    print("The outputs are the same before and after deployment.")
else:
    print("The outputs differ before and after deployment.")


RepConv2d(
  (convs): ModuleList(
    (0): Identity()
    (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), dilation=(3, 3))
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))
    (3): Conv2d(16, 16, kernel_size=(7, 7), stride=(1, 1), padding=(9, 9), dilation=(3, 3))
    (4): Conv2d(16, 16, kernel_size=(11, 11), stride=(1, 1), padding=(15, 15), dilation=(3, 3))
  )
)
Parameter containing:
tensor([[[[ 1.8109e-02, -2.0485e-02,  5.2946e-03,  ...,  1.5188e-02,
           -7.0638e-03,  5.3876e-03],
          [-1.0938e-02, -2.8759e-03,  1.4745e-02,  ..., -7.8913e-03,
           -2.0768e-02,  8.0515e-03],
          [ 1.8660e-02, -2.1418e-02, -3.0451e-03,  ..., -9.5498e-03,
            8.5726e-03,  7.6391e-03],
          ...,
          [ 1.2474e-02, -3.2262e-03, -1.6963e-04,  ...,  7.2465e-03,
           -1.6129e-02,  1.1251e-02],
          [-1.8547e-02, -2.0286e-02, -2.1047e-02,  ..., -1.9354e-02,
            1.1459e-02,  1.6453e-02],
          [-

In [107]:
output_before_deploy

tensor([[[[  2.7438,   3.8604,   0.7208,  ...,   5.0495,  -3.1655,   0.6610],
          [ -6.2636,  -6.7629,   1.0658,  ...,   8.5792,   4.3009,  -2.7788],
          [ -3.7285,  -2.6189,   0.3735,  ...,  13.6187,   2.4866,   2.6248],
          ...,
          [ -7.9906, -14.5445,  -9.1656,  ..., -11.7440,  -1.0764,   5.0076],
          [ -5.1502, -10.2577,  -5.7804,  ...,   5.8917,   2.9787,  -2.1047],
          [ -1.5465,  -1.5369,  -1.9339,  ...,   5.5483,   3.7897,  -7.4650]]]])

In [108]:
output_after_deploy 

tensor([[[[  2.7438,   3.8604,   0.7208,  ...,   5.0495,  -3.1655,   0.6610],
          [ -6.2636,  -6.7629,   1.0658,  ...,   8.5792,   4.3009,  -2.7788],
          [ -3.7285,  -2.6189,   0.3735,  ...,  13.6187,   2.4866,   2.6248],
          ...,
          [ -7.9906, -14.5445,  -9.1656,  ..., -11.7440,  -1.0764,   5.0076],
          [ -5.1502, -10.2577,  -5.7804,  ...,   5.8917,   2.9787,  -2.1047],
          [ -1.5465,  -1.5369,  -1.9339,  ...,   5.5483,   3.7897,  -7.4650]]]])

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)

    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std

    def switch_to_deploy(self):
        if self.bn.training:
            raise RuntimeError("BatchNorm should be in evaluation mode (eval) before deployment.")
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs
class RepConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes=(0, 3, 7, 11), stride=1, dilation=1, groups=1, use_bn=True, deploy=False):
        super(RepConv2d, self).__init__()
        self.deploy = deploy
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = sorted(kernel_sizes)
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.max_kernel_size = max(self.kernel_sizes)
        self.max_padding = (self.max_kernel_size - 1) // 2 * dilation

        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=self.max_kernel_size, stride=stride,
                                        padding=self.max_padding, dilation=dilation, groups=groups, bias=True)
        else:
            convs = []
            if use_bn:
                if 0 in self.kernel_sizes:
                    assert in_channels == out_channels, "in_channels and out_channels should be equal when kernel_size is 0"
                    assert stride == 1, "stride should be 1 when kernel size is 0"
                    convs.append(nn.BatchNorm2d(out_channels))  # Add BatchNorm for kernel size 0
                    self.kernel_sizes.remove(0)
                # Add ConvBn layers for other kernel sizes
                convs.extend([
                    ConvBn(in_channels, out_channels, k, stride=stride, dilation=dilation,
                           padding=(k - 1) // 2 * dilation, groups=groups)
                    for k in self.kernel_sizes
                ])
            else:
                if 0 in self.kernel_sizes:
                    assert in_channels == out_channels, "in_channels and out_channels should be equal when kernel_size is 0"
                    assert stride == 1, "stride should be 1 when kernel size is 0"
                    convs.append(nn.Identity())  # Identity for kernel size 0
                    self.kernel_sizes.remove(0)
                # Add Conv2d layers for other kernel sizes
                convs.extend([
                    nn.Conv2d(in_channels, out_channels, k, stride=stride, dilation=dilation,
                              padding=(k - 1) // 2 * dilation, bias=True, groups=groups)
                    for k in self.kernel_sizes
                ])
            self.convs = nn.ModuleList(convs)

    def forward(self, x):
        if self.deploy:
            return self.fused_conv(x)
        else:
            conv_outputs = []
            for conv in self.convs:
                conv_outputs.append(conv(x))
            return sum(conv_outputs)

    def _convert_weight_and_bias(self):
        if hasattr(self.convs[-1], 'switch_to_deploy'):
            self.convs[-1].switch_to_deploy()
            weight = self.convs[-1].fused_conv.weight
            bias = self.convs[-1].fused_conv.bias
        else:
            weight = self.convs[-1].weight
            bias = self.convs[-1].bias

        for conv in self.convs[:-1]:
            if isinstance(conv, nn.BatchNorm2d):
                std = (conv.running_var + conv.eps).sqrt()
                t = (conv.weight / std).reshape(-1, 1, 1, 1)
                pad = (self.max_kernel_size - 1) // 2
                input_dim = self.in_channels // self.groups
                identity_weight = F.pad(
                    torch.zeros(self.convs[-1].fused_conv.weight.shape[0], self.convs[-1].fused_conv.weight.shape[1], 1, 1).to(weight.device),
                    [pad, pad, pad, pad]
                )
                for i in range(self.in_channels):
                    identity_weight[i, i % input_dim, self.max_kernel_size//2, self.max_kernel_size//2] = 1
                weight = weight + identity_weight* t
                bias = bias + conv.bias - conv.running_mean * conv.weight / std
            elif isinstance(conv, nn.Identity):
                pad = (self.max_kernel_size - 1) // 2
                input_dim = self.in_channels // self.groups
                identity_weight = F.pad(
                    torch.zeros(self.convs[-1].weight.shape[0], self.convs[-1].weight.shape[1], 1, 1).to(weight.device),
                    [pad, pad, pad, pad]
                ) 
                for i in range(self.in_channels):
                    identity_weight[i, i % input_dim, self.max_kernel_size//2, self.max_kernel_size//2] = 1
                weight = weight + identity_weight
            elif isinstance(conv, ConvBn):
                conv.switch_to_deploy()
                conv_weight = conv.fused_conv.weight
                pad = (self.max_kernel_size - conv.fused_conv.weight.shape[-1]) // 2
                conv_weight = F.pad(conv_weight, [pad, pad, pad, pad])
                conv_bias = conv.fused_conv.bias
                weight = weight + conv_weight
                bias = bias + conv_bias
            elif isinstance(conv, nn.Conv2d):
                conv_weight = conv.weight
                pad = (self.max_kernel_size - conv.weight.shape[-1]) // 2
                conv_weight = F.pad(conv_weight, [pad, pad, pad, pad])
                conv_bias = conv.bias
                weight = weight + conv_weight
                bias = bias + conv_bias
            else:
                raise TypeError(f"Unsupported layer type: {type(conv)}")

        self.weight = weight.detach()
        self.bias = bias.detach()
    @torch.no_grad()
    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        self._convert_weight_and_bias()
        self.fused_conv = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=(self.max_kernel_size, self.max_kernel_size),
            stride=self.stride,
            padding=self.max_padding,
            dilation=self.dilation,
            groups=self.groups,
            bias=True
        )
        self.fused_conv.weight.data = self.weight
        self.fused_conv.bias.data = self.bias
        del self.convs
        self.deploy = True

In [19]:
import torch

# Create the RepConv2d model with kernel sizes that include 0
model = RepConv2d(in_channels=64, out_channels=64, kernel_sizes=(0,1, 3,5), stride=1, dilation=1, groups=8, use_bn=False, deploy=False).eval()
print(model)
#initialize_conv_weights_to_one(model)
# Create a dummy input tensor (e.g., batch size of 1, 3 channels, 224x224 image)
x = torch.randn(1, 64, 224, 224)

# Run the model before deployment (non-deployed state)

output_before_deploy = model(x)

# Switch the model to deploy mode
model.switch_to_deploy()
print(model)
# Run the model after deployment (fused conv state)

output_after_deploy = model(x)

# Check if the outputs are 
# the same
if torch.allclose(output_before_deploy, output_after_deploy, atol=1e-5):
    print("The outputs are the same before and after deployment.")
else:
    print("The outputs differ before and after deployment.")

RepConv2d(
  (convs): ModuleList(
    (0): Identity()
    (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), groups=8)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8)
    (3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=8)
  )
)
RepConv2d(
  (fused_conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=8)
)
The outputs are the same before and after deployment.


In [22]:
import torch

def test_deployment_consistency_sequential():
    # Define a larger set of parameter combinations for testing
    parameter_combinations = [
        (3, 3, (0, 1, 3), 1, 1, 1, True), 
        (3, 3, (1, 3, 5), 1, 1, 1, False), 
        (16, 16, (1, 3), 1, 1, 4, True),
        (16, 16, (3, 7, 11), 1, 2, 2, False), 
        (32, 32, (1, 3, 5, 7), 1, 1, 4, True),
        (64, 64, (1, 3, 5), 1, 1, 1, False), 
        (128, 128, (0, 1, 3, 5), 1, 2, 16, True),
        (256, 256, (0,3, 7), 1, 1, 32, False),
        (64, 64, (1, 5), 1, 1, 4, True),
        (128, 64, (3, 5, 7), 1, 2, 2, False)
    ]

    results = []
    
    for i, params in enumerate(parameter_combinations):
        in_channels, out_channels, kernel_sizes, stride, dilation, groups, use_bn = params
        
        # Create model instance for each parameter combination
        model = RepConv2d(in_channels=in_channels, out_channels=out_channels, 
                          kernel_sizes=kernel_sizes, stride=stride, dilation=dilation, 
                          groups=groups, use_bn=use_bn, deploy=False).eval()

        # Create a dummy input tensor with a smaller size to reduce memory usage
        x = torch.randn(1, in_channels, 64, 64)  # Reduced input size

        # Run the model before deployment (non-deployed state)
        with torch.no_grad():
            output_before_deploy = model(x)

        # Switch the model to deploy mode
        model.switch_to_deploy()

        # Run the model after deployment (fused conv state)
        with torch.no_grad():
            output_after_deploy = model(x)

        # Check if the outputs are the same (ignoring small numerical differences)
        outputs_are_the_same = torch.allclose(output_before_deploy, output_after_deploy, atol=1e-5)
        results.append((i, outputs_are_the_same))
        
        # Clear GPU memory (if you're using CUDA) after each test
        torch.cuda.empty_cache()

    return results

# Run the test for all combinations
test_results = test_deployment_consistency_sequential()

# Output the results
test_results


[(0, True),
 (1, True),
 (2, True),
 (3, True),
 (4, True),
 (5, True),
 (6, True),
 (7, True),
 (8, True),
 (9, True)]

In [201]:
output_after_deploy

tensor([[[[ 3.2503, -2.3336, -0.3168,  ...,  1.6845,  1.8223,  0.5355],
          [-0.7155, -1.3402, -1.5542,  ...,  1.8213,  0.3087,  0.4821],
          [ 1.1372,  2.1317,  0.5604,  ..., -1.3089,  1.2011,  1.0502],
          ...,
          [-1.0010,  0.4254, -1.7083,  ...,  0.6790, -1.0151,  2.9918],
          [-0.9721,  0.8507, -0.8526,  ...,  2.0636, -2.1848,  0.4944],
          [-1.6197,  1.3920,  2.0821,  ...,  3.0186,  1.5093, -0.8927]],

         [[ 4.2952, -2.6332, -2.8288,  ..., -1.3617,  0.1294, -0.8138],
          [-1.1744, -1.9993, -3.4749,  ...,  0.2073,  1.2643, -0.7046],
          [-3.1230,  1.4844, -0.8444,  ...,  0.4686, -1.0576, -2.8354],
          ...,
          [ 1.1390, -1.3227, -1.9438,  ..., -0.1970, -0.8131, -0.0600],
          [-1.8434,  0.0814, -0.4531,  ..., -1.8145, -1.4689,  0.6596],
          [-1.7036,  0.1385,  1.2308,  ..., -4.2808,  1.7674,  1.1973]],

         [[-0.9048, -3.2768,  3.2289,  ..., -2.0433,  0.2257, -0.5137],
          [-3.4765, -1.6263, -

In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBn, self).__init__()
        self.deploy = deploy
        self.padding = (kernel_size - 1) // 2 * dilation
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                        padding=self.padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=self.padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.bn = nn.BatchNorm2d(num_features=out_channels)

    def _fuse_bn_tensor(self, conv, bn):
        # Normalization and scaling
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        
        # Apply scaling and shifting to bias
        deploy_weight = conv.weight * t
        deploy_bias = bn.bias - bn.running_mean * bn.weight / std
        
        return deploy_weight, deploy_bias

    @torch.no_grad()
    def switch_to_deploy(self):
        if self.bn.training:
            raise RuntimeError("BatchNorm should be in evaluation mode (eval) before deployment.")
        
        deploy_k, deploy_b = self._fuse_bn_tensor(self.conv, self.bn)
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                                    kernel_size=self.conv.kernel_size, stride=self.conv.stride,
                                    padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True,
                                    padding_mode=self.conv.padding_mode)
        self.__delattr__('conv')
        self.__delattr__('bn')
        
        # Apply the fused weights and biases
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.conv(input)
            square_outputs = self.bn(square_outputs)
            return square_outputs

        
def test_deployment_consistency_sequential():
    # Define a larger set of parameter combinations for testing
    parameter_combinations = [
        (256, 256,  3, 1, 1, 1, True), 
        (3, 3,  5, 1, 1, 1, False), 
        (16, 16, 3, 1, 1, 4, True),
        (16, 16,  11, 1, 2, 2, False), 
        (32, 32,  7, 1, 1, 4, True),
        (64, 64,  5, 1, 1, 1, False), 
        (128, 128,  5, 1, 2, 16, True),
        (256, 256,  7, 1, 1, 32, False),
        (64, 64,  5, 1, 1, 4, True),
        (128, 64,  7, 1, 2, 2, False)
    ]

    results = []
    
    for i, params in enumerate(parameter_combinations):
        in_channels, out_channels, kernel_sizes, stride, dilation, groups, use_bn = params
        
        # Create model instance for each parameter combination
        model = ConvBn(in_channels=in_channels, out_channels=out_channels, 
                          kernel_size=kernel_sizes, stride=stride, dilation=dilation, 
                          groups=groups, deploy=False).eval()

        # Create a dummy input tensor with a smaller size to reduce memory usage
        x = torch.randn(1, in_channels, 64, 64)  # Reduced input size

        # Run the model before deployment (non-deployed state)
        with torch.no_grad():
            output_before_deploy = model(x)

        # Switch the model to deploy mode
        model.switch_to_deploy()

        # Run the model after deployment (fused conv state)
        with torch.no_grad():
            output_after_deploy = model(x)

        # Check if the outputs are the same (ignoring small numerical differences)
        outputs_are_the_same = torch.allclose(output_before_deploy, output_after_deploy, atol=1e-6)
        results.append((i, outputs_are_the_same))
        
        # Clear GPU memory (if you're using CUDA) after each test
        torch.cuda.empty_cache()

    return results

# Run the test for all combinations
test_results = test_deployment_consistency_sequential()

# Output the results
test_results

[(0, False),
 (1, True),
 (2, True),
 (3, False),
 (4, False),
 (5, False),
 (6, True),
 (7, True),
 (8, False),
 (9, False)]

In [71]:

import torch 

# 初始化模型
model = ConvBn(256,256,3).eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 准备输入数据
input_tensor_1 = torch.randn(1, 256, 128, 128).to(device)
input_tensor_2 = torch.randn(1, 256, 128, 128).to(device)

# 打印模型的输出（切换前）
output_before = model(input_tensor_1)


# 计算初始模型参数
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters:", total_params)
print("Trainable Parameters:", trainable_params)

# 使用 thop 计算 FLOPs 和参数
from thop import profile
from thop import clever_format
flops, params = profile(model, inputs=[input_tensor_1])
macs, params = clever_format([flops, total_params], "%.3f")
gflops = flops / 1e9
print(f"模型的 FLOPs: {gflops} GigaFLOPs")
print(f"模型的 params: {params} params")

# 切换到部署模式
model = repvgg_model_convert(model)

# 打印模型的输出（切换后）
output_after = model(input_tensor_1)


# 检查输出是否相同
output_difference = torch.allclose(output_before[0], output_after[0],atol=1e-5)
print("Do the outputs match?", output_difference)

# 重新计算模型参数
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters after switch:", total_params)
print("Trainable Parameters after switch:", trainable_params)

# 重新计算 FLOPs 和参数
flops, params = profile(model, inputs=[input_tensor_1])
macs, params = clever_format([flops, total_params], "%.3f")
gflops = flops / 1e9
print(f"模型的 FLOPs after switch: {gflops} GigaFLOPs")
print(f"模型的 params after switch: {params} params")



Total Parameters: 590336
Trainable Parameters: 590336
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
模型的 FLOPs: 9.680453632 GigaFLOPs
模型的 params: 590.336K params
Do the outputs match? True
Total Parameters after switch: 590080
Trainable Parameters after switch: 590080
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
模型的 FLOPs after switch: 9.663676416 GigaFLOPs
模型的 params after switch: 590.080K params


In [67]:
class RepVGGBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
        super(RepVGGBlock, self).__init__()
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels

        assert kernel_size == 3
        assert padding == 1

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.ReLU()

        if use_se:
            #   Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
            self.se = nn.Identity()
        else:
            self.se = nn.Identity()

        if deploy:
            self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)

        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
            self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
            self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
            print('RepVGG Block, identity = ', self.rbr_identity)


    def forward(self, inputs):
        if hasattr(self, 'rbr_reparam'):
            return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))


    #   Optional. This may improve the accuracy and facilitates quantization in some cases.
    #   1.  Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
    #   2.  Use like this.
    #       loss = criterion(....)
    #       for every RepVGGBlock blk:
    #           loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
    #       optimizer.zero_grad()
    #       loss.backward()
    def get_custom_L2(self):
        K3 = self.rbr_dense.conv.weight
        K1 = self.rbr_1x1.conv.weight
        t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
        t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()

        l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum()      # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
        eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1                           # The equivalent resultant central point of 3x3 kernel.
        l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum()        # Normalize for an L2 coefficient comparable to regular L2.
        return l2_loss_eq_kernel + l2_loss_circle



#   This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
#   You can get the equivalent kernel and bias at any time and do whatever you want,
    #   for example, apply some penalties or constraints during training, just like you do to the other models.
#   May be useful for quantization or pruning.
    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1,1,1,1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
                                     kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
                                     padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        self.__delattr__('rbr_dense')
        self.__delattr__('rbr_1x1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True

def test_deployment_consistency_sequential():
    # Define a larger set of parameter combinations for testing
    parameter_combinations = [
        (256, 256,  3, 1, 1, 1, True), 
    ]

    results = []
    
    for i, params in enumerate(parameter_combinations):
        in_channels, out_channels, kernel_sizes, stride, dilation, groups, use_bn = params
        
        # Create model instance for each parameter combination
        model = RepVGGBlock(in_channels=in_channels, out_channels=out_channels, 
                          kernel_size=kernel_sizes, stride=stride, dilation=dilation, 
                          groups=groups, deploy=False,padding=1).eval()

        # Create a dummy input tensor with a smaller size to reduce memory usage
        x = torch.randn(1, in_channels, 64, 64)  # Reduced input size

        # Run the model before deployment (non-deployed state)
        with torch.no_grad():
            output_before_deploy = model(x)

        # Switch the model to deploy mode
        model.switch_to_deploy()

        # Run the model after deployment (fused conv state)
        with torch.no_grad():
            output_after_deploy = model(x)

        # Check if the outputs are the same (ignoring small numerical differences)
        outputs_are_the_same = torch.allclose(output_before_deploy, output_after_deploy, atol=1e-5)
        results.append((i, outputs_are_the_same))
        
        # Clear GPU memory (if you're using CUDA) after each test
        torch.cuda.empty_cache()

    return results

# Run the test for all combinations
test_results = test_deployment_consistency_sequential()

# Output the results
test_results

RepVGG Block, identity =  BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


[(0, True)]

In [52]:
# --------------------------------------------------------
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
# Github source: https://github.com/DingXiaoH/RepVGG
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch.nn as nn
import numpy as np
import torch
import copy
import torch.utils.checkpoint as checkpoint

def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
    result = nn.Sequential()
    result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                                  kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
    result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    return result

class RepVGGBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
        super(RepVGGBlock, self).__init__()
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels

        assert kernel_size == 3
        assert padding == 1

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.ReLU()

        if use_se:
            #   Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
            self.se = nn.Identity()
        else:
            self.se = nn.Identity()

        if deploy:
            self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)

        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
            self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
            self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
            print('RepVGG Block, identity = ', self.rbr_identity)


    def forward(self, inputs):
        if hasattr(self, 'rbr_reparam'):
            return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))


    #   Optional. This may improve the accuracy and facilitates quantization in some cases.
    #   1.  Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
    #   2.  Use like this.
    #       loss = criterion(....)
    #       for every RepVGGBlock blk:
    #           loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
    #       optimizer.zero_grad()
    #       loss.backward()
    def get_custom_L2(self):
        K3 = self.rbr_dense.conv.weight
        K1 = self.rbr_1x1.conv.weight
        t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
        t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()

        l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum()      # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
        eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1                           # The equivalent resultant central point of 3x3 kernel.
        l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum()        # Normalize for an L2 coefficient comparable to regular L2.
        return l2_loss_eq_kernel + l2_loss_circle



#   This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
#   You can get the equivalent kernel and bias at any time and do whatever you want,
    #   for example, apply some penalties or constraints during training, just like you do to the other models.
#   May be useful for quantization or pruning.
    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1,1,1,1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
                                     kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
                                     padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        self.__delattr__('rbr_dense')
        self.__delattr__('rbr_1x1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True



class RepVGG(nn.Module):

    def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False):
        super(RepVGG, self).__init__()
        assert len(width_multiplier) == 4
        self.deploy = deploy
        self.override_groups_map = override_groups_map or dict()
        assert 0 not in self.override_groups_map
        self.use_se = use_se
        self.use_checkpoint = use_checkpoint

        self.in_planes = min(64, int(64 * width_multiplier[0]))
        self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se)
        self.cur_layer_idx = 1
        self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)
        self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
        self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)
        self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)

    def _make_stage(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        blocks = []
        for stride in strides:
            cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
            blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
                                      stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se))
            self.in_planes = planes
            self.cur_layer_idx += 1
        return nn.ModuleList(blocks)

    def forward(self, x):
        out = self.stage0(x)
        for stage in (self.stage1, self.stage2, self.stage3, self.stage4):
            for block in stage:
                if self.use_checkpoint:
                    out = checkpoint.checkpoint(block, out)
                else:
                    out = block(out)
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}

def create_RepVGG_A0(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
                  width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_A1(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
                  width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_A2(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
                  width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B0(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B1(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)


def create_RepVGG_B2(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)


def create_RepVGG_B3(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
                  width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)

def create_RepVGG_D2se(deploy=False, use_checkpoint=False):
    return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,
                  width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)


func_dict = {
'RepVGG-A0': create_RepVGG_A0,
'RepVGG-A1': create_RepVGG_A1,
'RepVGG-A2': create_RepVGG_A2,
'RepVGG-B0': create_RepVGG_B0,
'RepVGG-B1': create_RepVGG_B1,
'RepVGG-B1g2': create_RepVGG_B1g2,
'RepVGG-B1g4': create_RepVGG_B1g4,
'RepVGG-B2': create_RepVGG_B2,
'RepVGG-B2g2': create_RepVGG_B2g2,
'RepVGG-B2g4': create_RepVGG_B2g4,
'RepVGG-B3': create_RepVGG_B3,
'RepVGG-B3g2': create_RepVGG_B3g2,
'RepVGG-B3g4': create_RepVGG_B3g4,
'RepVGG-D2se': create_RepVGG_D2se,      #   Updated at April 25, 2021. This is not reported in the CVPR paper.
}
def get_RepVGG_func_by_name(name):
    return func_dict[name]



#   Use this for converting a RepVGG model or a bigger model with RepVGG as its component
#   Use like this
#   model = create_RepVGG_A0(deploy=False)
#   train model or load weights
#   repvgg_model_convert(model, save_path='repvgg_deploy.pth')
#   If you want to preserve the original model, call with do_copy=True

#   ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
#   train_backbone = create_RepVGG_B2(deploy=False)
#   train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
#   train_pspnet = build_pspnet(backbone=train_backbone)
#   segmentation_train(train_pspnet)
#   deploy_pspnet = repvgg_model_convert(train_pspnet)
#   segmentation_test(deploy_pspnet)
#   =====================   example_pspnet.py shows an example

def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
    if do_copy:
        model = copy.deepcopy(model)
    for module in model.modules():
        if hasattr(module, 'switch_to_deploy'):
            module.switch_to_deploy()
    if save_path is not None:
        torch.save(model.state_dict(), save_path)
    return model

In [57]:
model = create_RepVGG_B1g2().eval()

#initialize_conv_weights_to_one(model)
# Create a dummy input tensor (e.g., batch size of 1, 3 channels, 224x224 image)
x = torch.randn(1, 3, 224, 224)

# 计算初始模型参数
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters:", total_params)
print("Trainable Parameters:", trainable_params)

# 使用 thop 计算 FLOPs 和参数
from thop import profile
from thop import clever_format
flops, params = profile(model, inputs=[x])
macs, params = clever_format([flops, total_params], "%.3f")
gflops = flops / 1e9
print(f"模型的 FLOPs: {gflops} GigaFLOPs")
print(f"模型的 params: {params} params")


output_before_deploy = model(x)

# Switch the model to deploy mode
model= repvgg_model_convert(model, save_path=None, do_copy=True)

# Run the model after deployment (fused conv state)

output_after_deploy = model(x)
# 计算初始模型参数
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Parameters:", total_params)
print("Trainable Parameters:", trainable_params)

# 使用 thop 计算 FLOPs 和参数
from thop import profile
from thop import clever_format
flops, params = profile(model, inputs=[x])
macs, params = clever_format([flops, total_params], "%.3f")
gflops = flops / 1e9
print(f"模型的 FLOPs: {gflops} GigaFLOPs")
print(f"模型的 params: {params} params")
# Check if the outputs are 
# the same
if torch.allclose(output_before_deploy, output_after_deploy, atol=1e-5):
    print("The outputs are the same before and after deployment.")
else:
    print("The outputs differ before and after deployment.")

RepVGG Block, identity =  None
RepVGG Block, identity =  None
RepVGG Block, identity =  BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  None
RepVGG Block, identity =  BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
RepVGG Block, identity =  None
RepVGG Block, identi

In [55]:
!pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238
