## PyTorch Padding Same

使用 PyTorch 实现 padding="same"。

In [53]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Date    : Sep-29-21 20:17
# @Author  : Kan HUANG (kan.huang@connect.ust.hk)
# @RefLink : https://github.com/pytorch/pytorch/blob/1.7/torch/nn/modules/conv.py
# @RefLink : https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D
# @RefLink : https://oldpan.me/archives/pytorch-same-padding-tflike

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair
from torch.nn.modules.conv import _ConvNd

In [59]:
def conv2d_same_padding(input_, weight, bias=None, stride=1, padding="same", dilation=1, groups=1):
    """conv2d_same_padding
    Args:
        padding: one of "valid" or "same" (case-insensitive). 
    实现 TensorFlow padding="same" 的效果的 conv2d 函数
    函数中padding参数可以无视，实际实现的是padding=same的效果
    """
    if padding not in ["same", "valid"]:
        raise ValueError(f"""padding: {padding} not in ["same", "valid"]""")

    if padding == "valid":
        output =  F.conv2d(input_, weight, bias, stride,
                    padding=0,
                    dilation=dilation, groups=groups)

    elif padding == "same":
        input_rows = input_.size(2)
        filter_rows = weight.size(2)
        effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
        
        out_rows = (input_rows + stride[0] - 1) // stride[0]
        padding_rows = max(0, (out_rows - 1) * stride[0] +
                            (filter_rows - 1) * dilation[0] + 1 - input_rows)

        rows_odd = (padding_rows % 2 != 0)
        padding_cols = max(0, (out_rows - 1) * stride[0] +
                            (filter_rows - 1) * dilation[0] + 1 - input_rows)
        cols_odd = (padding_rows % 2 != 0)

        if rows_odd or cols_odd:
            input_ = F.pad(input_, [0, int(cols_odd), 0, int(rows_odd)])

        output =  F.conv2d(input_, weight, bias, stride,
                    padding=(padding_rows // 2, padding_cols // 2),
                    dilation=dilation, groups=groups)

    return output

In [60]:
class Conv2dSamePadding(_ConvNd): 
    """
    Args:
        padding: one of "valid" or "same" (case-insensitive). Refer to: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        # padding: _size_2_t = 0,
        padding: str = 'valid',
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None
    ) -> None:
        # in later versions than 1.7
        # factory_kwargs = {'device': device, 'dtype': dtype}
        kernel_size_ = _pair(kernel_size)
        stride_ = _pair(stride)
        padding_ = padding if isinstance(padding, str) else _pair(padding)
        dilation_ = _pair(dilation)
        super(Conv2dSamePadding, self).__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _pair(0), groups, bias, padding_mode) # **factory_kwargs

        self.padding = padding # Overwrite self.padding
    
    # 修改这里的实现函数
    def forward(self, x):
        return conv2d_same_padding(x, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

In [65]:
conv2d_1 = Conv2dSamePadding(3,1,3,1, padding="same")
conv2d_2 = Conv2dSamePadding(3,1,3,1, padding="valid")

In [66]:
x = torch.from_numpy(np.random.rand(64, 3, 32, 32).astype(np.float32))
out = conv2d_1(x)
print(out.shape)
out = conv2d_2(x)
print(out.shape)

torch.Size([64, 1, 32, 32])
torch.Size([64, 1, 30, 30])
