In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class DetBottleneck(nn.Module):
    # extra:False :A     True: B
    def __init__(self, inplanes, planes, stride=1, extra=False):
        super(DetBottleneck, self).__init__()
        self.bottleneck = nn.Sequential(
            nn.Conv2d(inplanes, planes, 1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(True),
            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=2, dilation=2, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(True),
            nn.Conv2d(planes, planes, 1, bias=False),
            nn.BatchNorm2d(planes)
        )
        self.relu = nn.ReLU(True)
        self.extra = extra
        if self.extra:
            self.extra_conv = nn.Sequential(
                nn.Conv2d(inplanes, planes, 1, bias=False),
                nn.BatchNorm2d(planes)
            )
    
    def forward(self, x):
        if self.extra:
            identity = self.extra_conv(x)
        else:
            identity = x
        out = self.bottleneck(x)
        out += identity
        out = self.relu(out)
        return out

In [3]:
bottleneck_b = DetBottleneck(1024, 256, 1, True)

In [4]:
bottleneck_b

DetBottleneck(
  (bottleneck): Sequential(
    (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu): ReLU(inplace=True)
  (extra_conv): Sequential(
    (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [5]:
bottleneck_a1 = DetBottleneck(256, 256)
bottleneck_a2 = DetBottleneck(256, 256)

In [6]:
data = torch.randn(1, 1024, 14, 14)

In [10]:
o1 = bottleneck_b(data)
o1.shape

torch.Size([1, 256, 14, 14])

In [11]:
o2 = bottleneck_a1(o1)
o2.shape

torch.Size([1, 256, 14, 14])

In [12]:
o3 = bottleneck_a2(o2)
o3.shape

torch.Size([1, 256, 14, 14])