In [1]:
import math
"""
输入数据大小为 (batch_size, channels, height, width)
卷积核大小为 (kernel_size, kernel_size)
填充大小为 (padding, padding)
步长为 (stride, stride)：
"""
def conv_output_size(input_size, kernel_size, padding, stride):
   
    """
    计算卷积后的输出大小
    """
    output_size = math.floor((input_size + 2 * padding - kernel_size) / stride) + 1
    return output_size

input_size = (1, 10, 28, 28)
batch_size, channels, height, width = input_size
kernel_size = 3
padding = 0
stride = 1

# 计算卷积后的输出大小
out_height = conv_output_size(height, kernel_size, padding, stride)
out_width = conv_output_size(width, kernel_size, padding, stride)

# 输出结果
print("输出大小：", (batch_size, channels, out_height, out_width))

输出大小： (1, 10, 26, 26)


In [15]:
def pool_output_size(input_size, pool_size, strides, padding='VALID'):
    """
    计算池化层的输出尺寸。

    Args:
        input_size: tuple, 输入张量的尺寸，格式为 (batch_size, channels, height, width)。
        pool_size: tuple, 池化窗口的大小，格式为 (pool_height, pool_width)。
        strides: tuple, 步幅的大小，格式为 (stride_height, stride_width)。
        padding: str, 填充方式，可选值为 'VALID' 或 'SAME'。

    Returns:
        tuple, 输出张量的尺寸，格式为 (batch_size, channels, output_height, output_width)。
    """
    batch_size, channels, height, width = input_size
    pool_height, pool_width = pool_size
    stride_height, stride_width = strides

    if padding == 'VALID':
        padding_height, padding_width = 0, 0
    elif padding == 'SAME':
        output_height = math.ceil(height / stride_height)
        output_width = math.ceil(width / stride_width)
        padding_height = max(0, (output_height - 1) * stride_height + pool_height - height)
        padding_width = max(0, (output_width - 1) * stride_width + pool_width - width)
    else:
        raise ValueError("无效的填充方式。")

    output_height = (height + 2 * padding_height - pool_height) // stride_height + 1
    output_width = (width + 2 * padding_width - pool_width) // stride_width + 1

    return batch_size, channels, output_height, output_width

input_size = (1, 10, 26, 26)
pool_size = (2, 2)
strides = (2, 2)
padding = 'SAME'

output_size = pool_output_size(input_size, pool_size, strides, padding)
print(output_size) 

(1, 10, 13, 13)
