In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

In [68]:
class LSTMFilter(nn.Module):
    # 定义LSTM模型
    def __init__(self, input_dim: int = 768, target_dim: int = 16) -> None:
        super(LSTMFilter, self).__init__()
        self.lstm = nn.LSTM(input_dim, target_dim, batch_first=True, bidirectional=True)  # 输出维度是16*2=32

    def forward(self, x: torch.Tensor):
        output, _ = self.lstm(x)
        return output.reshape(output.size(0), -1)

test_x = torch.randn(5, 10, 768)
lstm = LSTMFilter(768, 16)
output = lstm(test_x)
output.shape

torch.Size([5, 320])

In [65]:
class CNNLayer(nn.Module):
    # 定义CNN层
    def __init__(self, in_channels: int = 8, out_channels: int = 16, kernel_size: int = 7, stride: int = 2) -> None:
        super(CNNLayer, self).__init__()
        self.cnn = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
        self.pool = nn.MaxPool1d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
        self.bn = nn.BatchNorm1d(out_channels)
    
    def forward(self, x: torch.Tensor):
        output = self.cnn(x)
        output = self.pool(output)
        output = self.bn(output)
        return output

class CNNFilter(nn.Module):
    # 定义CNN滤波器
    def __init__(self, in_channels: int = 8, kernel_size: int = 7) -> None:
        super(CNNFilter, self).__init__()
        self.cnn1 = CNNLayer(in_channels, 1, kernel_size, stride=3)
        self.cnn2 = CNNLayer(1, 1, kernel_size, stride=1)
    
    def forward(self, x: torch.Tensor):
        output = self.cnn1(x)
        output = self.cnn2(output)
        output = output.squeeze(1)
        return output
    
test_x = torch.randn(5, 10, 768)
cnn = CNNFilter(10, 7)
output = cnn(test_x)
output.shape

torch.Size([5, 256])

In [63]:
class LSTMEncoder(nn.Module):
    # 定义LSTM编码器
    def __init__(self, input_dim: int = 768, target_dim: int = 128) -> None:
        super(LSTMEncoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, target_dim, batch_first=True, bidirectional=True)  # 输出维度是128*2=256
    
    def forward(self, x: torch.Tensor):
        seq, (h, c) = self.lstm(x)
        return torch.cat([h[0], h[1]], dim=1)

test_x = torch.randn(5, 10, 768)
lstm = LSTMEncoder(768, 128)
h = lstm(test_x)
h.shape

torch.Size([5, 256])