In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.nn import init

In [8]:
def conv3x3(in_c, out_c, stride=1, padding=1, bias=True, groups=1):
    return nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups)

def conv1x1(in_c, out_c, stride=1, groups=1):
    return nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, groups=groups)

def channel_shuffle(x, groups):
    # batch, channel, height, width
    n, c, h, w = x.data.size()

    group_channel = c // groups

    x = x.view(n, groups, group_channel, h, w)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(n, -1, h, w)
    
    return x


class ShuffleUnit(nn.Module):
    def __init__(self, in_c, out_c, groups=3, grouped_conv=True, combine='add'):
        super(ShuffleUnit, self).__init__()

        self.in_c = in_c
        self.out_c = out_c
        self.groups = groups
        self.grouped_conv = grouped_conv
        self.combine = combine
        self.bottleneck_channels = self.out_c//4


        # type of shuffle unit add(b) or concat(c)
        if self.combine == 'add':
            self.depthwise_stride=1
            self._combine_func = self._add
        elif self.combine == 'concat':
            self.depthwise_stride=2
            self._combine_func = self._concat
        else:
            raise ValueError(f"Not Supported type {self.combine}")
        
        self.first_1x1 = self.groups if grouped_conv else 1
        self.gconv_1x1 = self._make_1x1_grouped(
            self.in_c,
            self.out_c,
            self.first_1x1,
            bn = True,
            relu = True
            )
        
        # 3x3 DWConv : DepthWiseConv
        self.depthwise_conv3x3 = conv3x3(
            self.bottleneck_channels,
            self.bottleneck_channels,
            stride = self.depthwise_stride,
            groups = self.bottleneck_channels,
        )
        self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
        
        # 1x1 gconv
        self.g_conv_1x1_last = self._make_1x1_grouped(
            self.bottleneck_channels,
            self.out_c,
            self.groups,
            bn=True
            relu=False
        )

    @staticmethod
    def _add(x, out):
        return x+out

    @staticmethod
    def _concat(x, out):
        return torch.cat((x, out), 1)

    def _make_1x1_grouped(self, in_c, out_c, groups, bn=True, relu=False):
        modules = OrderedDict()

        conv = conv1x1(in_c, out_c, groups=groups)
        modules['conv1x1'] = conv

        if bn:
            modules['batch_norm'] = nn.BatchNorm2d(out_c)
        if relu:
            modules['relu'] = nn.ReLU()
        if len(modules) > 1:
            return nn.Sequential(modules)
        else:
            return conv

    def forward(self, x):
        # x : residual
        residual = x

        if self.combine == 'concat':
            residual = F.avg_pool2d(residual, kernel_size=3,
                                    stride=2, padding=1)
        
        out = self.g_conv_1x1(x)
        out = channel_shuffle(out, self.groups)
        out = self.depthwise_conv3x3(out)
        out = self.bn_after_depthwise(out)
        out = self.g_conv_1x1_last(out)
        out = self._combine_func(residual, out)

        return F.relu(out)


class ShuffleNet(nn.Module):
    def __init__(self, groups=3, in_c=3):
        super(ShuffleNet, self).__init__()

        self.groups = groups
        self.stage_repeats = [3, 7, 3]
        self.in_c=in_c

        if groups == 1:
            self.stage_out_channels = [-1, 24, 144, 288, 567]
        elif groups == 2:
            self.stage_out_channels = [-1, 24, 200, 400, 800]
        elif groups == 3:
            self.stage_out_channels = [-1, 24, 240, 480, 960]
        elif groups == 4:
            self.stage_out_channels = [-1, 24, 272, 544, 1088]
        elif groups == 8:
            self.stage_out_channels = [-1, 24, 384, 768, 1536]
        else:
            raise ValueError(f'Not Supoorted groups num : {groups}')

        self.conv1 = conv3x3(self.in_c, self.stage_out_channels[1], stride=2)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # stage2
        self.stage2 = self._make_stage(2)
        self.stage3 = self._make_stage(3)
        self.stage4 = self._make_stage(4)

        #gobal pooling
        
        num_inputs = self.stage_out_channels[-1]
        # for regression
        self.fc = nn.Linear(num_inputs, 1)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant(m.weight, 1)
                init.constant(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant(m.bias, 0)


    def _make_stage(self, stage):
        modules = OrderedDict()
        stage_name = f'ShuffleUnit_Stage{stage}'
        # in stage 2, not use group
        grouped_conv = stage>2

        first_module = ShuffleUnit(
            self.stage_out_channels[stage-1],
            self.stage_out_channels[stage],
            groups = self.groups,
            grouped_conv = grouped_conv,
            combine='concat'
        )
        modules[stage_name+'_0'] = first_module

        for i in range(self.stage_repeats[stage-2]):
            name = stage_name+ f'_{i+1}'
            module = ShuffleUnit(
                self.stage_out_channels[stage],
                self.stage_out_channels[stage],
                groups = self.groups,
                grouped_conv = True,
                combine='add'
            )
            modules[name] = module

        return nn.Sequential(modules)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = F.avg_pool2d(x, x.data.size()[-2:])
        x = self.fc(x)

        return x
