In [11]:
import torch
import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self, input_features, output_features, target_time_steps):
        super(CNNModel, self).__init__()
        # First layer: extracting spatial and temporal features
        self.conv1 = nn.Conv3d(
            in_channels=input_features,
            out_channels=64,
            kernel_size=(3, 3, 3),
            padding=(1, 1, 1)
        )
        
        # Second layer: further extract deep features
        self.conv2 = nn.Conv3d(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3, 3),
            padding=(1, 1, 1)
        )
        
        # Third layer: reduce the time step from 16 to 3, and the feature dimension is 1
        self.conv3 = nn.Conv3d(
            in_channels=128,
            out_channels=output_features,
            kernel_size=(3, 1, 1),
            stride=(5, 1, 1),
            padding=(0, 0, 0)
        )

        # Activation Function and Dropout
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout3d(p=0.3)

    def forward(self, x):
        # Input: (batch, feature, time, height, width)
        x = self.tanh(self.conv1(x))  # (1, 8, 16, 721, 1440) -> (1, 64, 16, 721, 1440)
        x = self.dropout(x)

        x = self.tanh(self.conv2(x))  # (1, 64, 16, 721, 1440) -> (1, 128, 16, 721, 1440)
        x = self.dropout(x)
        x = self.conv3(x)  # (1, 128, 16, 721, 1440) -> (1, 1, 3, 721, 1440)
        return x

# 初始化模型
input_features = 8  # 输入特征数
output_features = 1  # 输出特征数 (最终为1)
target_time_steps = 16  # 输入时间步

model = CNNModel(input_features, output_features, target_time_steps)

# 创建随机输入数据
input_data = torch.rand((1, 8, 16, 721, 1440))  # 输入形状 (batch, feature, time, height, width)

# 推理
model.eval()
with torch.no_grad():
    output_data = model(input_data)

# 打印输入和输出形状
print("Input shape:", input_data.shape)
print("Output shape:", output_data.shape)


torch.Size([1, 128, 16, 721, 1440])
Input shape: torch.Size([1, 8, 16, 721, 1440])
Output shape: torch.Size([1, 1, 3, 721, 1440])
