Permalink
Find file Copy path
ea25bb0 Jul 28, 2018
4 contributors

Users who have contributed to this file

@CreativeCVision @vfdev-5 @Cadene @wangg12
662 lines (504 sloc) 30.3 KB
"""
NASNet Mobile
Thanks to Anastasiia (https://github.com/DagnyT) for the great help, support and motivation!
------------------------------------------------------------------------------------
Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M)
------------------------------------------------------------------------------------
| NASNet-A (4 @ 1056) | 74.08% | 91.74% | 564 M | 5.3 |
------------------------------------------------------------------------------------
# References:
- [Learning Transferable Architectures for Scalable Image Recognition]
(https://arxiv.org/abs/1707.07012)
"""
from __future__ import print_function, division, absolute_import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.autograd import Variable
import numpy as np
pretrained_settings = {
'nasnetamobile': {
'imagenet': {
#'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar',
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224], # resize 256
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000
},
# 'imagenet+background': {
# # 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
# 'input_space': 'RGB',
# 'input_size': [3, 224, 224], # resize 256
# 'input_range': [0, 1],
# 'mean': [0.5, 0.5, 0.5],
# 'std': [0.5, 0.5, 0.5],
# 'num_classes': 1001
# }
}
}
class MaxPoolPad(nn.Module):
def __init__(self):
super(MaxPoolPad, self).__init__()
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
x = self.pad(x)
x = self.pool(x)
x = x[:, :, 1:, 1:].contiguous()
return x
class AvgPoolPad(nn.Module):
def __init__(self, stride=2, padding=1):
super(AvgPoolPad, self).__init__()
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
def forward(self, x):
x = self.pad(x)
x = self.pool(x)
x = x[:, :, 1:, 1:].contiguous()
return x
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
super(SeparableConv2d, self).__init__()
self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel,
stride=dw_stride,
padding=dw_padding,
bias=bias,
groups=in_channels)
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
def forward(self, x):
x = self.depthwise_conv2d(x)
x = self.pointwise_conv2d(x)
return x
class BranchSeparables(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, name=None, bias=False):
super(BranchSeparables, self).__init__()
self.relu = nn.ReLU()
self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
self.relu1 = nn.ReLU()
self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
self.name = name
def forward(self, x):
x = self.relu(x)
if self.name == 'specific':
x = nn.ZeroPad2d((1, 0, 1, 0))(x)
x = self.separable_1(x)
if self.name == 'specific':
x = x[:, :, 1:, 1:].contiguous()
x = self.bn_sep_1(x)
x = self.relu1(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class BranchSeparablesStem(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
super(BranchSeparablesStem, self).__init__()
self.relu = nn.ReLU()
self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
self.relu1 = nn.ReLU()
self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
def forward(self, x):
x = self.relu(x)
x = self.separable_1(x)
x = self.bn_sep_1(x)
x = self.relu1(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class BranchSeparablesReduction(BranchSeparables):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
def forward(self, x):
x = self.relu(x)
x = self.padding(x)
x = self.separable_1(x)
x = x[:, :, 1:, 1:].contiguous()
x = self.bn_sep_1(x)
x = self.relu1(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class CellStem0(nn.Module):
def __init__(self, stem_filters, num_filters=42):
super(CellStem0, self).__init__()
self.num_filters = num_filters
self.stem_filters = stem_filters
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True))
self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2)
self.comb_iter_0_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False)
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_1_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.comb_iter_2_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 5, 2, 2, bias=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False)
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
x1 = self.conv_1x1(x)
x_comb_iter_0_left = self.comb_iter_0_left(x1)
x_comb_iter_0_right = self.comb_iter_0_right(x)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x1)
x_comb_iter_1_right = self.comb_iter_1_right(x)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x1)
x_comb_iter_2_right = self.comb_iter_2_right(x)
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
x_comb_iter_4_right = self.comb_iter_4_right(x1)
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class CellStem1(nn.Module):
def __init__(self, stem_filters, num_filters):
super(CellStem1, self).__init__()
self.num_filters = num_filters
self.stem_filters = stem_filters
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_filters, self.num_filters, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True))
self.relu = nn.ReLU()
self.path_1 = nn.Sequential()
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False))
self.path_2 = nn.ModuleList()
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_2.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False))
self.final_path_bn = nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)
self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, name='specific', bias=False)
self.comb_iter_0_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, name='specific', bias=False)
# self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_1_left = MaxPoolPad()
self.comb_iter_1_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, name='specific', bias=False)
# self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.comb_iter_2_left = AvgPoolPad()
self.comb_iter_2_right = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, name='specific', bias=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, name='specific', bias=False)
# self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_4_right = MaxPoolPad()
def forward(self, x_conv0, x_stem_0):
x_left = self.conv_1x1(x_stem_0)
x_relu = self.relu(x_conv0)
# path 1
x_path1 = self.path_1(x_relu)
# path 2
x_path2 = self.path_2.pad(x_relu)
x_path2 = x_path2[:, :, 1:, 1:]
x_path2 = self.path_2.avgpool(x_path2)
x_path2 = self.path_2.conv(x_path2)
# final path
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
x_comb_iter_0_left = self.comb_iter_0_left(x_left)
x_comb_iter_0_right = self.comb_iter_0_right(x_right)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
x_comb_iter_1_right = self.comb_iter_1_right(x_right)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x_left)
x_comb_iter_2_right = self.comb_iter_2_right(x_right)
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
x_comb_iter_4_right = self.comb_iter_4_right(x_left)
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class FirstCell(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
super(FirstCell, self).__init__()
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.relu = nn.ReLU()
self.path_1 = nn.Sequential()
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.path_2 = nn.ModuleList()
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
def forward(self, x, x_prev):
x_relu = self.relu(x_prev)
# path 1
x_path1 = self.path_1(x_relu)
# path 2
x_path2 = self.path_2.pad(x_relu)
x_path2 = x_path2[:, :, 1:, 1:]
x_path2 = self.path_2.avgpool(x_path2)
x_path2 = self.path_2.conv(x_path2)
# final path
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
x_right = self.conv_1x1(x)
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
x_comb_iter_2 = x_comb_iter_2_left + x_left
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
x_comb_iter_4 = x_comb_iter_4_left + x_right
x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class NormalCell(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
super(NormalCell, self).__init__()
self.conv_prev_1x1 = nn.Sequential()
self.conv_prev_1x1.add_module('relu', nn.ReLU())
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
def forward(self, x, x_prev):
x_left = self.conv_prev_1x1(x_prev)
x_right = self.conv_1x1(x)
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
x_comb_iter_2 = x_comb_iter_2_left + x_left
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
x_comb_iter_4 = x_comb_iter_4_left + x_right
x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class ReductionCell0(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
super(ReductionCell0, self).__init__()
self.conv_prev_1x1 = nn.Sequential()
self.conv_prev_1x1.add_module('relu', nn.ReLU())
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
self.comb_iter_1_left = MaxPoolPad()
self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
self.comb_iter_2_left = AvgPoolPad()
self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
self.comb_iter_4_right = MaxPoolPad()
def forward(self, x, x_prev):
x_left = self.conv_prev_1x1(x_prev)
x_right = self.conv_1x1(x)
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class ReductionCell1(nn.Module):
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
super(ReductionCell1, self).__init__()
self.conv_prev_1x1 = nn.Sequential()
self.conv_prev_1x1.add_module('relu', nn.ReLU())
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
self.conv_1x1 = nn.Sequential()
self.conv_1x1.add_module('relu', nn.ReLU())
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, name='specific', bias=False)
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, name='specific', bias=False)
# self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_1_left = MaxPoolPad()
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, name='specific', bias=False)
# self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
self.comb_iter_2_left = AvgPoolPad()
self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, name='specific', bias=False)
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, name='specific', bias=False)
# self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
self.comb_iter_4_right =MaxPoolPad()
def forward(self, x, x_prev):
x_left = self.conv_prev_1x1(x_prev)
x_right = self.conv_1x1(x)
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
class NASNetAMobile(nn.Module):
"""NASNetAMobile (4 @ 1056) """
def __init__(self, num_classes=1001, stem_filters=32, penultimate_filters=1056, filters_multiplier=2):
super(NASNetAMobile, self).__init__()
self.num_classes = num_classes
self.stem_filters = stem_filters
self.penultimate_filters = penultimate_filters
self.filters_multiplier = filters_multiplier
filters = self.penultimate_filters // 24
# 24 is default value for the architecture
self.conv0 = nn.Sequential()
self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, padding=0, stride=2,
bias=False))
self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_filters, eps=0.001, momentum=0.1, affine=True))
self.cell_stem_0 = CellStem0(self.stem_filters, num_filters=filters // (filters_multiplier ** 2))
self.cell_stem_1 = CellStem1(self.stem_filters, num_filters=filters // filters_multiplier)
self.cell_0 = FirstCell(in_channels_left=filters, out_channels_left=filters//2, # 1, 0.5
in_channels_right=2*filters, out_channels_right=filters) # 2, 1
self.cell_1 = NormalCell(in_channels_left=2*filters, out_channels_left=filters, # 2, 1
in_channels_right=6*filters, out_channels_right=filters) # 6, 1
self.cell_2 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1
in_channels_right=6*filters, out_channels_right=filters) # 6, 1
self.cell_3 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1
in_channels_right=6*filters, out_channels_right=filters) # 6, 1
self.reduction_cell_0 = ReductionCell0(in_channels_left=6*filters, out_channels_left=2*filters, # 6, 2
in_channels_right=6*filters, out_channels_right=2*filters) # 6, 2
self.cell_6 = FirstCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1
in_channels_right=8*filters, out_channels_right=2*filters) # 8, 2
self.cell_7 = NormalCell(in_channels_left=8*filters, out_channels_left=2*filters, # 8, 2
in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2
self.cell_8 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2
in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2
self.cell_9 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2
in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2
self.reduction_cell_1 = ReductionCell1(in_channels_left=12*filters, out_channels_left=4*filters, # 12, 4
in_channels_right=12*filters, out_channels_right=4*filters) # 12, 4
self.cell_12 = FirstCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2
in_channels_right=16*filters, out_channels_right=4*filters) # 16, 4
self.cell_13 = NormalCell(in_channels_left=16*filters, out_channels_left=4*filters, # 16, 4
in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4
self.cell_14 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, # 24, 4
in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4
self.cell_15 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, # 24, 4
in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4
self.relu = nn.ReLU()
self.avg_pool = nn.AvgPool2d(7, stride=1, padding=0)
self.dropout = nn.Dropout()
self.last_linear = nn.Linear(24*filters, self.num_classes)
def features(self, input):
x_conv0 = self.conv0(input)
x_stem_0 = self.cell_stem_0(x_conv0)
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2)
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3)
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8)
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9)
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
return x_cell_15
def logits(self, features):
x = self.relu(features)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.last_linear(x)
return x
def forward(self, input):
x = self.features(input)
x = self.logits(x)
return x
def nasnetamobile(num_classes=1001, pretrained='imagenet'):
r"""NASNetALarge model architecture from the
`"NASNet" <https://arxiv.org/abs/1707.07012>`_ paper.
"""
if pretrained:
settings = pretrained_settings['nasnetamobile'][pretrained]
assert num_classes == settings['num_classes'], \
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
# both 'imagenet'&'imagenet+background' are loaded from same parameters
model = NASNetAMobile(num_classes=num_classes)
model.load_state_dict(model_zoo.load_url(settings['url'], map_location=None))
# if pretrained == 'imagenet':
# new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
# new_last_linear.weight.data = model.last_linear.weight.data[1:]
# new_last_linear.bias.data = model.last_linear.bias.data[1:]
# model.last_linear = new_last_linear
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
else:
settings = pretrained_settings['nasnetamobile']['imagenet']
model = NASNetAMobile(num_classes=num_classes)
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
return model
if __name__ == "__main__":
model = NASNetAMobile()
input = Variable(torch.randn(2, 3, 224, 224))
output = model(input)
print(output.size())