In [28]:
import torch
import torch.nn as nn

### Example Module

In [29]:
class ConvBNAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.act(x)

In [30]:
class SampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ConvBNAct(3, 16, 3)
        self.block2 = ConvBNAct(16, 32, 3)

    def forward(self, x):
        x = self.block1(x)
        return self.block2(x)

In [31]:
model = SampleModel()

In [32]:
model

SampleModel(
  (block1): ConvBNAct(
    (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
  )
  (block2): ConvBNAct(
    (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
  )
)

In [13]:
#Fuse

In [33]:
# model.eval()
# print()

In [36]:
torch.ao.quantization.fuse_modules_qat(model, [["block1.conv", "block1.bn", "block1.act"], 
                                        ["block2.conv", "block2.bn", "block2.act"]])

SampleModel(
  (block1): ConvBNAct(
    (conv): ConvBnReLU2d(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (bn): Identity()
    (act): Identity()
  )
  (block2): ConvBNAct(
    (conv): ConvBnReLU2d(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (bn): Identity()
    (act): Identity()
  )
)

In [21]:
type(model.block1)

__main__.ConvBNAct