In [None]:
import torch
import torch.nn as nn

In [None]:
class Bottle2neck(nn.Module):
    expansion = 4
    
    def __init__(self, inplanes, planes, stride = 1, downsample = None, cardinality = 1, base_width = 26,
              scale = 8, dilation = 1, first_dilation = None, attn_layer = None):
        super(Bottle2neck, self).__init__()
        self.scale = scale
        self.is_first = stride > 1 or downsample is not None
        self.num_scales = max(1, scale - 1)
        width = int(math.floor(planes*(base_width/64.0))) * cardinality
        self.width = width
        outplanes = planes * self.expansion
        
        self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(width*scale)
        
        convs = []
        bns = []
        
        for i in range(self.num_scales):
            convs.append(nn.Conv2d(
                width, width, kernel_size = (3,3), stride = stride, padding = first_dilation,
                dilation = first_dilation, groups = cardinality, bias = False))
            bns.append(nn.BatchNorm2d(width))
        
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        
        if self.is_first:
            self.pool = nn.AvgPool2d(kernel_size = 3, stride = stride, padding = 1)
        else:
            self.pool = None
            
        self.conv3 = nn.Conv2d(width*scale, outplanes, kernel_size = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(outplanes)
        
        if attn_layer is not None:
            self.se = attn_layer(outplanes)
            
        self.relu = nn.ReLU(inplace = True)
        self.downsample = downsample
        
    def zero_init_last(self):
        nn.init.zeros_(self.bn3.weight)
        
    def forward(self, x):
        shortcut = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        spx = torch.split(out, self.width, 1)
        spo = []
        sp = spx[0]
        
        for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
            if i==0 or self.is_first:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = conv(sp)
            sp = bn(sp)
            sp = self.relu(sp)
            spo.append(sp)
        
        if self.scale > 1:
            if self.pool is not None:
                spo.append(self.pool(spx[-1]))
            else:
                spo.append(spx[-1])
        
        out = torch.cat(spo, 1)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.se is not None:
            out = self.se(out)
            
        if self.downsample is not None:
            shortcut = self.downsample(x)
            
        out += shortcut
        out = self.relu(out)
        
        return out