In [1]:
import torch
from torch import nn


In [29]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        self.globalAvgPool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(planes * 4,round(planes / 4),kernel_size=1,stride=1)
        self.fc2 = nn.Conv2d(round(planes / 4),planes * 4,kernel_size=1,stride=1)
        self.sigmoid = nn.Sigmoid()
        if downsample:
            self.downsample =nn.Conv2d(inplanes,self.expansion*inplanes,kernel_size=1,stride=1)
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)

        original_out = out
        out = self.globalAvgPool(out)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = out * original_out
      
        out += residual
        out = self.relu(out)

        return out

In [31]:
x= torch.randn(1,512,28,28)
test= Bottleneck(512,512,downsample=True)

test.eval()
torch.onnx.export(test,x,'test.onnx')