In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F

In [2]:
class ShuffleV1Block(nn.Module):
    def __init__(self, inp, oup, *, group, first_group, mid_channels, ksize, stride):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]
        self.mid_channels = mid_channels
        self.ksize = ksize
        pad = ksize // 2
        self.pad = pad
        self.inp = inp
        self.group = group

        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        
        branch_main_1 = [
            nn.Conv2d(inp, mid_channels, 1, 1, 0, groups=1 if first_group else group, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
            nn.BatchNorm2d(mid_channels),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, outputs, 1, 1, 0, groups=group, bias=False),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.Sequential(*branch_main_1)
        self.branch_main_2 = nn.Sequential(*branch_main_2)
    
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, old_x):
        x = old_x
        x_proj = old_x
        x = self.branch_main_1(x)
        if self.group > 1:
            x = self.channel_shuffle(x)
        x = self.branch_main_2(x)
        if self.stride == 1:
            return F.relu(x + x_proj)
        elif self.stride == 2:
            return torch.cat((self.branch_proj(x_proj), F.relu(x)), 1)
    
    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = x.data.size()
        assert num_channels % self.group == 0
        group_channels = num_channels // self.group

        x = x.reshape(batchsize, group_channels, self.group, height, width)
        x = x.permute(0, 2, 1, 3, 4)
        x = x.reshape(batchsize, num_channels, height, width)

        return x

In [3]:
class ShuffleNetV1(nn.Module):
    def __init__(self, input_size=224, n_class=1000, model_size='2.0x', group=None):
        super().__init__()
        print("model size is", model_size)
        assert group is not None

        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == "0.5x":
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == "1.0x":
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == "1.5x":
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == "2.0x":
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        
        elif group == 8:
            if model_size == "0.5x":
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == "1.0x":
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == "1.5x":
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == "2.0x":
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                self.features.append(ShuffleV1Block(input_channel, output_channel,
                                    group=group, first_group=first_group,
                                    mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel

        self.features = nn.Sequential(*self.features)
        self.globalpool = nn.AvgPool2d(7)

        self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=False))

        self._initialize_weights()
    
    def forward(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)

        x = self.globalpool(x)
        x = x.contiguous().view(-1, self.stage_out_channels[-1])
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    nn.init.normal_(m.weight, 0, 0.01)
                else:
                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


In [4]:
model = ShuffleNetV1(group=3)
test_data = torch.rand(5, 3, 224, 224)
test_outputs = model(test_data)
print(test_outputs.size())

model size is 2.0x
torch.Size([5, 1000])


In [5]:
from torchsummary import summary

summary(model.cuda(), input_size=(3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 48, 112, 112]           1,296
       BatchNorm2d-2         [-1, 48, 112, 112]              96
              ReLU-3         [-1, 48, 112, 112]               0
         MaxPool2d-4           [-1, 48, 56, 56]               0
            Conv2d-5          [-1, 120, 56, 56]           5,760
       BatchNorm2d-6          [-1, 120, 56, 56]             240
              ReLU-7          [-1, 120, 56, 56]               0
            Conv2d-8          [-1, 120, 28, 28]           1,080
       BatchNorm2d-9          [-1, 120, 28, 28]             240
           Conv2d-10          [-1, 432, 28, 28]          17,280
      BatchNorm2d-11          [-1, 432, 28, 28]             864
        AvgPool2d-12           [-1, 48, 28, 28]               0
   ShuffleV1Block-13          [-1, 480, 28, 28]               0
           Conv2d-14          [-1, 120,