In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Some helper functions to get the output length of a conv1d layer.
from functools import reduce
def get_conv1d_len_func(model):
    return lambda i: torch.floor((i + 2 * model.padding[0] - model.dilation[0] * (model.kernel_size[0] - 1) - 1) / model.stride[0] + 1)
        
def get_transpose_conv1d_len_func(model):
    return lambda i: (i - 1) * model.stride[0] - 2 * model.padding[0] + model.dilation[0] * (model.kernel_size[0] - 1) + model.output_padding[0] + 1

def get_unfold_len(length, window_ks, window_stride):
    return (length - window_ks) // window_stride + 1

# A simple conv + batch norm + relu model. 
# Note that the padding is set to 0.
class Block(nn.Module):
    def __init__(self, num_ins, num_outs, kernel_size=3,stride=1, dilation=1):
        super().__init__()
        
        # Note that the padding is 0 here. 
        self.conv = nn.Conv1d(num_ins, num_outs, kernel_size, padding=0, stride=stride,dilation=dilation)
        self.bn = nn.BatchNorm1d(num_outs) 

    def forward(self, x):
        x = F.relu(self.bn(self.conv(x)))
        return x
    
    # This helps you judge the output length of the model given some input length. 
    def get_output_lengths(self, length):
        return self.len_funcs(length)

# A more complex CNN blosk with residual path. aka ResBlock.
# You can chain several Resblock or Block together. 
class ResBlock(nn.Module):
    '''
    Gaddy and Klein, 2021, https://arxiv.org/pdf/2106.01933.pdf 
    Original code:
        https://github.com/dgaddy/silent_speech/blob/master/transformer.py
    '''
    def __init__(self, num_ins, num_outs, kernel_size = 3, padding = 1, stride=1):
        super().__init__()

        self.conv1 = nn.Conv1d(num_ins, num_outs, kernel_size, padding=padding, stride=stride)
        self.bn1 = nn.BatchNorm1d(num_outs)
        self.conv2 = nn.Conv1d(num_outs, num_outs, kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm1d(num_outs)

        # This helps whenever in_channels != out_channels or stride != 1.
        # E.g. the first input Resblock layer, or you want to downsample the input.
        if stride != 1 or num_ins != num_outs:
            # With kernel size of 1, this is essentially a linear layer but with stride. 
            self.residual_path = nn.Conv1d(num_ins, num_outs, 1, stride=stride)
            self.res_norm = nn.BatchNorm1d(num_outs)
        else:
            self.residual_path = None
        
        # This helps you judge the output length of the model given some input length.
        # len_funcs is a list of functions that takes in a length and returns the output length.
        len_funcs = []
        len_funcs.append(get_conv1d_len_func(self.conv1))
        len_funcs.append(get_conv1d_len_func(self.conv2))
        self.len_funcs = len_funcs
        
        
        # Helps you run a sanity check that if the residual path is activated, 
        # it should be configured such that the output length is the same as the output length of the main path.
        if self.residual_path is not None:
            residual_len_funcs = get_conv1d_len_func(self.residual_path)
            for data_length in torch.arange(10, 100, 1):
                main_len = reduce(lambda x, func: func(x), self.len_funcs, data_length).int()
                res_len = residual_len_funcs(data_length)
                assert main_len == res_len, f"Residual path length {res_len} is not the same as the main path length {main_len}. Please check the configuration or reach out to me."
    
    def get_output_lengths(self, length):
        return reduce(lambda x, func: func(x), self.len_funcs, length).int()
    
    def forward(self, x):
        input_value = x

        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))

        if self.residual_path is not None:
            res = self.res_norm(self.residual_path(input_value))
        else:
            res = input_value

        return F.relu(x + res)
    
# An example of CNN that chains several ResBlocks together.
# Generally, this is more powerful than the simple Block model, but is also more prune to overfitting.
# Original conv blocks: no dropout, 2 hidden-hidden blocks.
class Original(nn.Module):
    def __init__(self, in_channels=8, hidden_dim = 1024):
        super().__init__()

    
        self.conv_blocks = nn.Sequential(
            ResBlock(in_channels, hidden_dim, kernel_size = 3, stride = 1),
            ResBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1),
            ResBlock(hidden_dim, hidden_dim, kernel_size=3, stride=1),
        )

        def get_model_len_func(model):
            len_funcs = []
            for i in model.conv_blocks:
                for j in range(2):
                    len_funcs.append(get_conv1d_len_func(eval("i.conv" + str(j + 1))))
            return len_funcs

        self.len_funcs = get_model_len_func(self)

    def get_output_lengths(self, length):
        return reduce(lambda x, func: func(x), self.len_funcs, length)
    def forward(self, x):
        """
        Args:
            x: shape (batchsize, num_in_feats, seq_len).
        
        Return:
            out: shape (batchsize, num_out_feats, seq_len).
        """
        return self.conv_blocks(x)
    

In [3]:
# To stay consistent with Pytorch API, typically the input to a CNN has the shape (B, C, T)
# (B: batch size, C: number of channels, T: sequence length).

# You need to make sure that the "in_channel" (i.e. the number of channels in the first convolution layer) 
# or "num_in" as I used above
# is EXACTLY the same as the number of channels in the input.

# Here's one example.
# This model expects the input to have 8 channels, and the output will have 1024 channels.
model = Original(in_channels=8, hidden_dim=1024)

# Let's say the input has a length of 100.
# Create some dummy input with 8 channels. 
# The first dimension is the batch dimension.
data = torch.rand((1, 8, 100))

In [4]:
# Forward pass.
# As you can see, the model doesn't change the shape of the input after forward pass.
with torch.no_grad():
    print(model(data).shape)

torch.Size([1, 1024, 100])


In [None]:
# Sometimes the output length of a CNN can be troublesome.
# You can use the get_output_lengths to get the output length of the model.
# As you can see, regradless of the input length, the output length is always the same.
# We can safely use this 'Original' model as a feature extractor without worrying about change in length. 
for data_length in torch.arange(10, 100, 1):
    assert model.get_output_lengths(data_length) == data_length, f"Error: {model.get_output_lengths(data_length)}"
    print(f"Input length: {data_length}, Output length: {model.get_output_lengths(data_length)}")

In [6]:
# With the CNN encoder above, we can easily build any classifiers.
# Change num_classes to the number of classes you have.
# Here's one example of CNN + BiLSTM + Linear classifier.
class CNN_LSTM_Classifier(nn.Module):
    def __init__(self, in_channels=8, hidden_dim = 1024, num_classes=10, num_lstm_layers=2):
        super().__init__()

        self.cnn = Original(in_channels=in_channels, hidden_dim=hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        self.linear = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        """
        Args:
            x: shape (batchsize, num_in_feats, seq_len).
        
        Return:
            out: shape (batchsize, num_classes).
        """
        # (B, C, T)
        x = self.cnn(x)
        
        # Before passing into the LSTM, change the shape to (B, T, C).
        x, _ = self.lstm(x.permute(0, 2, 1))
        x = self.linear(x)
        return x

In [None]:
device = torch.device('cuda:0')
model = CNN_LSTM_Classifier(in_channels=8, hidden_dim=1024, num_classes=10, num_lstm_layers=2).to(device)
data = torch.rand((1, 8, 100)).to(device)
with torch.no_grad():
    print(model(data).shape)