In [None]:
import torch.nn as nn

from collections import OrderedDict

class ConvNormActivation(nn.Module):
    def __init__(
        self,
        name,
        in_channels, out_channels, kernel_size,
        conv_layer=nn.Conv1d,
        norm_layer=nn.BatchNorm1d,
        activation_layer=nn.ReLU,
        bias=True
    ):
        super().__init__()

        # Input args
        self.name = name
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        self.conv_layer = conv_layer
        self.norm_layer = norm_layer
        self.activation_layer = activation_layer

        self.bias = bias

        # Layers
        self.block = self.make_block()

    def make_block(self):
        items = [
            (
                self.name + '_conv',
                self.conv_layer(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    bias=self.bias,
                )
            )
        ]

        if self.norm_layer is not None:
            items.append(
                (self.name + '_bn', self.norm_layer(self.out_channels))
            )

        if self.activation_layer is not None:
            items.append(
                (self.name + '_relu', self.activation_layer(inplace=True))
            )

        return nn.Sequential(OrderedDict(items))

    def forward(self, x):
        x = self.block(x)

        return x

In [None]:
import torch.nn as nn

from collections import OrderedDict

class ConvNormActivationPoolDropout(nn.Module):
    def __init__(self, name, in_channels, out_channels, kernel_size, dropout, gp=False):
        super().__init__()

        # Input args
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Layers
        self.block = nn.Sequential(
            OrderedDict([
                (
                    name + '_conv',
                    nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                    )
                ),
                (name + '_relu', nn.ReLU(inplace=True)),
                (name + '_bn', nn.BatchNorm1d(out_channels)),
                (name + '_mp', nn.MaxPool1d(2) if not gp else nn.AdaptiveMaxPool1d(1)),
                (name + '_dropout', nn.Dropout(dropout)),
            ])
        )

    def forward(self, x):
        x = self.block(x)

        return x