In [3]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [4]:
class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out

In [5]:
class ConvBNLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 groups=1,
                 if_act=True,
                 act=None):
        super(ConvBNLayer, self).__init__()
        self.if_act = if_act
        self.act = act
        

        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups)

        self.bn = nn.BatchNorm2d(out_channels)
        
        if self.if_act:
            if self.act == "hardswish":
                self.act_fun = hswish()
            else:
                self.act_fun = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.if_act:
            x = self.act_fun(x)
        return x

In [8]:
ConvBNLayer(3,3,3,1,1,act="hardswish")

ConvBNLayer(
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fun): hswish()
)