In [3]:
import torch
from torch import nn
import matplotlib.pyplot as plt

Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where

$$
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
        \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
$$

$$
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
            \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
$$

# 转置卷积

https://distill.pub/2016/deconv-checkerboard/

元素之间: s-1行/列

元素四周: k-p-1行/列

卷积核上下左右翻转

In [4]:
nn.ConvTranspose2d(
    3,
    16,
    kernel_size=3,
    stride=1,
    padding=0,
    output_padding=0,
    dilation=1,
    groups=1,
    bias=True,
)

ConvTranspose2d(3, 16, kernel_size=(3, 3), stride=(1, 1))

In [5]:
x = torch.ones(1, 3, 4, 4)

In [6]:
# (4 - 1) * 1 - 2 * 0 + 1 * (3 - 1) + 1 = 3 - 0 + 2 + 1 = 6
# (h - 1) * 1 - 2 * 0 + 1 * (3 - 1) + 1 = h - 1 - 0 + 2 + 1 + 2
model = nn.ConvTranspose2d(3, 16, kernel_size=3, stride=1)
model(x).size()

torch.Size([1, 16, 6, 6])

In [7]:
# centernet https://github.com/bubbliiiing/centernet-pytorch/blob/main/nets/resnet50.py#L155
# (4 - 1) * 2 - 2 * 1 + 1 * (4 - 1) + 0 + 1 = 6 - 2 + 3 + 0 + 1 = 8
# (h - 1) * 2 - 2 * 1 + 1 * (4 - 1) + 0 + 1 = h * 2 - 2 - 2 + 3 + 0 + 1 = h * 2
model = nn.ConvTranspose2d(3, 16, kernel_size=4, stride=2, padding=1)
model(x).size()

torch.Size([1, 16, 8, 8])

In [8]:
# (4 - 1) * 2 - 2 * 1 + 1 * (3 - 1) + 1 + 1 = 6 - 2 + 2 + 1 + 1 = 8
# (h - 1) * 2 - 2 * 1 + 1 * (3 - 1) + 1 + 1 = h * 2 - 2 - 2 + 2 + 1 + 1 = h * 2
model = nn.ConvTranspose2d(3, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
model(x).size()

torch.Size([1, 16, 8, 8])

In [9]:
# (4 - 1) * 2 - 2 * 0 + 1 * (2 - 1) + 0 + 1 = 6 - 0 + 1 + 0 + 1 = 8
# (h - 1) * 2 - 2 * 0 + 1 * (2 - 1) + 0 + 1 = h * 2 - 2 - 0 + 1 + 0 + 1 = h * 2
model = nn.ConvTranspose2d(3, 16, kernel_size=2, stride=2)
model(x).size()

torch.Size([1, 16, 8, 8])

In [21]:
# (4 - 1) * 3 - 2 * 1 + 1 * (5 - 1) + 0 + 1 = 9 - 2 + 4 + 0 + 1 = 12
# (h - 1) * 3 - 2 * 1 + 1 * (5 - 1) + 0 + 1 = h * 3 - 3 - 2 + 4 + 0 + 1 = h * 3
model = nn.ConvTranspose2d(3, 16, kernel_size=5, stride=3, padding=1)
model(x).size()

torch.Size([1, 16, 30, 30])

In [11]:
# (4 - 1) * 3 - 2 * 0 + 1 * (3 - 1) + 0 + 1 = 9 - 0 + 2 + 0 + 1 = 12
# (h - 1) * 3 - 2 * 0 + 1 * (3 - 1) + 0 + 1 =h * 3 - 3 - 0 + 2 + 0 + 1 = h * 3
model = nn.ConvTranspose2d(3, 16, kernel_size=3, stride=3)
model(x).size()

torch.Size([1, 16, 12, 12])

In [12]:
# (4 - 1) * 2 - 2 * 0 + 1 * (3 - 1) + 0 + 1 = 6 - 0 + 2 + 0 + 1 = 9
# (h - 1) * 2 - 2 * 0 + 1 * (3 - 1) + 0 + 1 = h * 2 - 2 - 0 + 2 + 0 + 1 = h * 2 + 1
model = nn.ConvTranspose2d(3, 16, kernel_size=3, stride=2)
model(x).size()

torch.Size([1, 16, 9, 9])

In [13]:
# (4 - 1) * 2 - 2 * 1 + 1 * (3 - 1) + 0 + 1 = 6 - 2 + 2 + 0 + 1 = 7
# (h - 1) * 2 - 2 * 1 + 1 * (3 - 1) + 0 + 1 = h * 2 - 2 - 2 + 2 + 0 + 1 = h * 2 - 1
model = nn.ConvTranspose2d(3, 16, kernel_size=3, stride=2, padding=1)
model(x).size()

torch.Size([1, 16, 7, 7])

In [14]:
# (4 - 1) * 2 + 2 * 0 + 1 * (4 - 1) + 0 + 1 = 6 + 0 + 3 + 0 + 1 = 10
# (h - 1) * 2 + 2 * 0 + 1 * (4 - 1) + 0 + 1 = h * 2 - 2 + 0 + 3 + 0 + 1 = h * 2 + 2
model = nn.ConvTranspose2d(3, 16, kernel_size=4, stride=2)
model(x).size()

torch.Size([1, 16, 10, 10])

# TransposeConvNormAct

In [15]:
class TransposeConvNormAct(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        output_padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        norm: nn.Module = nn.BatchNorm2d,
        act: nn.Module = nn.ReLU,
    ) -> None:
        super().__init__()
        assert in_channels % groups == 0
        assert out_channels % groups == 0
        self.conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )
        self.norm = norm(out_channels)
        self.act = act()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.norm(self.conv(x)))

In [16]:
x = torch.randn(1, 3, 10, 10)

In [17]:
model = TransposeConvNormAct(
    in_channels=3, out_channels=4, kernel_size=3, stride=1
).eval()
with torch.inference_mode():
    print(model(x).shape)  # [1, 4, 12, 12]

torch.Size([1, 4, 12, 12])


In [18]:
model = TransposeConvNormAct(
    in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1
).eval()
with torch.inference_mode():
    print(model(x).shape)  # [1, 4, 19, 19]

torch.Size([1, 4, 19, 19])


In [19]:
model = TransposeConvNormAct(
    in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1, output_padding=1
).eval()
with torch.inference_mode():
    print(model(x).shape)  # [1, 4, 20, 20]

torch.Size([1, 4, 20, 20])


In [20]:
model = TransposeConvNormAct(
    in_channels=3, out_channels=4, kernel_size=4, stride=2, padding=1
).eval()
with torch.inference_mode():
    print(model(x).shape)  # [1, 4, 20, 20]

torch.Size([1, 4, 20, 20])
